Theory Median_Of_Medians_Selection.Median_Of_Medians_Selection

(*
  File:     Median_Of_Medians_Selection.thy
  Author:   Manuel Eberl, TU München

  The median-of-medians selection algorithm, which runs deterministically in linear time.
*)
section ‹The Median-of-Medians Selection Algorithm›
theory Median_Of_Medians_Selection
  imports Complex_Main "HOL-Library.Multiset"
begin

subsection ‹Some facts about lists and multisets›

lemma mset_concat: "mset (concat xss) = sum_list (map mset xss)"
  by (induction xss) simp_all

lemma set_mset_sum_list [simp]: "set_mset (sum_list xs) = (xset xs. set_mset x)"
  by (induction xs) auto

lemma filter_mset_image_mset:
  "filter_mset P (image_mset f A) = image_mset f (filter_mset (λx. P (f x)) A)"
  by (induction A) auto

lemma filter_mset_sum_list: "filter_mset P (sum_list xs) = sum_list (map (filter_mset P) xs)"
  by (induction xs) simp_all

lemma sum_mset_mset_mono: 
  assumes "(x. x ∈# A  f x ⊆# g x)"
  shows   "(x∈#A. f x) ⊆# (x∈#A. g x)"
  using assms by (induction A) (auto intro!: subset_mset.add_mono)

lemma mset_filter_mono:
  assumes "A ⊆# B" "x. x ∈# A  P x  Q x"
  shows   "filter_mset P A ⊆# filter_mset Q B"
  by (rule mset_subset_eqI) (insert assms, auto simp: mset_subset_eq_count count_eq_zero_iff)

lemma size_mset_sum_mset_distrib: "size (sum_mset A :: 'a multiset) = sum_mset (image_mset size A)"
  by (induction A) auto

lemma sum_mset_mono:
  assumes "x. x ∈# A  f x  (g x :: 'a :: {ordered_ab_semigroup_add,comm_monoid_add})"
  shows   "(x∈#A. f x)  (x∈#A. g x)"
  using assms by (induction A) (auto intro!: add_mono)

lemma filter_mset_is_empty_iff: "filter_mset P A = {#}  (x. x ∈# A  ¬P x)"
  by (auto simp: multiset_eq_iff count_eq_zero_iff)

lemma sorted_filter_less_subset_take:
  assumes "sorted xs" "i < length xs"
  shows   "{# x ∈# mset xs. x < xs ! i #} ⊆# mset (take i xs)"
  using assms
