Theory Build

(*
  File:     Build.thy
  Author:   Martin Rau, TU München
*)

section ‹Building a balanced k›-d Tree from a List of Points›

theory Build
imports
  KD_Tree
  Median_Of_Medians_Selection.Median_Of_Medians_Selection
begin

text ‹
  Build a balanced k›-d Tree by recursively partition the points into two lists.
  The partitioning criteria will be the median at a particular axis k›.
  The left list will contain all points p› with @{term "p$k  median"}.
  The right list will contain all points with median at axis @{term "median < p$k"}.
  The left and right list differ in length by one or none.
  The axis k› will the widest spread axis.
›

subsection "Auxiliary Lemmas"

lemma length_filter_mset_sorted_nth:
  assumes "distinct xs" "n < length xs" "sorted xs"
  shows "{# x ∈# mset xs. x  xs ! n #} = mset (take (n + 1) xs)"
  using assms
proof (induction xs arbitrary: n rule: list.induct)
  case (Cons x xs)
  thus ?case
  proof (cases n)
    case 0
    thus ?thesis
      using Cons.prems(1,3) filter_mset_is_empty_iff by fastforce
  next
    case (Suc n')
    thus ?thesis
      using Cons by simp
  qed
qed auto

lemma length_filter_sort_nth:
  assumes "distinct xs" "n < length xs"
  shows "length (filter (λx. x  sort xs ! n) xs) = n + 1"
proof -
  have "length (filter (λx. x  sort xs ! n) xs) = length (filter (λx. x  sort xs ! n) (sort xs))"
    by (simp add: filter_sort)
  also have "... = size (mset (filter (λx. x  sort xs ! n) (sort xs)))"
    using size_mset by metis
  also have "... = size ({# x ∈# mset (sort xs). x  sort xs ! n #})"
    using mset_filter by simp
  also have "... = size (mset (take (n + 1) (sort xs)))"
    using length_filter_mset_sorted_nth assms sorted_sort distinct_sort length_sort by metis
  finally show ?thesis
    using assms(2) by auto
qed


subsection ‹Widest Spread Axis›

definition calc_spread :: "('k::finite)  'k point list  real" where
  "calc_spread k ps = (case ps of []  0 | ps 
    let ks = map (λp. p$k) (tl ps) in
    fold max ks ((hd ps)$k) - fold min ks ((hd ps)$k)
  )"

fun widest_spread :: "('k::finite) list  'k point list  'k × real" where
  "widest_spread [] _ = undefined"
| "widest_spread [k] ps = (k, calc_spread k ps)"
| "widest_spread (k # ks) ps = (
    let (k', s') = widest_spread ks ps in
    let s = calc_spread k ps in
    if s  s' then (k', s') else (k, s)
  )"

lemma calc_spread_spec:
  "calc_spread k ps = spread k (set ps)"
  using Max.set_eq_fold[of "(hd ps)$k"] Min.set_eq_fold[of "(hd ps)$k"]
  by (auto simp: Let_def spread_def calc_spread_def split: list.splits, metis set_map)

lemma widest_spread_calc_spread:
  "ks  []  (k, s) = widest_spread ks ps  s = calc_spread k ps"
  by (induction ks ps rule: widest_spread.induct) (auto simp: Let_def split: prod.splits if_splits)

lemma widest_spread_axis_Un:
  shows "widest_spread_axis k K P  spread k' P  spread k P  widest_spread_axis k (K  { k' }) P"
    and "widest_spread_axis k K P  spread k P  spread k' P  widest_spread_axis k' (K  { k' }) P"
  unfolding widest_spread_axis_def by auto

lemma widest_spread_spec:
  "(k, s) = widest_spread ks ps  widest_spread_axis k (set ks) (set ps)"
