Theory Randomized_Closest_Pair_Correct

section ‹Correctness›

text ‹This section verifies that the algorithm always returns the correct result.

Because the algorithm checks every pair of points in the same or in neighboring cells. It is
enough to establish that the grid distance is at least the distance of the closest pair.

The latter is true by construction, because the grid distance is chosen as a minimum of actually
occurring point distances.›

theory Randomized_Closest_Pair_Correct
  imports Randomized_Closest_Pair
begin

definition min_dist :: "('a::metric_space) list  real"
  where "min_dist xs = Min {dist x y|x y. {# x, y#} ⊆# mset xs}"

text ‹For a list with length at least two, the result is the minimum distance between the points
of any two elements of the list. This means that @{term "min_dist xs = 0"}, if and only if the same
point occurs twice in the list.

Note that this means, we won't assume the distinctness of the input list, and show the correctness
of the algorithm in the above sense.›

lemma image_conv_2: "{f x y|x y. p x y} = (case_prod f) ` {(x,y). p x y}" by auto

lemma min_dist_set_fin: "finite {dist x y|x y. {#x, y#} ⊆# mset xs}"
proof -
  have a:"finite (set xs × set xs)" by simp
  have "x ∈# mset xs  y ∈# mset xs" if "{#x, y#} ⊆# mset xs" for x y
    using that by (meson insert_union_subset_iff mset_subset_eq_insertD)
  thus ?thesis unfolding image_conv_2 by (intro finite_imageI finite_subset[OF _ a]) auto
qed

lemma min_dist_ne:  "length xs  2  {dist x y|x y. {# x,y#} ⊆# mset xs}  {}" (is "?L  ?R")
proof
  assume ?L
  then obtain xh1 xh2 xt where xs:"xs=xh1#xh2#xt" by (metis Suc_le_length_iff numerals(2))
  hence "{#xh1,xh2#} ⊆# mset xs" unfolding xs by simp
  thus ?R by auto
next
  assume ?R
  then obtain x y where xy: "{#x,y#} ⊆# mset xs" by auto
  have "2   size {#x, y#}" by simp
  also have "...  size (mset xs)" by (intro size_mset_mono xy)
  finally have "2  size (mset xs)" by simp
  thus ?L by simp
qed
lemmas min_dist_neI = iffD1[OF min_dist_ne]

lemma min_dist_nonneg:
  assumes "length xs  2"
  shows "min_dist xs  0"
  unfolding min_dist_def by (intro Min.boundedI min_dist_set_fin assms iffD1[OF min_dist_ne]) auto

lemma min_dist_pos_iff:
  assumes "length xs  2"
  shows "distinct xs  0 < min_dist xs"
proof -
  have "¬(distinct xs)  (x. count (mset xs) x  of_bool (x  set xs))"
    unfolding of_bool_def distinct_count_atmost_1 by fastforce
  also have "...  (x. count (mset xs) x  {0,1})"
    using count_mset_0_iff by (intro ex_cong1) simp
  also have "...  (x. count (mset xs) x  count {#x, x#} x)"
    by (intro ex_cong1) (simp add:numeral_eq_Suc Suc_le_eq dual_order.strict_iff_order)
  also have "...  (x. {#x, x#} ⊆# mset xs)" by (intro ex_cong1) (simp add: subseteq_mset_def)
  also have "...  0  {dist x y |x y. {#x, y#} ⊆# mset xs}" by auto
  also have "...  min_dist xs = 0" (is "?L  ?R")
  proof
    assume "?L"
    hence "min_dist xs  0" unfolding min_dist_def by (intro Min_le min_dist_set_fin)
    thus "min_dist xs = 0" using min_dist_nonneg[OF assms] by auto
  next
    assume "?R"
    thus "0  {dist x y |x y. {#x, y#} ⊆# mset xs}"
      unfolding min_dist_def using Min_in[OF min_dist_set_fin min_dist_neI[OF assms]] by simp
  qed
  finally  have "¬(distinct xs)  min_dist xs = 0" by simp
  thus ?thesis using min_dist_nonneg[OF assms] by auto
qed