proof (induction xs arbitrary: i rule: list.induct)
  case (Cons x xs i)
  show ?case
  proof (cases i)
    case 0
    thus ?thesis using Cons.prems by (auto simp: filter_mset_is_empty_iff)
  next
    case (Suc i')
    have "{#y ∈# mset (x # xs). y < (x # xs) ! i#} ⊆# add_mset x {#y ∈# mset xs. y < xs ! i'#}"
      using Suc Cons.prems by (auto)
    also have " ⊆# add_mset x (mset (take i' xs))"
      unfolding mset_subset_eq_add_mset_cancel using Cons.prems Suc
      by (intro Cons.IH) (auto)
    also have " = mset (take i (x # xs))" by (simp add: Suc)
    finally show ?thesis .
  qed
qed auto

lemma sorted_filter_greater_subset_drop:
  assumes "sorted xs" "i < length xs"
  shows   "{# x ∈# mset xs. x > xs ! i #} ⊆# mset (drop (Suc i) xs)"
  using assms
proof (induction xs arbitrary: i rule: list.induct)
  case (Cons x xs i)
  show ?case
  proof (cases i)
    case 0
    thus ?thesis by (auto simp: sorted_append filter_mset_is_empty_iff)
  next
    case (Suc i')
    have "{#y ∈# mset (x # xs). y > (x # xs) ! i#} ⊆# {#y ∈# mset xs. y > xs ! i'#}"
      using Suc Cons.prems by (auto simp: set_conv_nth)
    also have " ⊆# mset (drop (Suc i') xs)"
      using Cons.prems Suc by (intro Cons.IH) (auto)
    also have " = mset (drop (Suc i) (x # xs))" by (simp add: Suc)
    finally show ?thesis .
  qed
qed auto


subsection ‹The dual order type›

text ‹
  The following type is a copy of a given ordered base type, but with the ordering reversed.
  This will be useful later because we can do some of our reasoning simply by symmetry.
›
typedef 'a dual_ord = "UNIV :: 'a set" morphisms of_dual_ord to_dual_ord
  by auto

setup_lifting type_definition_dual_ord

instantiation dual_ord :: (ord) ord
begin

lift_definition less_eq_dual_ord :: "'a dual_ord  'a dual_ord  bool" is
  "λa b :: 'a. a  b" .

lift_definition less_dual_ord :: "'a dual_ord  'a dual_ord  bool" is
  "λa b :: 'a. a > b" .

instance ..
end

instance dual_ord :: (preorder) preorder
  by standard (transfer; force simp: less_le_not_le intro: order_trans)+

instance dual_ord :: (linorder) linorder
  by standard (transfer; force simp: not_le)+


subsection ‹Chopping a list into equal-sized sublists›

(* TODO: Move to library? *)
function chop :: "nat  'a list  'a list list" where
  "chop n [] = []"
| "chop 0 xs = []"
| "n > 0  xs  []  chop n xs = take n xs # chop n (drop n xs)"
  by force+
termination by lexicographic_order

context
  includes lifting_syntax
begin

lemma chop_transfer [transfer_rule]: 
  "((=) ===> list_all2 R ===> list_all2 (list_all2 R)) chop chop"
proof (intro rel_funI)
  fix m n ::nat and xs :: "'a list" and ys :: "'b list"
  assume "m = n" "list_all2 R xs ys"
  from this(2) have "list_all2 (list_all2 R) (chop n xs) (chop n ys)"
  proof (induction n xs arbitrary: ys rule: chop.induct)
    case (3 n xs ys)
    hence "ys  []" by auto
    with 3 show ?case by auto
  qed auto
  with m = n show "list_all2 (list_all2 R) (chop m xs) (chop n ys)" by simp
qed

end

lemma chop_reduce: "chop n xs = (if n = 0  xs = [] then [] else take n xs # chop n (drop n xs))"
  by (cases "n = 0"; cases "xs = []") auto

lemma concat_chop [simp]: "n > 0  concat (chop n xs) = xs"
  by (induction n xs rule: chop.induct) auto

lemma chop_elem_not_Nil [simp,dest]: "ys  set (chop n xs)  ys  []"
  by (induction n xs rule: chop.induct) (auto simp: eq_commute[of "[]"])

lemma chop_eq_Nil_iff [simp]: "chop n xs = []  n = 0  xs = []"
  by (induction n xs rule: chop.induct) auto  

lemma chop_ge_length_eq: "n > 0  xs  []  n  length xs  chop n xs = [xs]"
  by simp

lemma length_chop_part_le: "ys  set (chop n xs)  length ys  n"
  by (induction n xs rule: chop.induct) auto

lemma length_nth_chop:
  assumes "i < length (chop n xs)"
  shows   "length (chop n xs ! i) = 
             (if i = length (chop n xs) - 1  ¬n dvd length xs then length xs mod n else n)"
proof (cases "n = 0")
  case False
  thus ?thesis
    using assms
  proof (induction n xs arbitrary: i rule: chop.induct)
    case (3 n xs i)
    show ?case
    proof (cases i)
      case 0
      thus ?thesis using "3.prems"
      by (cases "length xs < n") (auto simp: le_Suc_eq dest: dvd_imp_le)
    next
      case [simp]: (Suc i')
      with "3.prems" have [simp]: "xs  []" by auto
      with "3.prems" have *: "length xs > n" by (cases "length xs  n") simp_all
      with "3.prems" have "chop n xs ! i = chop n (drop n xs) ! i'" by simp
      also have "length  = (if i = length (chop n xs) - 1  ¬ n dvd (length xs - n)
                                then (length xs - n) mod n else n)"
        by (subst "3.IH") (use Suc "3.prems" in auto)
      also have "n dvd (length xs - n)  n dvd length xs"
        using * by (subst dvd_minus_self) auto
      also have "(length xs - n) mod n = length xs mod n"
        using * by (subst le_mod_geq [symmetric]) auto
      finally show ?thesis .
    qed
  qed auto
qed (insert assms, auto)

lemma length_chop:
  assumes "n > 0"
  shows   "length (chop n xs) = nat length xs / n"
  using assms
proof (induction n xs rule: chop.induct)
  case (3 n xs)
  show ?case
  proof (cases "length xs  n")
    case False
    hence "real (length xs) / real n = 1" using "3.hyps"
      by (intro ceiling_unique) auto
    with False show ?thesis using "3.prems" "3.hyps"
      by (auto simp: chop_ge_length_eq not_le)
  next
    case True
    hence "real (length xs) = real n + real (length (drop n xs))"
      by simp
    also have " / real n = real (length (drop n xs)) / real n + 1"
      using n > 0 by (simp add: divide_simps)
    also have "ceiling  = ceiling (real (length (drop n xs)) / real n) + 1" by simp
    also have "nat  = nat (ceiling (real (length (drop n xs)) / real n)) + nat 1"
      by (intro nat_add_distrib[OF order.trans[OF _ ceiling_mono[of 0]]]) auto
    also have " = length (chop n xs)"
      using n > 0 "3.hyps" by (subst "3.IH" [symmetric]) auto
    finally show ?thesis ..
  qed
qed auto

lemma sum_msets_chop: "n > 0  (yschop n xs. mset ys) = mset xs"
  by (subst mset_concat [symmetric]) simp_all

lemma UN_sets_chop: "n > 0  (ysset (chop n xs). set ys) = set xs"
  by (simp only: set_concat [symmetric] concat_chop)

lemma in_set_chopD [dest]:
  assumes "x  set ys" "ys  set (chop d xs)"
  shows   "x  set xs"
proof (cases "d > 0")
  case True
  thus ?thesis by (subst UN_sets_chop [symmetric]) (use assms in auto)
qed (use assms in auto)


subsection ‹$k$-th order statistics and medians›

text ‹
  This returns the $k$-th smallest element of a list. This is also known as the $k$-th order
  statistic.
›
definition select :: "nat  'a list  ('a :: linorder)" where
  "select k xs = sort xs ! k"

text ‹
  The median of a list, where, for lists of even lengths, the smaller one is favoured:
›
definition median where "median xs = select ((length xs - 1) div 2) xs"

lemma select_in_set [intro,simp]:
  assumes "k < length xs"
  shows   "select k xs  set xs"
proof -
  from assms have "sort xs ! k  set (sort xs)" by (intro nth_mem) auto
  also have "set (sort xs) = set xs" by simp
  finally show ?thesis by (simp add: select_def)
qed

lemma median_in_set [intro, simp]: 
  assumes "xs  []"
  shows   "median xs  set xs"
proof -
  from assms have "length xs > 0" by auto
  hence "(length xs - 1) div 2 < length xs" by linarith
  thus ?thesis by (simp add: median_def)
qed

text ‹
  We show that selection and medians does not depend on the order of the elements:
›
lemma sort_cong: "mset xs = mset ys  sort xs = sort ys"
  by (rule properties_for_sort) simp_all

lemma select_cong:
  "k = k'  mset xs = mset xs'  select k xs = select k' xs'"
  by (auto simp: select_def dest: sort_cong)

lemma median_cong: "mset xs = mset xs'  median xs = median xs'"
  unfolding median_def by (intro select_cong) (auto dest: mset_eq_length)


text ‹
  Selection distributes over appending lists under certain conditions:
›
lemma sort_append:
  assumes "x y. x  set xs  y  set ys  x  y"
  shows   "sort (xs @ ys) = sort xs @ sort ys"
  using assms  by (intro properties_for_sort) (auto simp: sorted_append)

lemma select_append:
  assumes "y z. y  set ys  z  set zs  y  z"
  shows   "k < length ys  select k (ys @ zs) = select k ys"
          "k  {length ys..<length ys + length zs} 
             select k (ys @ zs) = select (k - length ys) zs"
  using assms by (simp_all add: select_def sort_append nth_append)

lemma select_append':
  assumes "y z. y  set ys  z  set zs  y  z" "k < length ys + length zs"
  shows   "select k (ys @ zs) = (if k < length ys then select k ys else select (k - length ys) zs)"
  using assms by (auto intro!: select_append)


text ‹
  We can find simple upper bounds for the number of elements that are strictly less than (resp.
  greater than) the median of a list.
›
lemma size_less_than_median:
  "size {#y ∈# mset xs. y < median xs#}  (length xs - 1) div 2"
proof (cases "xs = []")
  case False
  hence "length xs > 0" by simp
  hence "(length xs - 1) div 2 < length xs" by linarith
  hence "size {#y ∈# mset (sort xs). y < median xs#}  
           size (mset (take ((length xs - 1) div 2) (sort xs)))"
    unfolding median_def select_def using False
    by (intro size_mset_mono sorted_filter_less_subset_take) auto
  thus ?thesis using False by simp
qed auto

lemma size_greater_than_median:
  "size {#y ∈# mset xs. y > median xs#}  length xs div 2"
proof (cases "xs = []")
  case False
  hence "length xs > 0" by simp
  hence "(length xs - 1) div 2 < length xs" by linarith
  hence "size {#y ∈# mset (sort xs). y > median xs#}  
           size (mset (drop (Suc ((length xs - 1) div 2)) (sort xs)))"
    unfolding median_def select_def using False
    by (intro size_mset_mono sorted_filter_greater_subset_drop) auto
  hence "size (filter_mset (λy. y > median xs) (mset xs)) 
           length xs - Suc ((length xs - 1) div 2)" by simp
  also have " = length xs div 2" by linarith
  finally show ?thesis .
qed auto


subsection ‹A more liberal notion of medians›

text ‹
  We now define a more relaxed version of being ``a median'' as opposed to being ``\emph{the}
  median''. A value is a median if at most half the values in the list are strictly smaller 
  than it and at most half are strictly greater. Note that, by this definition, the median does
  not even have to be in the list itself.
›
definition is_median :: "'a :: linorder  'a list  bool" where
  "is_median x xs  length (filter (λy. y < x) xs)  length xs div 2 
                      length (filter (λy. y > x) xs)  length xs div 2"

text ‹
  We set up some transfer rules for @{const is_median}. In particular, we have a rule that
  shows that something is a median for a list iff it is a median on that list w.\,r.\,t.\ 
  the dual order, which will later allow us to argue by symmetry.
›
context
  includes lifting_syntax
begin
lemma transfer_is_median [transfer_rule]:
  assumes [transfer_rule]: "(r ===> r ===> (=)) (<) (<)"
  shows   "(r ===> list_all2 r ===> (=)) is_median is_median"
  unfolding is_median_def by transfer_prover

lemma list_all2_eq_fun_conv_map: "list_all2 (λx y. x = f y) xs ys  xs = map f ys"
proof
  assume "list_all2 (λx y. x = f y) xs ys"
  thus "xs = map f ys" by induction auto
next
  assume "xs = map f ys"
  moreover have "list_all2 (λx y. x = f y) (map f ys) ys"
    by (induction ys) auto
  ultimately show "list_all2 (λx y. x = f y) xs ys" by simp
qed

lemma transfer_is_median_dual_ord [transfer_rule]:
  "(pcr_dual_ord (=) ===> list_all2 (pcr_dual_ord (=)) ===> (=)) is_median is_median"
  by (auto simp: pcr_dual_ord_def cr_dual_ord_def OO_def rel_fun_def is_median_def 
        list_all2_eq_fun_conv_map o_def less_dual_ord.rep_eq)
end

lemma is_median_to_dual_ord_iff [simp]:
  "is_median (to_dual_ord x) (map to_dual_ord xs)  is_median x xs"
  unfolding is_median_def by transfer auto


text ‹
  The following is an obviously equivalent definition of @{const is_median} in terms of
  multisets that is occasionally nicer to use.
›
lemma is_median_altdef:
  "is_median x xs  size (filter_mset (λy. y < x) (mset xs))  length xs div 2 
                      size (filter_mset (λy. y > x) (mset xs))  length xs div 2"
proof -
  have *: "length (filter P xs) = size (filter_mset P (mset xs))" for P and xs :: "'a list"
    by (simp flip: mset_filter)
  show ?thesis by (simp only: is_median_def *)
qed

lemma is_median_cong:
  assumes "x = y" "mset xs = mset ys"
  shows   "is_median x xs  is_median y ys"
  unfolding is_median_altdef by (simp only: assms mset_eq_length[OF assms(2)])

text ‹
  If an element is the median of a list of odd length, we can add any element to the list
  and the element is still a median. Conversely, if we want to compute a median of a list with
  even length $n$, we can simply drop one element and reduce the problem to a median of a list
  of size $n - 1$.
›
lemma is_median_Cons_odd:
  assumes "is_median x xs" and "odd (length xs)"
  shows   "is_median x (y # xs)"
  using assms by (auto simp: is_median_def)

text ‹
  And, of course, \emph{the} median is a median.
›
lemma is_median_median [simp,intro]: "is_median (median xs) xs"
  using size_less_than_median[of xs] size_greater_than_median[of xs]
  unfolding is_median_def size_mset [symmetric] mset_filter by linarith+


subsection ‹Properties of a median-of-medians›

text ‹
  We can now bound the number of list elements that can be strictly smaller than a 
  median-of-medians of a chopped-up list (where each part has length $d$ except for the last one,
  which can also be shorter).

  The core argument is that at least roughly half of the medians of the sublists are greater or 
  equal to the median-of-medians, and about $\frac{d}{2}$ elements in each such sublist are greater
  than or equal to their median and thereby also than the median-of-medians.
›
lemma size_less_than_median_of_medians_strong:
  fixes xs :: "'a :: linorder list" and d :: nat
  assumes d: "d > 0"
  assumes median: "xs. xs  []  length xs  d  is_median (med xs) xs"
  assumes median': "is_median x (map med (chop d xs))"
  defines "m  length (chop d xs)"
  shows   "size {#y ∈# mset xs. y < x#}  m * (d div 2) + m div 2 * ((d + 1) div 2)"
proof -
  define n where [simp]: "n = length xs"
  ― ‹The medians of the sublists›
  define M where "M = mset (map med (chop d xs))"
  define YS where "YS = mset (chop d xs)"
  ― ‹The sublists with a smaller median than the median-of-medians @{term x} and the rest.›
  define YS1 where "YS1 = filter_mset (λys. med ys < x) (mset (chop d xs))"
  define YS2 where "YS2 = filter_mset (λys. ¬(med ys < x)) (mset (chop d xs))"

  ― ‹At most roughly half of the lists have a median that is smaller than @{term M}
  have "size YS1 = size (image_mset med YS1)" by simp
  also have "image_mset med YS1 = {#y ∈# mset (map med (chop d xs)). y < x#}"
    unfolding YS1_def by (subst filter_mset_image_mset [symmetric]) simp_all
  also have "size   (length (map med (chop d xs))) div 2"
    using median' unfolding is_median_altdef by simp
  also have " = m div 2" by (simp add: m_def)
  finally have size_YS1: "size YS1  m div 2" .

  have "m = size (mset (chop d xs))" by (simp add: m_def)
  also have "mset (chop d xs) = YS1 + YS2" unfolding YS1_def YS2_def
    by (rule multiset_partition)
  finally have m_eq: "m = size YS1 + size YS2" by simp

  ― ‹We estimate the number of elements less than @{term x} by grouping them into elements
      coming from @{term YS1} and elements coming from @{term YS2}. In the first case, we 
      just note that no more than @{term d} elements can come from each sublist, whereas in
      the second case, we make the analysis more precise and note that only elements that are
      less than the median of their sublist can be less than @{term x}.›
  have "{# y ∈# mset xs. y < x#} = {# y ∈# (yschop d xs. mset ys). y < x#}" using d
    by (subst sum_msets_chop) simp_all
  also have " = (yschop d xs. {#y ∈# mset ys. y < x#})"
    by (subst filter_mset_sum_list) (simp add: o_def)
  also have " = (ys∈#YS. {#y ∈# mset ys. y < x#})" unfolding YS_def
    by (subst sum_mset_sum_list [symmetric]) simp_all
  also have "YS = YS1 + YS2"
    by (simp add: YS_def YS1_def YS2_def not_le)
  also have "(ys∈#. {#y ∈# mset ys. y < x#}) = 
               (ys∈#YS1. {#y ∈# mset ys. y < x#}) + (ys∈#YS2. {#y ∈# mset ys. y < x#})"
    by simp
  also have " ⊆# (ys∈#YS1. mset ys) + (ys∈#YS2. {#y ∈# mset ys. y < med ys#})"
    by (intro subset_mset.add_mono sum_mset_mset_mono mset_filter_mono) (auto simp: YS2_def)
  finally have "{# y ∈# mset xs. y < x #} ⊆# " .
  hence "size {# y ∈# mset xs. y < x #}  size " by (rule size_mset_mono)

  ― ‹We do some further straightforward estimations and arrive at our goal.›
  also have " = (ys∈#YS1. length ys) + (x∈#YS2. size {#y ∈# mset x. y < med x#})"
    by (simp add: size_mset_sum_mset_distrib multiset.map_comp o_def)
  also have "(ys∈#YS1. length ys)  (ys∈#YS1. d)"
    by (intro sum_mset_mono) (auto simp: YS1_def length_chop_part_le)
  also have " = size YS1 * d" by simp
  also have d: "d = (d div 2) + ((d + 1) div 2)" using d by linarith
  have "size YS1 * d = size YS1 * (d div 2) + size YS1 * ((d + 1) div 2)"
    by (subst d) (simp add: algebra_simps)
  also have "(ys∈#YS2. size {#y ∈# mset ys. y < med ys#}) 
               (ys∈#YS2. length ys div 2)"
  proof (intro sum_mset_mono size_less_than_median, goal_cases)
    case (1 ys)
    hence "ys  []" "length ys  d" by (auto simp: YS2_def length_chop_part_le)
    from median[OF this] show ?case by (auto simp: is_median_altdef)
  qed
  also have "  (ys∈#YS2. d div 2)"
    by (intro sum_mset_mono div_le_mono diff_le_mono) (auto simp: YS2_def dest: length_chop_part_le)
  also have " = size YS2 * (d div 2)" by simp
  also have "size YS1 * (d div 2) + size YS1 * ((d + 1) div 2) +  =
               m * (d div 2) + size YS1 * ((d + 1) div 2)" by (simp add: m_eq algebra_simps)
  also have "size YS1 * ((d + 1) div 2)  (m div 2) * ((d + 1) div 2)"
    by (intro mult_right_mono size_YS1) auto
  finally show "size {#y ∈# mset xs. y < x#} 
                  m * (d div 2) + m div 2 * ((d + 1) div 2)" by simp_all
qed

text ‹
  We now focus on the case of an odd chopping size and make some further estimations to 
  simplify the above result a little bit.
›
theorem size_less_than_median_of_medians:
  fixes xs :: "'a :: linorder list" and d :: nat
  assumes median: "xs. xs  []  length xs  Suc (2 * d)  is_median (med xs) xs"
  assumes median': "is_median x (map med (chop (Suc (2*d)) xs))"
  defines "n  length xs"
  defines "c  (3 * real d + 1) / (2 * (2 * d + 1))"
  shows   "size {#y ∈# mset xs. y < x#}  nat c * n + (5 * d) div 2 + 1"
proof (cases "xs = []")
  case False
  define m where "m = length (chop (Suc (2*d)) xs)"

  have "real (m div 2)  real (nat real n / (1 + 2 * real d)) / 2"
    by (simp add: m_def length_chop n_def flip: of_nat_int_ceiling)
  also have "real (nat real n / (1 + 2 * real d)) =
               of_int real n / (1 + 2 * real d)"
    by (intro of_nat_nat) (auto simp: divide_simps)
  also have " / 2  (real n / (1 + 2 * real d) + 1) / 2"
    by (intro divide_right_mono) linarith+
  also have " = n / (2 * (2 * real d + 1)) + 1 / 2" by (simp add: field_simps)
  finally have m: "real (m div 2)  " .

  have "size {#y ∈# mset xs. y < x#}  d * m + Suc d * (m div 2)"
    using size_less_than_median_of_medians_strong[of "Suc (2 * d)" med x xs] assms
    unfolding m_def by (simp add: algebra_simps)
  also have "  d * (2 * (m div 2) + 1) + Suc d * (m div 2)"
    by (intro add_mono mult_left_mono) linarith+
  also have " = (3 * d + 1) * (m div 2) + d"
    by (simp add: algebra_simps)
  finally have "real (size {#y ∈# mset xs. y < x#})  real "
    by (subst of_nat_le_iff)
  also have "  (3 * real d + 1) * (n / (2 * (2 * d + 1)) + 1/2) + real d"
    unfolding of_nat_add of_nat_mult of_nat_1 of_nat_numeral
    by (intro add_mono mult_mono order.refl m) (auto simp: m_def length_chop n_def add_ac)
  also have " = c * real n + (5 * real d + 1) / 2"
    by (simp add: field_simps c_def)
  also have "  real (nat c * n + ((5 * d) div 2 + 1))"
    unfolding of_nat_add by (intro add_mono) (linarith, simp add: field_simps)
  finally show ?thesis by (subst (asm) of_nat_le_iff) (simp_all add: add_ac)
qed auto

text ‹
  We get the analogous result for the number of elements that are greater than a median-of-medians
  by looking at the dual order and using the \emph{transfer} method.
›
theorem size_greater_than_median_of_medians:
  fixes xs :: "'a :: linorder list" and d :: nat
  assumes median: "xs. xs  []  length xs  Suc (2 * d)  is_median (med xs) xs"
  assumes median': "is_median x (map med (chop (Suc (2*d)) xs))" 
  defines "n  length xs"
  defines "c  (3 * real d + 1) / (2 * (2 * d + 1))"
  shows   "size {#y ∈# mset xs. y > x#}  nat c * n + (5 * d) div 2 + 1"
proof -
  include lifting_syntax
  define med' where "med' = (λxs. to_dual_ord (med (map of_dual_ord xs)))"
  have "xs = map of_dual_ord ys" if "list_all2 cr_dual_ord xs ys" for xs :: "'a list" and ys
    using that by induction (auto simp: cr_dual_ord_def)
  hence [transfer_rule]: "(list_all2 (pcr_dual_ord (=)) ===> pcr_dual_ord (=)) med med'"
    by (auto simp: rel_fun_def pcr_dual_ord_def OO_def med'_def cr_dual_ord_def 
                   dual_ord.to_dual_ord_inverse)

  have "size {#y ∈# mset xs. y > x#} = length (filter (λy. y > x) xs)"
    by (subst size_mset [symmetric]) (simp only: mset_filter)
  also have " = length (map to_dual_ord (filter (λy. y > x) xs))" by simp
  also have "(λy. y > x) = (λy. to_dual_ord y < to_dual_ord x)"
    by transfer simp_all
  hence "length (map to_dual_ord (filter (λy. y > x) xs)) = length (map to_dual_ord (filter  xs))"
    by simp
  also have " = length (filter (λy. y < to_dual_ord x) (map to_dual_ord xs))"
    unfolding filter_map o_def by simp
  also have " = size {#y ∈# mset (map to_dual_ord xs). y < to_dual_ord x#}"
    by (subst size_mset [symmetric]) (simp only: mset_filter)
  also have "  nat (3 * real d + 1) / real (2 * (2 * d + 1)) * length (map to_dual_ord xs)
                    + 5 * d div 2 + 1"
  proof (intro size_less_than_median_of_medians)
    fix xs :: "'a dual_ord list" assume xs: "xs  []" "length xs  Suc (2 * d)"
    from xs show "is_median (med' xs) xs" by (transfer fixing: d) (rule median)
  next
    show "is_median (to_dual_ord x) (map med' (chop (Suc (2 * d)) (map to_dual_ord xs)))"
      by (transfer fixing: d x xs) (use median' in simp_all)
  qed
  finally show ?thesis by (simp add: n_def c_def)
qed


text ‹
  The most important case is that of chopping size 5, since that is the most practical one
  for the median-of-medians selection algorithm. For it, we obtain the following nice
  and simple bounds:
›
corollary size_less_greater_median_of_medians_5:
  fixes xs :: "'a :: linorder list"
  assumes "xs. xs  []  length xs  5  is_median (med xs) xs"
  assumes "is_median x (map med (chop 5 xs))" 
  shows "length (filter (λy. y < x) xs)  nat 0.7 * length xs + 6"
    and "length (filter (λy. y > x) xs)  nat 0.7 * length xs + 6"
  using size_less_than_median_of_medians[of 2 med x xs]
        size_greater_than_median_of_medians[of 2 med x xs] assms
  by (simp_all add: size_mset [symmetric] mset_filter mult_ac add_ac del: size_mset)


subsection ‹The recursive step›

text ‹
  We now turn to the actual selection algorithm itself. The following simple reduction lemma 
  illustrates the idea of the algorithm quite well already, but it has the disadvantage that,
  if one were to use it as a recursive algorithm, it would only work for lists with distinct
  elements. If the list contains repeated elements, this may not even terminate.

  The basic idea is that we choose some pivot element, partition the list into elements that 
  are bigger than the pivot and those that are not, and then recurse into one of these (hopefully
  smaller) lists.
›
theorem select_rec_partition:
  assumes "d > 0" "k < length xs"
  shows "select k xs = (
           let (ys, zs) = partition (λy. y  x) xs
           in  if k < length ys then select k ys else select (k - length ys) zs
          )" (is "_ = ?rhs")
proof -
  define ys zs where "ys = filter (λy. y  x) xs" and "zs = filter (λy. ¬(y  x)) xs"
  have "select k xs = select k (ys @ zs)"
    by (intro select_cong) (simp_all add: ys_def zs_def)
  also have " = (if k < length ys then select k ys else select (k - length ys) zs)"
    using assms(2) by (intro select_append') (auto simp: ys_def zs_def sum_length_filter_compl)
  finally show ?thesis by (simp add: ys_def zs_def Let_def o_def)
qed

text ‹
  The following variant uses a three-way partitioning function instead. This way, the size of
  the list in the final recursive call decreases by a factor of at least $\frac{3d'+1}{2(2d'+1)}$ 
  by the previous estimates, given that the chopping size is $d = 2d'+1$. For a chopping size of 5,
  we get a factor of $0.7$. 
›
definition threeway_partition :: "'a  'a :: linorder list  'a list × 'a list × 'a list" where
  "threeway_partition x xs = (filter (λy. y < x) xs, filter (λy. y = x) xs, filter (λy. y > x) xs)"

lemma threeway_partition_code [code]:
  "threeway_partition x [] = ([], [], [])"
  "threeway_partition x (y # ys) =
     (case threeway_partition x ys of (ls, es, gs) 
        if y < x then (y # ls, es, gs) else if x = y then (ls, y # es, gs) else (ls, es, y # gs))"
  by (auto simp: threeway_partition_def)

theorem select_rec_threeway_partition:
  assumes "d > 0" "k < length xs"
  shows "select k xs = (
           let (ls, es, gs) = threeway_partition x xs;
               nl = length ls; ne = length es
           in
             if k < nl then select k ls 
             else if k < nl + ne then x
             else select (k - nl - ne) gs
          )" (is "_ = ?rhs")
proof -
  define ls es gs where "ls = filter (λy. y < x) xs" and "es = filter (λy. y = x) xs"
                    and "gs = filter (λy. y > x) xs"
  define nl ne where [simp]: "nl = length ls" "ne = length es"
  have mset_eq: "mset xs = mset ls + mset es + mset gs" unfolding ls_def es_def gs_def
    by (induction xs) auto
  have length_eq: "length xs = length ls + length es + length gs" unfolding ls_def es_def gs_def
    by (induction xs) (auto simp del: filter_True)
    
  have [simp]: "select i es = x" if "i < length es" for i
  proof -
    have "select i es  set (sort es)" unfolding select_def
      using that by (intro nth_mem) auto
    hence "select i es  set es" using that by (auto simp: select_def)
    also have "set es  {x}" unfolding es_def by (induction es) auto
    finally show ?thesis by simp
  qed

  have "select k xs = select k (ls @ (es @ gs))"
    by (intro select_cong) (simp_all add: mset_eq)
  also have " = (if k < nl then select k ls else select (k - nl) (es @ gs))" 
    unfolding nl_ne_def using assms
    by (intro select_append') (auto simp: ls_def es_def gs_def length_eq)
  also have " = (if k < nl then select k ls else if k < nl + ne then x
                    else select (k - nl - ne) gs)" (is "?lhs' = ?rhs'")
  proof (cases "k < nl")
    case False
    hence "?lhs' = select (k - nl) (es @ gs)" by simp
    also have " = (if k - nl < ne then select (k - nl) es else select (k - nl - ne) gs)"
      unfolding nl_ne_def using assms False
      by (intro select_append') (auto simp: ls_def es_def gs_def length_eq)
    also have " = (if k - nl < ne then x else select (k - nl - ne) gs)"
      by simp
    also from False have " = ?rhs'" by auto
    finally show ?thesis .
  qed simp_all
  also have " = ?rhs"
    by (simp add: threeway_partition_def Let_def ls_def es_def gs_def)
  finally show ?thesis .
qed

text ‹
  By the above results, it can be seen quite easily that, in each recursive step, the algorithm
  takes a list of length $n$, does $O(n)$ work for the chopping, computing the medians of the
  sublists, and partitioning, and it calls itself recursively with lists of size at most
  $\lceil 0.2n\rceil$ and $\lceil 0.7n\rceil + 6$, respectively. This means that the runtime
  of the algorithm is bounded above by the Akra--Bazzi-style recurrence
    \[T(n) = T(\lceil 0.2n\rceil) + T(\lceil 0.7n\rceil + 6) + O(n)\]
  which, by the Akra--Bazzi theorem, can be shown to fulfil $T\in \Theta(n)$.

  However, a proper analysis of this would require an actual execution model and some way of 
  measuring the runtime of the algorithm, which is not what we aim to do here. Additionally, the
  entire algorithm can be performed in-place in an imperative way, but this because quite tedious.

  Instead of this, we will now focus on developing the above recursion into an executable 
  functional algorithm.
›


subsection ‹Medians of lists of length at most 5›

text ‹
  We now show some basic results about how to efficiently find a median of a list of size
  at most 5. For length 1 or 2, this is trivial, since we can just pick any element. For length 
  3 and 4, we need at most three comparisons. For length 5, we need at most six comparisons.

  This allows us to save some comparisons compared with the naive method of performing insertion
  sort and then returning the element in the middle.
›
definition median_3 :: "'a :: linorder  _" where
  "median_3 a b c =
     (if a  b then
        if b  c then b else max a c
      else
        if c  b then b else min a c)"

lemma median_3: "median_3 a b c = median [a, b, c]"
  by (auto simp: median_3_def median_def select_def min_def max_def)

definition median_5_aux :: "'a :: linorder  _" where
  "median_5_aux x1 x2 x3 x4 x5 = (
     if x2  x3 then if x2  x4 then min x3 x4 else min x2 x5
     else if x4  x3 then min x3 x5 else min x2 x4)"

lemma median_5_aux:
  assumes "x1  x2" "x4  x5" "x1  x4" 
  shows   "median_5_aux x1 x2 x3 x4 x5 = median [x1,x2,x3,x4,x5]"
  using assms by (auto simp: median_5_aux_def median_def select_def min_def)

definition median_5 :: "'a :: linorder  _" where
  "median_5 a b c d e = (
     let (x1, x2) = (if a  b then (a, b) else (b, a));
         (x4, x5) = (if d  e then (d, e) else (e, d))
     in
         if x1  x4 then median_5_aux x1 x2 c x4 x5 else median_5_aux x4 x5 c x1 x2)"

lemma median_5: "median_5 a b c d e = median [a, b, c, d, e]"
  by (auto simp: median_5_def Let_def median_5_aux intro: median_cong)

fun median_le_5 where
  "median_le_5 [a] = a"
| "median_le_5 [a,b] = a"
| "median_le_5 [a,b,c] = median_3 a b c"
| "median_le_5 [a,b,c,d] = median_3 a b c"
| "median_le_5 [a,b,c,d,e] = median_5 a b c d e"
| "median_le_5 _ = undefined"

lemma median_5_in_set: "median_5 a b c d e  {a, b, c, d, e}"
proof -
  have "median_5 a b c d e  set [a, b, c, d, e]"
    unfolding median_5 by (rule median_in_set) auto
  thus ?thesis by simp
qed

lemma median_le_5_in_set:
  assumes "xs  []" "length xs  5"
  shows   "median_le_5 xs  set xs"
proof (cases xs rule: median_le_5.cases)
  case (5 a b c d e)
  with median_5_in_set[of a b c d e] show ?thesis by simp
qed (insert assms, auto simp: median_3_def min_def max_def)

lemma median_le_5:
  assumes "xs  []" "length xs  5"
  shows   "is_median (median_le_5 xs) xs"
proof (cases xs rule: median_le_5.cases)
  case (3 a b c)
  have "is_median (median xs) xs" by simp
  also have "median xs = median_3 a b c" by (simp add: median_3 3)
  finally show ?thesis using 3 by simp
next
  case (4 a b c d)
  have "is_median (median [a,b,c]) [a,b,c]" by simp
  also have "median [a,b,c] = median_3 a b c" by (simp add: median_3 4)
  finally have "is_median (median_3 a b c) (d # [a,b,c])" by (rule is_median_Cons_odd) auto
  also have "?this  is_median (median_3 a b c) [a,b,c,d]" by (intro is_median_cong) auto
  finally show ?thesis using 4 by simp
next
  case (5 a b c d e)
  have "is_median (median xs) xs" by simp
  also have "median xs = median_5 a b c d e" by (simp add: median_5 5)
  finally show ?thesis using 5 by simp  
qed (insert assms, auto simp: is_median_def)


subsection ‹Median-of-medians selection algorithm›

text ‹
  The fast selection function now simply computes the median-of-medians of the chopped-up list
  as a pivot, partitions the list into with respect to that pivot, and recurses into one of 
  the resulting sublists.
›
function fast_select where
  "fast_select k xs = (
     if length xs  20 then
       sort xs ! k
     else
       let x = fast_select (((length xs + 4) div 5 - 1) div 2) (map median_le_5 (chop 5 xs));
           (ls, es, gs) = threeway_partition x xs
       in
         if k < length ls then fast_select k ls 
         else if k < length ls + length es then x
         else fast_select (k - length ls - length es) gs
      )"
  by auto

text ‹
  The correctness of this is obvious from the above theorems, but the proof is still
  somewhat complicated by the fact that termination depends on the correctness of the
  function.
›
lemma fast_select_correct_aux:
  assumes "fast_select_dom (k, xs)" "k < length xs"
  shows   "fast_select k xs = select k xs"
  using assms
proof induction
  case (1 k xs)
  show ?case
  proof (cases "length xs  20")
    case True
    thus ?thesis using "1.prems" "1.hyps"
      by (subst fast_select.psimps) (auto simp: select_def)
  next
    case False
    define x where
      "x = fast_select (((length xs + 4) div 5 - Suc 0) div 2) (map median_le_5 (chop 5 xs))"
    define ls where "ls = filter (λy. y < x) xs"
    define es where "es = filter (λy. y = x) xs"
    define gs where "gs = filter (λy. y > x) xs"
    define nl ne where "nl = length ls" and "ne = length es"
    note defs = nl_def ne_def x_def ls_def es_def gs_def
    have tw: "(ls, es, gs) = threeway_partition (fast_select (((length xs + 4) div 5 - 1) div 2)
                               (map median_le_5 (chop 5 xs))) xs"
      unfolding threeway_partition_def defs One_nat_def ..
    have tw': "(ls, es, gs) = threeway_partition x xs"
      by (simp add: tw x_def)

    have "fast_select k xs = (if k < nl then fast_select k ls else if k < nl + ne then x
                                else fast_select (k - nl - ne) gs)" using "1.hyps" False
      by (subst fast_select.psimps) (simp_all add: threeway_partition_def defs [symmetric])
    also have " = (if k < nl then select k ls else if k < nl + ne then x 
                       else select (k - nl - ne) gs)"
    proof (intro if_cong refl)
      assume *: "k < nl"
      show "fast_select k ls = select k ls"
        by (rule 1; (rule refl tw)?) 
           (insert *, auto simp: False threeway_partition_def ls_def x_def nl_def)+
    next
      assume *: "¬k < nl" "¬k < nl + ne"
      have **: "length xs = length ls + length es + length gs"
        unfolding ls_def es_def gs_def by (induction xs) (auto simp del: filter_True)
      show "fast_select (k - nl - ne) gs = select (k - nl - ne) gs"
        unfolding nl_def ne_def
        by (rule 1; (rule refl tw)?) (insert False * ** k < length xs, auto simp: nl_def ne_def)
    qed
    also have " = select k xs" using k < length xs
      by (subst (3) select_rec_threeway_partition[of "5::nat" _ _ x])
         (unfold Let_def nl_def ne_def ls_def gs_def es_def x_def threeway_partition_def, simp_all)
    finally show ?thesis .
  qed
