Theory KD_Region_Tree

section ‹K-dimensional Region Trees›

theory KD_Region_Tree
imports
  "HOL-Library.NList"
  "HOL-Library.Tree" (* only for ‹height› *)
begin                                         

(* TODO: In Isabelle after 2023 - remove *)
lemma nlists_Suc: "nlists (Suc n) A = (aA. (#) a ` nlists n A)"
by(auto simp: set_eq_iff image_iff in_nlists_Suc_iff)
lemma in_nlists_UNIV: "xs  nlists k UNIV  length xs = k"
unfolding nlists_def by(auto)
lemma nlists_singleton: "nlists n {a} = {replicate n a}"
unfolding nlists_def by(auto simp: replicate_length_same dest!: subset_singletonD)

text ‹Generalizes quadtrees. Instead of having 2^n› direct children of a node,
the children are arranged in a binary tree where each Split› splits along one dimension.›

datatype 'a kdt = Box 'a | Split "'a kdt" "'a kdt"

(* For quickcheck: *)
datatype_compat kdt

type_synonym kdtb = "bool kdt"

text ‹A kdt› is most easily explained by showing how quad trees are represented:
Q t0 t1 t2 t3› becomes @{term "Split (Split t0' t1') (Split t2' t3')"}
where ti'› is the representation of ti›; L a› becomes @{term "Box a"}.
In general, each level of an abstract k› dimensional tree subdivides space into 2^k›
subregions. This subdivision is represented by a kdt› of depth at most k›.
Further subdivisions of the subregions are seamlessly represented as the subtrees
at depth k›. @{term "Box a"} represents a subregion entirely filled with a›'s.
In contrast to quad trees, cubes can also occur half way down the subdivision.
For example, Q (L a) (L a) (L b) (L c)› becomes @{term "Split (Box a) (Split (Box b) (Box c))"}.
›

instantiation kdt :: (type)height
begin

fun height_kdt :: "'a kdt  nat" where
"height (Box _) = 0" |
"height (Split l r) = max (height l) (height r) + 1"

instance ..

end

lemma height_0_iff: "height t = 0  (x. t = Box x)"
by(cases t)auto

definition bits :: "nat  bool list set" where
"bits n = nlists n UNIV"

(* for quickcheck *)

lemma bits_0[code]: "bits 0 = {[]}"
by(simp add:bits_def)

lemma bits_Suc[code]:
  "bits (Suc n) = (let B = bits n in (#) True ` B  (#) False ` B)"
by(simp_all add: bits_def nlists_Suc UN_bool_eq Let_def)


subsection ‹Subtree›

fun subtree :: "'a kdt  bool list  'a kdt" where
"subtree t [] = t" |
"subtree (Box x) _ = Box x" |
"subtree (Split l r) (b#bs) = subtree (if b then r else l) bs"

lemma subtree_Box[simp]: "subtree (Box x) bs = Box x"
by(cases bs)auto

lemma height_subtree: "height (subtree t bs)  height t - length bs"
by(induction t bs rule: subtree.induct) auto

lemma height_subtree2: " height t  k * (Suc n); length bs = k  height (subtree t bs)  k * n"
using height_subtree[of t bs] by auto

lemma subtree_Split_Box: "length bs  0  subtree (Split (Box b) (Box b)) bs = Box b"
by(auto simp: neq_Nil_conv)

(* Surprisingly, points_SplitC is not needed:
lemma points_Split_Box: "Suc 0 ≤ k*n ⟹ points k n (Split (Box b) (Box b)) = points k n (Box b)"
proof(induction n)
  case (Suc n)
  from ‹Suc 0 ≤ k*Suc n› have "k > 0" using neq0_conv by fastforce
  with Suc show ?case by(simp add: subtree_Split_Box nlists2_simp)
qed simp

lemma points_SplitC: "height (Split l r) ≤ k*n ⟹ points k n (SplitC l r) = points k n (Split l r)"
by(induction l r rule: SplitC.induct)
  (simp_all add: points_Split_Box)
*)