proof (induction ks ps arbitrary: k s rule: widest_spread.induct)
  case (3 k0 k1 ks ps)
  obtain K' S' where K'_def: "(K', S') = widest_spread (k1 # ks) ps"
    by (metis surj_pair)
  hence IH: "widest_spread_axis K' (set (k1 # ks)) (set ps)"
    using "3.IH" by blast
  hence 0: "S' = spread K' (set ps)"
    using K'_def widest_spread_calc_spread calc_spread_spec by blast
  define S where "S = calc_spread k0 ps"
  hence 1: "S = spread k0 (set ps)"
    using calc_spread_spec by blast
  show ?case
  proof (cases "S  S'")
    case True
    hence "widest_spread_axis K' (set (k0 # k1 # ks)) (set ps)"
      using 0 1 widest_spread_axis_Un(1)[OF IH, of k0]  by auto
    thus ?thesis
      using True K'_def S_def "3.prems" by (auto split: prod.splits)
  next
    case False
    hence "widest_spread_axis k0 (set (k0 # k1 # ks)) (set ps)"
      using 0 1 widest_spread_axis_Un(2)[OF IH, of k0] "3.prems"(1) by auto
    thus ?thesis
      using False K'_def S_def "3.prems" by (auto split: prod.splits)
  qed
qed (auto simp: widest_spread_axis_def)


subsection ‹Fast Axis Median›

definition axis_median :: "('k::finite)  'k point list  real" where
  "axis_median k ps = (let n = (length ps - 1) div 2 in fast_select n (map (λp. p$k) ps))"

lemma length_filter_le_axis_median:
  assumes "0 < length ps" "k. distinct (map (λp. p$k) ps)"
  shows "length (filter (λp. p$k  axis_median k ps) ps) = (length ps - 1) div 2 + 1"
proof -
  let ?n = "(length ps - 1) div 2"
  let ?ps = "map (λp. p$k) ps"
  let ?m = "fast_select ?n ?ps"
  have 0: "?n < length ?ps"
    using assms(1) by (auto, linarith)
  have 1: "distinct ?ps"
    using assms(2) by blast
  have "?m = select ?n ?ps"
    using fast_select_correct[OF 0] by blast
  hence "length (filter (λp. p$k  axis_median k ps) ps) =
        length (filter (λp. p$k  sort ?ps ! ?n) ps)"
    unfolding axis_median_def by (auto simp add: Let_def select_def simp del: fast_select.simps)
  also have "... = length (filter (λv. v  sort ?ps ! ?n) ?ps)"
    by (induction ps) (auto, metis comp_apply)
  also have "... = ?n + 1"
    using length_filter_sort_nth[OF 1 0] by blast
  finally show ?thesis .
qed

definition partition_by_median :: "('k::finite)  'k point list  'k point list × real × 'k point list" where
  "partition_by_median k ps = (
     let m = axis_median k ps in
     let (l, r) = partition (λp. p$k  m) ps in
     (l, m, r)
  )"

lemma set_partition_by_median:
  "(l, m, r) = partition_by_median k ps  set ps = set l  set r"
  unfolding partition_by_median_def by (auto simp: Let_def)

lemma filter_partition_by_median:
  assumes "(l, m, r) = partition_by_median k ps"
  shows "p  set l. p$k  m"
    and "p  set r. ¬p$k  m"
  using assms unfolding partition_by_median_def by (auto simp: Let_def)

lemma sum_length_partition_by_median:
  assumes "(l, m, r) = partition_by_median k ps"
  shows "length ps = length l + length r"
  using assms sum_length_filter_compl[of "(λp. p $ k  axis_median k ps)"]
  unfolding partition_by_median_def by (simp add: Let_def o_def)

lemma length_l_partition_by_median:
  assumes "0 < length ps" "k. distinct (map (λp. p$k) ps)" "(l, m, r) = partition_by_median k ps"
  shows "length l = (length ps - 1) div 2 + 1"
  using assms unfolding partition_by_median_def by (auto simp: Let_def length_filter_le_axis_median)

corollary lengths_partition_by_median_1:
  assumes "0 < length ps"  "k. distinct (map (λp. p$k) ps)" "(l, m, r) = partition_by_median k ps"
  shows "length l - length r  1"
    and "length r  length l"
    and "0 < length l"
    and "length r < length ps"
  using length_l_partition_by_median[OF assms] sum_length_partition_by_median[OF assms(3)] by auto

corollary lengths_partition_by_median_2:
  assumes "1 < length ps" "k. distinct (map (λp. p$k) ps)" "(l, m, r) = partition_by_median k ps"
  shows "0 < length r"
    and "length l < length ps"