qed

text ‹
  Termination of the algorithm is reasonably obvious because the lists that are recursed into
  never contain the pivot (the median-of-medians), while the original list clearly does.
  The proof is still somewhat technical though.
›
lemma fast_select_termination: "All fast_select_dom"
proof (relation "measure (length  snd)"; (safe)?, goal_cases)
  case (1 k xs)
  thus ?case
    by (auto simp: length_chop nat_less_iff ceiling_less_iff)
next
  fix k :: nat and xs ls es gs :: "'a list"
  define x where "x = fast_select (((length xs + 4) div 5 - 1) div 2) (map median_le_5 (chop 5 xs))"
  assume A: "¬ length xs  20" 
            "(ls, es, gs) = threeway_partition x xs"
            "fast_select_dom (((length xs + 4) div 5 - 1) div 2, 
                             map median_le_5 (chop 5 xs))"
  from A have eq: "ls = filter (λy. y < x) xs" "gs = filter (λy. y > x) xs"
    by (simp_all add: x_def threeway_partition_def)
  have len: "(length xs + 4) div 5 = nat length xs / 5" by linarith
  have less: "(nat real (length xs) / 5 - Suc 0) div 2 < nat real (length xs) / 5"
    using A(1) by linarith
  have "x = select (((length xs + 4) div 5 - 1) div 2) (map median_le_5 (chop 5 xs))"
    using less unfolding x_def by (intro fast_select_correct_aux A) (auto simp: length_chop len)
  also have " = median (map median_le_5 (chop 5 xs))" by (simp add: median_def len length_chop)
  finally have x: "x = " .
  moreover {
    have "x  set (map median_le_5 (chop 5 xs))"
      using A(1) unfolding x by (intro median_in_set) auto
    also have "  (ysset (chop 5 xs). {median_le_5 ys})" by auto
    also have "  (ysset (chop 5 xs). set ys)" using A(1)
      by (intro UN_mono) (auto simp: median_le_5_in_set length_chop_part_le)
    also have " = set xs" by (subst UN_sets_chop) auto
    finally have "x  set xs" .
  }  
  ultimately show "((k, ls), k, xs)  measure (length  snd)"
              and "((k - length ls - length es, gs), k, xs)  measure (length  snd)"
    using A(1) by (auto simp: eq intro!: length_filter_less[of x])