subsection ‹Shifting a coordinate by a boolean vector›

definition mv :: "nat  bool list  nat list  nat list" where
"mv d = map2 (λb x. x + (if b then 0 else d))"

lemma map_zip1: " length xs = length ys; p  set(zip xs ys). f p = fst p   map f (zip xs ys) = xs"
by (metis (no_types, lifting) map_eq_conv map_fst_zip)

lemma map_mv1: " ps  nlists (length bs) {0..<n}; length ps = length bs 
  map (λi. i < n) (mv (n) bs ps) = bs"
by(fastforce simp: mv_def intro!: map_zip1 dest: set_zip_rightD nlistsE_set split: if_splits)

lemma map_zip2: " length xs = length ys; p  set(zip xs ys). f p = snd p   map f (zip xs ys) = ys"
by (metis (no_types, lifting) map_eq_conv map_snd_zip)

lemma map_mv2: " ps  nlists (length bs) {0..<2^n}   map (λx. x mod 2^n) (mv (2^n) bs ps) = ps"
by(fastforce simp: mv_def dest: set_zip_rightD nlistsE_set intro!: map_zip2)

lemma mv_map_map: "set ps  {0..<2 * n}  mv (n) (map (λx. x < n) ps) (map (λx. x mod n) ps) = ps"
unfolding nlists_def mv_def
by(auto simp: map_eq_conv[where xs=ps and g=id,simplified] map2_map_map not_less le_iff_add)

lemma mv_in_nlists:
  " p  nlists k {0..<2 ^ n}; bs  bits k   mv (2^n) bs p  nlists k {0..<2 * 2 ^ n}"
unfolding mv_def nlists_def bits_def
by (fastforce dest: set_zip_rightD)

lemma in_nlists2D: "xs  nlists k {0..<2 * 2^n}  bs bits k. xs  mv (2^n) bs ` nlists k {0..<2^n}"
unfolding nlists_def bits_def image_def
apply(rule bexI[where x  = "map (λx. x < 2^n) xs"])
 apply(simp)
 apply(rule exI[where x = "map (λi. i mod 2^n) xs"])
 apply (auto simp add: mv_map_map)
done

lemma nlists2_simp: "nlists k {0..<2 * 2 ^ n} = (bs bits k. mv (2^n) bs ` nlists k {0..<2 ^ n})"
by (auto simp: mv_in_nlists in_nlists2D)

lemma in_mv_image: " ps  nlists k {0..<2*2^n}; Ps  nlists k {0..<2^n}; bs  bits k  
  ps  mv (2^n) bs ` Ps  map (λx. x mod 2^n) ps  Ps  (bs = map (λi. i < 2^n) ps)"
by (auto simp: map_mv1 map_mv2 mv_map_map bits_def intro!: image_eqI)


subsection ‹Points in a tree›

fun cube :: "nat  nat  nat list set" where
"cube k n = nlists k {0..<2^n}"

fun points :: "nat  nat  kdtb  nat list set" where
"points k n (Box b) = (if b then cube k n else {})" |
"points k (Suc n) t = (bs  bits k. mv (2^n) bs ` points k n (subtree t bs))"

lemma points_Suc: "points k (Suc n) t = (bs  bits k. mv (2^n) bs ` points k n (subtree t bs))"
by(cases t) (simp_all add: nlists2_simp)

lemma points_subset: "height t  k*n  points k n t  nlists k {0..<2^n}"
proof(induction k n t rule: points.induct)
  case (2 k n l r)
  have "bs. bs  bits k  height (subtree (Split l r) bs)  k*n"
    unfolding bits_def using "2.prems" height_subtree2 in_nlists_UNIV by blast
  with "2.IH" show ?case
    by(auto intro: mv_in_nlists dest: subsetD)
qed auto


subsection ‹Compression›

text ‹Compressing Split:›

fun SplitC :: "'a kdt  'a kdt  'a kdt" where
"SplitC (Box b1) (Box b2) = (if b1=b2 then Box b1 else Split (Box b1) (Box b2))" |
"SplitC t1 t2 = Split t1 t2"