lemma multiset_filter_mono_2:
  assumes "x. x  set_mset xs  P x  Q x"
  shows "filter_mset P xs ⊆# filter_mset Q xs" (is "?L ⊆# ?R")
proof -
  have "?L = filter_mset (λx. Q x  P x) xs" using assms by (intro filter_mset_cong) auto
  also have "... = filter_mset P (filter_mset Q xs)" by (simp add:filter_filter_mset)
  also have "... ⊆# ?R" by simp
  finally show ?thesis by simp
qed

lemma filter_mset_disj:
  "filter_mset (λx. p x  q x) xs = filter_mset (λx. p x  ¬ q x) xs + filter_mset q xs"
  by (induction xs) auto

lemma size_filter_mset_decompose:
  assumes "finite T"
  shows "size (filter_mset (λx. f x  T) xs) = (t  T. size (filter_mset (λx. f x = t) xs))"
  using assms
proof (induction T)
  case empty thus ?case by simp
next
  case (insert x F) thus ?case by (simp add:filter_mset_disj) metis
qed

lemma size_filter_mset_decompose':
  "size (filter_mset (λx. f x  T) xs) = sum' (λt. size (filter_mset (λx. f x = t) xs)) T"
  (is "?L = ?R")
proof -
  let ?T = "f ` set_mset xs  T"
  have "?L = size (filter_mset (λx. f x  ?T) xs)"
    by (intro arg_cong[where f="size"] filter_mset_cong) auto
  also have "... = (t  ?T. size (filter_mset (λx. f x = t) xs))"
    by (intro size_filter_mset_decompose) auto
  also have "... = sum' (λt. size (filter_mset (λx. f x = t) xs)) ?T"
    by (intro sum.eq_sum[symmetric]) auto
  also have "... = ?R" by (intro sum.mono_neutral_left') auto
  finally show ?thesis by simp
qed


lemma filter_product:
  "filter (λx. P (fst x)Q (snd x)) (List.product xs ys) = List.product (filter P xs) (filter Q ys)"
proof (induction xs)
  case Nil thus ?case by simp
next
  case (Cons xh xt) thus ?case by (simp add:filter_map comp_def)
qed

lemma floor_diff_bound: "¦x-y¦  ¦x - (y::real)¦" by linarith

lemma power2_strict_mono:
  fixes x y :: "'a :: linordered_idom"
  assumes "¦x¦ < ¦y¦"
  shows "x^2 < y^2"
  using assms unfolding power2_eq_square
  by (metis abs_mult_less abs_mult_self_eq)

definition "grid ps d =  g_dist = d, g_lookup = (λq. map_tm return (filter (λx. to_grid d x = q) ps)) "

lemma build_grid_val: "val (build_grid ps d) = grid ps d"
  unfolding build_grid_def grid_def by simp

lemma lookup_neighborhood:
  "mset (val (lookup_neighborhood (grid ps d) p)) =
  filter_mset (λx. to_grid d x - to_grid d p  {(0,0),(0,1),(1,-1),(1,0),(1,1)}) (mset ps) - {#p#}"
proof -
  define ls where  "ls = [(0::int,0::int),(0,1),(1,-1),(1,0),(1,1)]"
  define g where "g = grid ps d"
  define cs where "cs = map ((+) (to_grid (g_dist g) p)) ([(0,0),(0,1),(1,-1),(1,0),(1,1)])"

  have distinct_ls: "distinct ls" unfolding ls_def by (simp add: upto.simps)

  have "mset (concat (map (λx. val (g_lookup g (x + to_grid (g_dist g) p))) ls)) =
    mset (concat (map (λx. filter (λq. to_grid d q - to_grid d p = x) ps) ls))"
    by (simp add:grid_def filter_eq_val_filter_tm cs_def comp_def algebra_simps ls_def g_def)
  also have "... = {# q ∈# mset ps. to_grid d q - to_grid d p  set ls #}"
    using distinct_ls by (induction ls) (simp_all add:filter_mset_disj, metis)
  also have "... = {#x ∈# mset ps. to_grid d x - to_grid d p  {(0,0),(0,1),(1,-1),(1,0),(1,1)}#}"
    unfolding ls_def by simp
  finally have a:
    "mset (concat (map (λx. val (g_lookup g (x + to_grid (g_dist g) p))) ls)) =
    {#x ∈# mset ps. to_grid d x - to_grid d p  {(0,0),(0,1),(1,-1),(1,0),(1,1)}#}" by simp

  thus ?thesis
    unfolding g_def[symmetric] lookup_neighborhood_def ls_def[symmetric]
    by (simp add:val_remove1 comp_def)
qed

lemma fin_nat_pairs: "finite {(i, j). i < j  j < (n::nat)}"
  by (rule finite_subset[where B="{..<n }×{..<n}"]) auto

lemma mset_list_subset:
  assumes "distinct ys" "set ys  {..<length xs}"
  shows  "mset (map ((!) xs) ys) ⊆# mset xs" (is "?L ⊆# ?R")
proof -
  have "mset ys ⊆# mset [0..<length xs]" using assms
    by (metis finite_lessThan mset_set_set mset_set_upto_eq_mset_upto subset_imp_msubset_mset_set)
  hence "image_mset ((!) xs) (mset ys) ⊆# image_mset ((!) xs) (mset ([0..<length xs]))"
    by (intro image_mset_subseteq_mono)
  moreover have "image_mset ((!) xs) (mset ([0..<length xs])) = mset xs" by (metis map_nth mset_map)
  ultimately show ?thesis by simp
qed

lemma sample_distance:
  assumes "length ps  2"
  shows "AE d in map_pmf val (sample_distance ps). min_dist ps  d"
proof -
  let ?S = "{i. fst i < snd i  snd i < length ps}"
  let ?p = "pmf_of_set ?S"

  have "(0,1)  ?S" using assms by auto
  hence a:"finite ?S" "?S  {}"
    using fin_nat_pairs[where n="length ps"] by (auto simp:case_prod_beta')

  have "min_dist ps  dist (ps ! (fst x)) (ps ! (snd x))" if "x  ?S" for x
  proof -
    have "mset (map ((!) ps) [fst x, snd x]) ⊆# mset ps"
      using that by (intro mset_list_subset) auto
    hence "{#ps ! fst x, ps ! snd x#} ⊆# mset ps" by simp
    hence "(λ(x, y). dist x y) (ps ! (fst x), ps ! (snd x))  {dist x y |x y. {#x, y#} ⊆# mset ps}"
      unfolding image_conv_2 by (intro imageI) simp
    thus ?thesis unfolding min_dist_def by (intro Min_le min_dist_set_fin) simp
  qed
  thus ?thesis
    using a unfolding sample_distance_def map_pmf_def[symmetric] val_tpmf_simps
    by (intro AE_pmfI) (auto)
qed

lemma first_phase:
  assumes "length ps  2"
  shows "AE d in map_pmf val (first_phase ps). min_dist ps  d"
proof -
  have "min_dist ps  val (min_list_tm ds)"
    if ds_range:"set dsset_pmf(map_pmf val (sample_distance ps))" and "length ds=length ps" for ds
  proof -
    have ds_ne: "ds  []" using assms that(2) by auto

    have "min_dist ps  a" if "a  set ds" for a
    proof -
      have "a  set_pmf (map_pmf val (sample_distance ps))" using ds_range that by auto
      thus ?thesis using sample_distance[OF assms] by (auto simp add: AE_measure_pmf_iff)
    qed
    hence "min_dist ps  Min (set ds)" using ds_ne by (intro Min.boundedI) auto
    also have "... = min_list ds" unfolding min_list_Min[OF ds_ne] by simp
    also have "... = val (min_list_tm ds)" by (intro val_min_list[symmetric] ds_ne)
    finally show ?thesis by simp
  qed

  thus ?thesis
    unfolding first_phase_def val_tpmf_simps val_replicate_tpmf
    by (intro AE_pmfI) (auto simp:set_replicate_pmf)
qed

definition grid_lex_ord :: "int * int  int * int  bool"
  where "grid_lex_ord x y = (fst x < fst y  (fst x = fst y  snd x  snd y))"

lemma grid_lex_order_antisym: "grid_lex_ord x y  grid_lex_ord y x"
  unfolding grid_lex_ord_def by auto

lemma grid_dist:
  fixes p q :: point
  assumes "d > 0"
  shows  "¦p $ k/d - q $ k/d¦  dist p q/d"
proof -
  have "¦p$k - q$k¦ = sqrt ((p$k - q$k)^2)" by simp
  also have "... = sqrt (jUNIV. of_bool(j=k)*(p$j - q$j)^2)" by simp
  also have "...  dist p q" unfolding dist_vec_def L2_set_def
    by (intro real_sqrt_le_mono sum_mono) (auto simp:dist_real_def)
  finally have "¦p$k - q$k¦  dist p q" by simp
  hence 0:"¦p$k /d - q$k /d¦ dist p q /d" using assms by (simp add:field_simps)
  have "¦p$k/d - q$k/d¦  ¦p$k /d - q$k /d¦" by (intro floor_diff_bound)
  also have "...  dist p q/d" by (intro ceiling_mono 0)
  finally show ?thesis by simp
qed

lemma grid_dist_2:
  fixes p q :: point
  assumes "d > 0"
  assumes "dist p q/d  s"
  shows  "to_grid d p - to_grid d q  {-s..s}×{-s..s}"
proof -
  have "f (to_grid d p) - f (to_grid d q)  {-s..s}" if "f = fst  f = snd" for f
  proof -
    have "¦f (to_grid d p) - f (to_grid d q)¦  dist p q/d"
      using that grid_dist[OF assms(1)] unfolding to_grid_def by auto
    also have "...  s" by (intro assms(2))
    finally have "¦f (to_grid d p) - f (to_grid d q)¦  s" by simp
    thus ?thesis by auto
  qed
  thus ?thesis by (simp add:mem_Times_iff)
qed

lemma grid_dist_3:
  fixes p q :: point
  assumes "d > 0"
  assumes "dist q p/d  1" "grid_lex_ord (to_grid d p) (to_grid d q)"
  shows  "to_grid d q - to_grid d p  {(0,0),(0,1),(1,-1),(1,0),(1,1)}"
proof -
  have a:"{-1..1} = {-1,0,1::int}" by auto
  let ?r = "to_grid d q - to_grid d p"
  have "?r  {-1..1}×{-1..1}" by (intro grid_dist_2 assms(1-2))
  moreover have "?r  {(-1,0),(-1,-1),(-1,1),(0,-1)}" using assms(3)
    unfolding grid_lex_ord_def insert_iff de_Morgan_disj
    by (intro conjI notI) (simp_all add:algebra_simps)
  ultimately show ?thesis unfolding a by simp
qed

lemma second_phase_aux:
  assumes "d > 0" "min_dist ps  d" "length ps  2"
  obtains u v where
    "min_dist ps = dist u v"
    "{#u, v#} ⊆# mset ps"
    "grid_lex_ord (to_grid d u) (to_grid d v)"
    "u  set ps" "v  set (val (lookup_neighborhood (grid ps d) u))"
proof -
  have "u v. min_dist ps = dist u v  {#u, v#} ⊆# mset ps"
    unfolding min_dist_def using Min_in[OF min_dist_set_fin min_dist_neI[OF assms(3)]] by auto

  then obtain u v where uv:
    "min_dist ps = dist u v" "{#u, v#} ⊆# mset ps"
    "grid_lex_ord (to_grid d u) (to_grid d v)"
    using add_mset_commute dist_commute grid_lex_order_antisym by (metis (no_types, lifting))

  have u_range: "u  set ps" using uv(2) set_mset_mono by fastforce

  have "to_grid d v - to_grid d u  {(0,0),(0,1),(1,-1),(1,0),(1,1)}"
    using assms(1,2) uv(1,3) by (intro grid_dist_3) (simp_all add:dist_commute)

  hence "v ∈# mset (val (lookup_neighborhood (grid ps d) u))"
    using uv(2) unfolding lookup_neighborhood by (simp add: in_diff_count insert_subset_eq_iff)

  thus ?thesis using that u_range uv by simp
qed

lemma second_phase:
  assumes "d > 0" "min_dist ps  d" "length ps  2"
  shows "val (second_phase d ps) = min_dist ps" (is "?L = ?R")
proof -
  let ?g = "grid ps d"

  have "u v. min_dist ps = dist u v  {#u, v#} ⊆# mset ps"
    unfolding min_dist_def using Min_in[OF min_dist_set_fin min_dist_neI[OF assms(3)]] by auto

  then obtain u v where uv:
    "min_dist ps = dist u v" "{#u, v#} ⊆# mset ps"
    "grid_lex_ord (to_grid d u) (to_grid d v)"
    and u_range: "u  set ps"
    and v_range: "v  set (val (lookup_neighborhood (grid ps d) u))"
    using second_phase_aux[OF assms] by auto

  hence a: "val (lookup_neighborhood (grid ps d) u)  []" by auto

  have "xset ps. min_dist ps  dist x ` set (val (lookup_neighborhood (grid ps d) x))"
    using v_range uv(1) by (intro bexI[where x="u"] u_range) simp

  hence b: "Min (xset ps. dist x ` set (val (lookup_neighborhood (grid ps d) x)))  min_dist ps"
    by (intro Min.coboundedI finite_UN_I) simp_all

  have "{# x, y#} ⊆# mset ps"
    if "x  set ps" "y  set (val (lookup_neighborhood (grid ps d) x))" for x y
  proof -
    have "y ∈# mset (val (lookup_neighborhood (grid ps d) x))" using that by simp
    moreover have "mset (val (lookup_neighborhood (grid ps d) x)) ⊆#  mset ps - {#x#}"
      using that(1) unfolding lookup_neighborhood subset_eq_diff_conv by simp
    ultimately have "y ∈# mset ps - {#x#}" by (metis mset_subset_eqD)
    moreover have "x ∈# mset ps" using that(1) by simp
    ultimately show "{#x, y#} ⊆# mset ps" by (simp add: insert_subset_eq_iff)
  qed
  hence c: "min_dist ps  Min (xset ps. dist x ` set (val (lookup_neighborhood (grid ps d) x)))"
    unfolding min_dist_def using a u_range by (intro Min_antimono min_dist_set_fin) auto

  have "?L = val (min_list_tm (concat (map (λx. map (dist x) (val (lookup_neighborhood ?g x))) ps)))"
    unfolding second_phase_def by (simp add:calc_dists_neighborhood_def build_grid_val)
  also have "... = min_list (concat (map (λx. map (dist x) (val (lookup_neighborhood ?g x))) ps))"
    using assms(3) a u_range by (intro val_min_list) auto
  also have "... = Min (xset ps. dist x ` set (val (lookup_neighborhood ?g x)))"
    using a u_range by (subst min_list_Min) auto
  also have "... = min_dist ps" using b c by simp
  finally show ?thesis by simp
qed

text ‹Main result of this section:›

theorem closest_pair_correct:
  assumes "length ps  2"
  shows "AE r in map_pmf val (closest_pair ps). r = min_dist ps"
proof -
  define fp where "fp = map_pmf val (first_phase ps)"

  have "r = min_dist ps" if
    "d  fp"
    "r = (if d = 0 then 0 else val (second_phase d ps))" for r d
  proof -
    have d_ge: "d  min_dist ps"
      using that(1) first_phase[OF assms] unfolding AE_measure_pmf_iff fp_def[symmetric] by simp
    show ?thesis
    proof (cases "d > 0")
      case True
      thus ?thesis using second_phase[OF True d_ge assms] that(2)
        by (simp add: AE_measure_pmf_iff)
    next
      case False
      hence "d = 0" "min_dist ps = 0" using d_ge min_dist_nonneg[OF assms] by auto
      then show ?thesis using that(2) by auto
    qed
  qed
  thus ?thesis unfolding closest_pair_def val_tpmf_simps fp_def[symmetric] if_distrib
    by (intro AE_pmfI) (auto simp:if_distrib)
qed

end