proof -
  have *: "0 < length ps"
    using assms(1) by auto
  show "0 < length r" "length l < length ps"
    using length_l_partition_by_median[OF * assms(2,3)] sum_length_partition_by_median[OF assms(3)]
    using assms(1) by linarith+
qed

lemmas length_partition_by_median =
  sum_length_partition_by_median length_l_partition_by_median
  lengths_partition_by_median_1 lengths_partition_by_median_2


subsection ‹Building the Tree›

function (domintros, sequential) build :: "('k::finite) list  'k point list  'k kdt" where
  "build _ [] = undefined"
| "build _ [p] = Leaf p"
| "build ks ps = (
    let (k, _) = widest_spread ks ps in
    let (l, m, r) = partition_by_median k ps in
    Node k m (build ks l) (build ks r)
  )"
  by pat_completeness auto

lemma build_domintros3:
  assumes "(k, s) = widest_spread ks (x # y # zs)" "(l, m, r) = partition_by_median k (x # y # zs)"
  assumes "build_dom (ks, l)" "build_dom (ks, r)"
  shows "build_dom (ks, x # y # zs)"
proof -
  {
    fix k s l m r
    assume "(k, s) = widest_spread ks (x # y # zs)" "(l, m, r) = partition_by_median k (x # y # zs)"
    hence "build_dom (ks, l)" "build_dom (ks, r)"
      using assms by (metis Pair_inject)+
  }
  thus ?thesis
    by (simp add: build.domintros(3))
qed

lemma build_termination:
  assumes "k. distinct (map (λp. p$k) ps)"
  shows "build_dom (ks, ps)"
  using assms
proof (induction ps rule: length_induct)
  case (1 xs)
  consider (A) "xs = []" | (B) "x. xs = [x]" | (C) "x y zs. xs = x # y # zs"
    by (induction xs rule: induct_list012) auto
  then show ?case
  proof cases
    case C
    then obtain x y zs where xyzs_def: "xs = x # y # zs"
      by blast
    obtain k s where ks_def: "(k, s) = widest_spread ks xs"
      by (metis surj_pair)
    obtain l m r where lmr_def: "(l, m, r) = partition_by_median k xs"
      by (metis prod_cases3)
    note defs = xyzs_def ks_def lmr_def
    have "k. distinct (map (λp. p $ k) l)" "k. distinct (map (λp. p $ k) r)"
      using lmr_def unfolding partition_by_median_def
      by (auto simp: Let_def "1.prems" distinct_map_filter)
    moreover have "length l < length xs" "length r < length xs"
      using length_partition_by_median(8)[OF _ "1.prems"] length_partition_by_median(6)[OF _ "1.prems"]
      using defs by auto
    ultimately have "build_dom (ks, l)" "build_dom (ks, r)"
      using "1.IH" by blast+
    thus ?thesis
      using build_domintros3 defs by blast
  qed (auto intro: build.domintros)
qed

lemma build_psimp_1:
  "ps = [p]  build k ps = Leaf p"
  by (simp add: build.domintros(2) build.psimps(2))

lemma build_psimp_2:
  assumes "(k, s) = widest_spread ks (x # y # zs)" "(l, m, r) = partition_by_median k (x # y # zs)"
  assumes "build_dom (ks, l)" "build_dom (ks, r)"
  shows "build ks (x # y # zs) = Node k m (build ks l) (build ks r)"
proof -
  have 0: "build_dom (ks, x # y # zs)"
    using assms build_domintros3 by blast
  thus ?thesis
    using build.psimps(3)[OF 0] assms(1,2) by (auto split: prod.splits)
qed

lemma length_xs_gt_1:
  "1 < length xs  x y ys. xs = x # y # ys"
  by (cases xs, auto simp: neq_Nil_conv)

lemma build_psimp_3:
  assumes "1 < length ps" "(k, s) = widest_spread ks ps" "(l, m, r) = partition_by_median k ps"
  assumes "build_dom (ks, l)" "build_dom (ks, r)"
  shows "build ks ps = Node k m (build ks l) (build ks r)"
  using build_psimp_2 length_xs_gt_1 assms by blast

lemmas build_psimps[simp] = build_psimp_1 build_psimp_3


subsection ‹Main Theorems›

theorem set_build:
  "0 < length ps  k. distinct (map (λp. p$k) ps)  set ps = set_kdt (build ks ps)"