fun compressed :: "'a kdt  bool" where
"compressed (Box _) = True" |
"compressed (Split l r) = (compressed l  compressed r  ¬(b. l = Box b  r = Box b))"

lemma compressedI: " compressed l; compressed r   compressed (SplitC l r)"
by(induction l r rule: SplitC.induct) auto

lemma subtree_SplitC:
  "1  length bs  subtree (SplitC l r) bs = subtree (Split l r) bs"
by(induction l r rule: SplitC.induct)
  (simp_all add: subtree_Split_Box Suc_le_eq)

lemma height_SplitC: "height(SplitC l r)  Suc (max (height l) (height r))"
by(cases "(l,r)" rule: SplitC.cases)(auto)

lemma height_SplitC2: " height l  n; height r  n   height(SplitC l r)  Suc n"
using height_SplitC[of l r] by simp

subsection ‹Extracting a point from a tree›

text ‹Also the abstraction function.›

fun get :: "nat  'a kdt  nat list  'a"  where
"get _ (Box b) _ = b" |
"get (Suc n) t ps = get n (subtree t (map (λi. i < 2^n) ps)) (map (λi. i mod 2^n) ps)"

lemma get_Suc: "get (Suc n) t ps =
  get n (subtree t (map (λi. i < 2 ^ n) ps)) (map (λi. i mod 2 ^ n) ps)"
by(cases t)auto

lemma points_get: " height t  k*n; ps  nlists k {0..<2^n}  
  get n t ps = (ps  points k n t)"
proof(induction n arbitrary: k t ps)
  case 0
  then show ?case by(clarsimp simp add: height_0_iff)
next
  case (Suc n)
  show ?case
  proof (cases t)
    case Box
    thus ?thesis using Suc.prems(2) by(simp)
  next
    case (Split l r)
    obtain k0 where "k = Suc k0" using Suc.prems(1) Split
      by(cases k) auto
    hence "ps  []"
      using Suc.prems(2) by (auto simp: in_nlists_Suc_iff)
    then show ?thesis using Suc.prems Split Suc.IH[OF height_subtree2[OF Suc.prems(1)]] in_nlists2D
      by(simp add: height_subtree2 in_mv_image points_subset bits_def)
  qed
qed


subsection ‹Modifying a point in a tree›

fun modify :: "('a kdt  'a kdt)  bool list  'a kdt  'a kdt" where
"modify f [] t = f t" |
"modify f (b # bs) (Split l r) = (if b then SplitC l (modify f bs r) else SplitC (modify f bs l) r)" |
"modify f (b # bs) (Box a) =
  (let t = modify f bs (Box a) in if b then SplitC (Box a) t else SplitC t (Box a))"

fun put :: "nat list  'a  nat  'a kdt  'a kdt" where
"put ps a 0 (Box _) = Box a" |
"put ps a (Suc n) t = modify (put (map (λi. i mod 2^n) ps) a n) (map (λi. i < 2^n) ps) t"


lemma height_modify: " t. height t  nk  height (f t)  nk;
     height t  k + nk; length bs = k
     height (modify f bs t)  k + nk"
apply(induction f bs t arbitrary: k rule: modify.induct)
by (auto simp: height_SplitC2 Let_def)

lemma height_put: "height t  n * length ps  height (put ps a n t)  n * length ps"
proof(induction ps a n t rule: put.induct)
  case 2
  then show ?case by (auto simp: height_modify)
qed auto

lemma subtree_modify: " length bs' = length bs 
     subtree (modify f bs t) bs' = (if bs' = bs then f(subtree t bs) else subtree t bs')"