qed


text ‹
  We now have all the ingredients to show that @{const fast_select} terminates and does,
  indeed, compute the $k$-th order statistic.
›
termination fast_select by (rule fast_select_termination)

theorem fast_select_correct: "k < length xs  fast_select k xs = select k xs"
  using fast_select_termination by (intro fast_select_correct_aux) auto


text ‹
  The following version is then suitable for code export.
›
lemma fast_select_code [code]:
  "fast_select k xs = (
     if length xs  20 then
       fold insort xs [] ! k
     else
       let x = fast_select (((length xs + 4) div 5 - 1) div 2) (map median_le_5 (chop 5 xs));
           (ls, es, gs) = threeway_partition x xs;
           nl = length ls; ne = nl + length es
       in
         if k < nl then fast_select k ls 
         else if k < ne then x
         else fast_select (k - ne) gs
      )"
  by (subst fast_select.simps) (simp_all only: Let_def algebra_simps sort_conv_fold)

lemma select_code [code]: 
  "select k xs = (if k < length xs then fast_select k xs 
                    else Code.abort (STR ''Selection index out of bounds.'') (λ_. select k xs))"
proof (cases "k < length xs")
  case True
  thus ?thesis by (simp only: if_True fast_select_correct)
qed (simp_all only: Code.abort_def if_False)

end