proof (induction ps rule: length_induct)
  case (1 ps)
  show ?case
  proof (cases "1 < length ps")
    case True
    obtain k s where ks_def: "(k, s) = widest_spread ks ps"
      by (metis surj_pair)
    obtain l m r where lmr_def: "(l, m, r) = partition_by_median k ps"
      by (metis prod_cases3)
    have D: "k. distinct (map (λp. p$k) l)" "k. distinct (map (λp. p$k) r)"
      using lmr_def unfolding partition_by_median_def
      by (auto simp: "1.prems"(2) Let_def distinct_map_filter)
    moreover have "length l < length ps" "0 < length l"
                  "length r < length ps" "0 < length r"
      using length_partition_by_median(8)[OF True "1.prems"(2)]
            length_partition_by_median(5)[OF "1.prems"(1) "1.prems"(2)]
            length_partition_by_median(6)[OF "1.prems"(1) "1.prems"(2)]
            length_partition_by_median(7)[OF True "1.prems"(2)]
            lmr_def by blast+
    ultimately have "set l = set_kdt (build ks l)" "set r = set_kdt (build ks r)"
      using "1.IH" by blast+
    moreover have "set ps = set l  set r"
      using lmr_def unfolding partition_by_median_def by (auto simp: Let_def)
    moreover have "build ks ps = Node k m (build ks l) (build ks r)"
      using build_psimp_3[OF True ks_def lmr_def] build_termination D by blast
    ultimately show ?thesis
      by simp
  next
    case False
    thus ?thesis
      using "1.prems" by (cases ps) auto
  qed
qed

theorem invar_build:
  "0 < length ps  k. distinct (map (λp. p$k) ps)  set ks = UNIV  invar (build ks ps)"
proof (induction ps rule: length_induct)
  case (1 ps)
  show ?case
  proof (cases "1 < length ps")
    case True
    obtain k s where ks_def: "(k, s) = widest_spread ks ps"
      by (metis surj_pair)
    obtain l m r where lmr_def: "(l, m, r) = partition_by_median k ps"
      by (metis prod_cases3)
    have D: "k. distinct (map (λp. p$k) l)" "k. distinct (map (λp. p$k) r)"
      using lmr_def unfolding partition_by_median_def
      by (auto simp: "1.prems"(2) Let_def distinct_map_filter)
    moreover have "length l < length ps" "0 < length l"
                  "length r < length ps" "0 < length r"
      using length_partition_by_median(8)[OF True "1.prems"(2)]
            length_partition_by_median(5)[OF "1.prems"(1) "1.prems"(2)]
            length_partition_by_median(6)[OF "1.prems"(1) "1.prems"(2)]
            length_partition_by_median(7)[OF True "1.prems"(2)]
            lmr_def by blast+
    ultimately have "invar (build ks l)" "invar (build ks r)"
      using "1.IH" "1.prems"(3) by blast+
    moreover have "p  set l. p$k  m" "p  set r. m < p$k"
      using filter_partition_by_median(1)[OF lmr_def]
            filter_partition_by_median(2)[OF lmr_def] by auto
    moreover have "widest_spread_axis k UNIV (set l  set r)"
      using widest_spread_spec[OF ks_def] "1.prems"(3) set_partition_by_median[OF lmr_def] by simp
    moreover have "build ks ps = Node k m (build ks l) (build ks r)"
      using build_psimp_3[OF True ks_def lmr_def] build_termination D by blast
    ultimately show ?thesis
      using set_build[OF 0 < length l D(1)] set_build[OF 0 < length r D(2)] by simp
  next
    case False
    thus ?thesis
      using "1.prems" by (cases ps) auto
  qed
qed

theorem size_build:
  "0 < length ps  k. distinct (map (λp. p$k) ps)  size_kdt (build ks ps) = length ps"