apply(induction f bs t arbitrary: bs' rule: modify.induct)
apply(auto simp add: length_Suc_conv Let_def subtree_SplitC split: if_splits)
done

lemma mod_eq1: " y < 2 * n; ya < 2 * n; ¬ ya < n; ¬ y < n; ya mod n = y mod n
        ya = (y::nat)"
by(simp add: mod_if mult_2 split: if_splits)

lemma nlist_eq_mod: " ps  nlists k {0..<(2::nat) * 2 ^ n}; ps'  nlists k {0..<2 * 2 ^ n};
     map (λi. i < 2 ^ n) ps' = map (λi. i < 2 ^ n) ps; ps'  ps  
      map (λi. i mod 2 ^ n) ps'  map (λi. i mod 2 ^ n) ps"
apply(induction k arbitrary: ps ps')
 apply simp
apply (fastforce simp: in_nlists_Suc_iff mod_eq1)
done

lemma get_put: " height t  k*n; ps  cube k n; ps'  cube k n  
  get n (put ps a n t) ps' = (if ps' = ps then a else get n t ps')"
proof(induction ps a n t arbitrary: ps' rule: put.induct)
  case 1
  then show ?case by (simp add: nlists_singleton)
next
  case 2
  thus ?case using in_nlists2D
    by(auto simp add: subtree_modify get_Suc height_subtree2 nlist_eq_mod in_mv_image)
qed auto

lemma compressed_modify: " compressed t; compressed (f (subtree t bs))   compressed (modify f bs t)"
by(induction f bs t rule: modify.induct) (auto simp: compressedI Let_def)

lemma compressed_subtree: "compressed t  compressed (subtree t bs)"
by(induction t bs rule: subtree.induct) auto

lemma compressed_put:
  "  height t  k*n; k = length ps; compressed t   compressed (put ps a n t)"
proof(induction ps a n t rule: put.induct)
  case 1
  then show ?case by (simp)
next
  case 2
  thus ?case by (simp add: compressed_modify compressed_subtree height_subtree2)
qed auto


subsection ‹Union›

fun union :: "kdtb  kdtb  kdtb" where
"union (Box b) t = (if b then Box True else t)" |
"union t (Box b) = (if b then Box True else t)" |
"union (Split l1 r1) (Split l2 r2) = SplitC (union l1 l2) (union r1 r2)"

lemma union_Box2: "union t (Box b) = (if b then Box True else t)"
by(cases t) auto

lemma subtree_union: "subtree (union t1 t2) bs = union (subtree t1 bs) (subtree t2 bs)"
proof(induction t1 t2 arbitrary: bs rule: union.induct)
  case 2 thus ?case by(auto simp: union_Box2)
next
  case 3 thus ?case by(cases bs) (auto simp: subtree_SplitC)
qed auto

lemma points_union:
  " max (height t1) (height t2)  k*n  
  points k n (union t1 t2) = points k n t1  points k n t2"
proof(induction n arbitrary: t1 t2)
  case 0 thus ?case by(clarsimp simp add: height_0_iff)
next
  case (Suc n)
  have "height t1  k * Suc n" "height t2  k * Suc n"
    using Suc.prems by auto
  from height_subtree2[OF this(1)] height_subtree2[OF this(2)] show ?case
    by(auto simp: Suc.IH subtree_union points_Suc bits_def)
qed

lemma get_union:
  " max (height t1) (height t2)  length ps * n  
  get n (union t1 t2) ps = (get n t1 ps  get n t2 ps)"
proof(induction n arbitrary: t1 t2 ps)
  case 0 thus ?case by(clarsimp simp add: height_0_iff)
next
  case (Suc n)
  have "height t1  length ps * Suc n" "height t2  length ps * Suc n"
    using Suc.prems(1) by auto
  from height_subtree2[OF this(1)] height_subtree2[OF this(2)] show ?case
    by(simp add: Suc.IH subtree_union get_Suc)
qed

lemma height_union: "height (union t1 t2)  max (height t1) (height t2)"
by(induction t1 t2 rule: union.induct) (auto simp: height_SplitC2)

lemma compressed_union: "compressed t1  compressed t2  compressed(union t1 t2)"
by(induction t1 t2 rule: union.induct) (simp_all add: compressedI)

(*unused_thms*)

end