proof (induction ps rule: length_induct)
  case (1 ps)
  show ?case
  proof (cases "1 < length ps")
    case True
    obtain k s where ks_def: "(k, s) = widest_spread ks ps"
      by (metis surj_pair)
    obtain l m r where lmr_def: "(l, m, r) = partition_by_median k ps"
      by (metis prod_cases3)
    have D: "k. distinct (map (λp. p$k) l)" "k. distinct (map (λp. p$k) r)"
      using lmr_def unfolding partition_by_median_def
      by (auto simp: "1.prems"(2) Let_def distinct_map_filter)
    moreover have "length l < length ps" "0 < length l"
                  "length r < length ps" "0 < length r"
      using length_partition_by_median(8)[OF True "1.prems"(2)]
            length_partition_by_median(5)[OF "1.prems"(1) "1.prems"(2)]
            length_partition_by_median(6)[OF "1.prems"(1) "1.prems"(2)]
            length_partition_by_median(7)[OF True "1.prems"(2)]
            lmr_def by blast+
    ultimately have "size_kdt (build ks l) = length l" "size_kdt (build ks r) = length r"
      using "1.IH" by blast+
    moreover have "build ks ps = Node k m (build ks l) (build ks r)"
      using build_psimp_3[OF True ks_def lmr_def] build_termination D by blast
    ultimately show ?thesis
      using length_partition_by_median(1)[OF lmr_def] by simp
  next
    case False
    thus ?thesis
      using "1.prems" by (cases ps) auto
  qed
qed

theorem balanced_build:
  "0 < length ps  k. distinct (map (λp. p$k) ps)  balanced (build ks ps)"
proof (induction ps rule: length_induct)
  case (1 ps)
  show ?case
  proof (cases "1 < length ps")
    case True
    obtain k s where ks_def: "(k, s) = widest_spread ks ps"
      by (metis surj_pair)
    obtain l m r where lmr_def: "(l, m, r) = partition_by_median k ps"
      by (metis prod_cases3)
    have D: "k. distinct (map (λp. p$k) l)" "k. distinct (map (λp. p$k) r)"
      using lmr_def unfolding partition_by_median_def
      by (auto simp: "1.prems"(2) Let_def distinct_map_filter)
    moreover have "length l < length ps" "0 < length l"
                  "length r < length ps" "0 < length r"
      using length_partition_by_median(8)[OF True "1.prems"(2)]
            length_partition_by_median(5)[OF "1.prems"(1) "1.prems"(2)]
            length_partition_by_median(6)[OF "1.prems"(1) "1.prems"(2)]
            length_partition_by_median(7)[OF True "1.prems"(2)]
            lmr_def by blast+
    ultimately have IH: "balanced (build ks l)" "balanced (build ks r)"
      using "1.IH" by blast+
    have "build ks ps = Node k m (build ks l) (build ks r)"
      using build_psimp_3[OF True ks_def lmr_def] build_termination D by blast
    moreover have "length r + 1 = length l  length r = length l"
      using length_partition_by_median(1)[OF lmr_def]
            length_partition_by_median(3)[OF "1.prems"(1) "1.prems"(2) lmr_def]
            length_partition_by_median(4)[OF "1.prems"(1) "1.prems"(2) lmr_def]
      by linarith
    ultimately show ?thesis
      using balanced_Node_if_wbal1[OF IH] balanced_Node_if_wbal2[OF IH]
            size_build[OF 0 < length l D(1)] size_build[OF 0 < length r D(2)]
      by auto
  next
    case False
    thus ?thesis
      using "1.prems" by (cases ps) (auto simp: balanced_def)
  qed
qed

lemma complete_if_balanced_size_2powh:
  assumes "balanced kdt" "size_kdt kdt = 2 ^ h"
  shows "complete kdt"
proof (rule ccontr)
  assume "¬ complete kdt"
  hence "2 ^ (min_height kdt) < size_kdt kdt" "size_kdt kdt < 2 ^ height kdt"
    by (simp_all add: min_height_size_if_incomplete size_height_if_incomplete)
  hence "height kdt - min_height kdt > 1"
    using assms(2) by simp
  hence "¬ balanced kdt"
    using balanced_def by force
  thus "False"
    using assms(1) by simp
qed

theorem complete_build:
  "length ps = 2 ^ h  k. distinct (map (λp. p$k) ps)  complete (build k ps)"
  by (simp add: balanced_build complete_if_balanced_size_2powh size_build)

corollary height_build:
  assumes "length ps = 2 ^ h" "k. distinct (map (λp. p$k) ps)"
  shows "h = height (build k ps)"
  using complete_build[OF assms] size_build[OF _ assms(2)] by (simp add: assms(1) complete_iff_size)

end