Theory Schoenhage_Strassen_Runtime_Preliminaries

subsection "Some Running Time Formalizations"

theory Schoenhage_Strassen_Runtime_Preliminaries
imports
  Main
  "Karatsuba.Time_Monad_Extended"
  "Karatsuba.Main_TM"
  "Karatsuba.Karatsuba_Preliminaries"
  "Karatsuba.Nat_LSBF"
  "Karatsuba.Nat_LSBF_TM"
  "Karatsuba.Estimation_Method"
  "Schoenhage_Strassen_Preliminaries"
  "Akra_Bazzi.Akra_Bazzi"
  "HOL-Library.Landau_Symbols"
begin

fun zip_tm :: "'a list  'b list  ('a × 'b) list tm" where
"zip_tm xs [] =1 return []"
| "zip_tm [] ys =1 return []"
| "zip_tm (x # xs) (y # ys) =1 do { rs  zip_tm xs ys; return ((x, y) # rs) }"

lemma val_zip_tm[simp, val_simp]: "val (zip_tm xs ys) = zip xs ys"
  by (induction xs ys rule: zip_tm.induct; simp)

lemma time_zip_tm[simp]: "time (zip_tm xs ys) = min (length xs) (length ys) + 1"
  by (induction xs ys rule: zip_tm.induct; simp)

fun map3_tm :: "('a  'b  'c  'd tm)  'a list  'b list  'c list  'd list tm" where
"map3_tm f (x # xs) (y # ys) (z # zs) =1 do {
  r  f x y z;
  rs  map3_tm f xs ys zs;
  return (r # rs)
}"
| "map3_tm f _ _ _ =1 return []"

lemma val_map3_tm[simp, val_simp]: "val (map3_tm f xs ys zs) = map3 (λx y z. val (f x y z)) xs ys zs"
  by (induction f xs ys zs rule: map3_tm.induct; simp)

lemma time_map3_tm_bounded:
  assumes "x y z. x  set xs  y  set ys  z  set zs  time (f x y z)  c"
  shows "time (map3_tm f xs ys zs)  (c + 1) * min (min (length xs) (length ys)) (length zs) + 1"
using assms proof (induction f xs ys zs rule: map3.induct)
  case (1 f x xs y ys z zs)
  then have ih: "time (map3_tm f xs ys zs)  (c + 1) * min (min (length xs) (length ys)) (length zs) + 1"
    by simp
  from "1.prems" have fxyz: "time (f x y z)  c" by simp
  show ?case
    unfolding map3_tm.simps tm_time_simps
    apply (estimation estimate: ih)
    apply (estimation estimate: fxyz)
    by simp
qed simp_all

fun map4_tm :: "('a  'b  'c  'd  'e tm)  'a list  'b list  'c list  'd list  'e list tm" where
"map4_tm f (x # xs) (y # ys) (z # zs) (w # ws) =1 do {
  r  f x y z w;
  rs  map4_tm f xs ys zs ws;
  return (r # rs)
}"
| "map4_tm f _ _ _ _ =1 return []"

lemma val_map4_tm[simp, val_simp]: "val (map4_tm f xs ys zs ws) = map4 (λx y z w. val (f x y z w)) xs ys zs ws"
  by (induction f xs ys zs ws rule: map4_tm.induct; simp)

lemma time_map4_tm_bounded:
  assumes "x y z w. x  set xs  y  set ys  z  set zs  w  set ws  time (f x y z w)  c"
  shows "time (map4_tm f xs ys zs ws)  (c + 1) * min (min (min (length xs) (length ys)) (length zs)) (length ws) + 1"
using assms proof (induction f xs ys zs ws rule: map4.induct)
  case (1 f x xs y ys z zs w ws)
  then have ih: "time (map4_tm f xs ys zs ws)  (c + 1) * min (min (min (length xs) (length ys)) (length zs)) (length ws) + 1"
    by simp
  from "1.prems" have fxyzw: "time (f x y z w)  c" by simp
  show ?case
    unfolding map4_tm.simps tm_time_simps
    apply (estimation estimate: ih)
    apply (estimation estimate: fxyzw)
    by simp
qed simp_all

definition map2_tm where
"map2_tm f xs ys =1 do {
  xys  zip_tm xs ys;
  map_tm (λ(x,y). f x y) xys
}"

lemma val_map2_tm[simp, val_simp]: "val (map2_tm f xs ys) = map2 (λx y. val (f x y)) xs ys"
  unfolding map2_tm_def by (simp split: prod.splits)

lemma time_map2_tm_bounded:
  assumes "length xs = length ys"
  assumes "x y. x  set xs  y  set ys  time (f x y)  c"
  shows "time (map2_tm f xs ys)  (c + 2) * length xs + 3"
proof -
  have "time (map2_tm f xs ys) = length xs + 2 + time (map_tm (λ(x, y). f x y) (zip xs ys))"
    unfolding map2_tm_def by (simp add: assms)
  also have "...  length xs + 2 + ((c + 1) * length (zip xs ys) + 1)"
    apply (intro add_mono order.refl time_map_tm_bounded)
    using assms by (auto split: prod.splits elim: in_set_zipE)
  also have "... = (c + 2) * length xs + 3"
    using assms by simp
  finally show ?thesis .
qed

definition rotate_left_tm :: "nat  'a list  'a list tm" where
"rotate_left_tm k xs =1 do {
  lenxs  length_tm xs;
  kmod  k modt lenxs;
  (xs1, xs2)  split_at_tm kmod xs;
  xs2 @t xs1
}"

lemma val_rotate_left_tm[simp, val_simp]: "val (rotate_left_tm k xs) = rotate_left k xs"
  unfolding rotate_left_tm_def rotate_left_def by (simp add: Let_def)

lemma time_rotate_left_tm_le: "time (rotate_left_tm k xs)  13 + 14 * max k (length xs)"
proof -
  obtain xs1 xs2 where 1: "(xs1, xs2) = split_at (k mod length xs) xs"
    by simp
  then have 2: "length xs2  length xs" by simp
  have "time (rotate_left_tm k xs) =
    time (length_tm xs) +
    time (k modt (length xs)) +
    time (split_at_tm (k mod length xs) xs) + time (xs2 @t xs1) + 1"
  unfolding rotate_left_tm_def tm_time_simps val_length_tm val_mod_nat_tm val_split_at_tm
  Product_Type.prod.case 1[symmetric] by simp
  also have "...  (length xs + 1) + (8 * k + 2 * length xs + 7) + (2 * length xs + 3) + (length xs + 1) + 1"
    apply (intro add_mono order.refl)
    subgoal by simp
    subgoal by (estimation estimate: time_mod_nat_tm_le) (rule order.refl)
    subgoal by (simp add: time_split_at_tm)
    subgoal by (simp add: 2)
    done
  also have "... = 13 + 6 * length xs + 8 * k" by simp
  finally show ?thesis by simp
qed

definition rotate_right_tm :: "nat  'a list  'a list tm" where
"rotate_right_tm k xs =1 do {
  lenxs  length_tm xs;
  kmod  k modt lenxs;
  rk  lenxs -t kmod;
  rotate_left_tm rk xs
}"

lemma val_rotate_right_tm[simp, val_simp]: "val (rotate_right_tm k xs) = rotate_right k xs"
  unfolding rotate_right_tm_def rotate_right_def by (simp add: Let_def)

lemma time_rotate_right_tm_le: "time (rotate_right_tm k xs)  23 + 26 * max k (length xs)"
proof -
  have "time (rotate_right_tm k xs) =
    time (length_tm xs) +
    time (k modt length xs) +
    time (length xs -t (k mod length xs)) +
    time (rotate_left_tm (length xs - k mod length xs) xs) + 1"
    unfolding rotate_right_tm_def tm_time_simps val_length_tm val_mod_nat_tm val_minus_nat_tm
    by simp
  also have "...  (length xs + 1) +
    (8 * k + 2 * length xs + 7) +
    (length xs + 1) +
    (14 * length xs + 13) + 1"
    apply (intro add_mono order.refl)
    subgoal by simp
    subgoal by (estimation estimate: time_mod_nat_tm_le) (rule order.refl)
    subgoal by simp
    subgoal by (estimation estimate: time_rotate_left_tm_le) simp
    done
  also have "... = 23 + 18 * length xs + 8 * k" by simp
  finally show ?thesis by simp
qed

subsection "Auxiliary Lemmas for Landau Notation"

lemma eventually_early_nat:
  fixes f g :: "nat  nat"
  assumes "f  O(g)"
  assumes "x. x  n0  g x > 0"
  shows "c. (x. x  n0  f x  c * g x)"
proof -
  from landau_o.bigE[OF f  O(g)]
  obtain c_real where "eventually (λx. norm (f x)  c_real * norm (g x)) sequentially"
    by auto
  then have "eventually (λx. f x  c_real * g x) at_top" by simp
  then obtain n1 where f_le_g_real: "f x  c_real * g x" if "x  n1" for x
    using eventually_at_top_linorder by meson
  define c where "c = nat (ceiling c_real)"
  then have f_le_g: "f x  c * g x" if "x  n1" for x
  proof -
    have "real (f x)  c_real * real (g x)" using f_le_g_real[OF that] .
    also have "...  real c * real (g x)" unfolding c_def
      by (simp add: mult_mono real_nat_ceiling_ge)
    also have "... = real (c * g x)" by simp
    finally show ?thesis by linarith
  qed
  consider "n1  n0" | "n1 > n0" by linarith
  then show ?thesis
  proof cases
    case 1
    then show ?thesis
      apply (intro exI[of _ c]) using f_le_g by simp
  next
    case 2
    define M where "M = Max (f ` {n0..<n1})"
    define C where "C = (max M 1) * (max c 1)"
    have "f x  C * g x" if "x  n0" for x
    proof (cases "x < n1")
      case True
      then have "f x  M"
        unfolding M_def using 2
        by (intro Max.coboundedI; simp add: that)
      also have "...  C" unfolding C_def
        using nat_mult_max_right by auto
      also have "...  C * g x"
        using assms(2)[OF that] by simp
      finally show ?thesis .
    next
      case False
      then have "f x  c * g x" using f_le_g by simp
      also have "...  C * g x" unfolding C_def using nat_mult_max_left
        by simp
      finally show ?thesis .
    qed
    then show ?thesis by blast
  qed
qed

lemma eventually_early_real:
  fixes f g :: "nat  real"
  assumes "f  O(g)"
  assumes "x. x  n0  f x  0  g x  1"
  shows "c. (x  n0. f x  c * g x)"
proof -
  from landau_o.bigE[OF f  O(g)]
  obtain c where "eventually (λx. norm (f x)  c * norm (g x)) at_top"
    by auto
  then obtain n1 where f_le_g: "norm (f x)  c * norm (g x)" if "x  n1" for x
    using eventually_at_top_linorder by meson
  consider "n1  n0" | "n1 > n0" by linarith
  then show ?thesis
  proof cases
    case 1
    then show ?thesis
      apply (intro exI[of _ c] allI impI)
      subgoal for x using f_le_g[of x] assms(2)[of x] by simp
      done
  next
    case 2
    define M where "M = Max (f ` {n0..<n1})"
    define C where "C = (max M 1) * (max c 1)"
    then have "C  1" using mult_mono[OF max.cobounded2[of 1 M] max.cobounded2[of 1 c]] by argo
    have "C  c" unfolding C_def using mult_mono[OF max.cobounded2[of 1 M] max.cobounded1[of c 1]]
      by linarith
    have "f x  C * g x" if "x  n0" for x
    proof (cases "x < n1")
      case True
      then have "f x  M"
        unfolding M_def using 2
        by (intro Max.coboundedI; simp add: that)
      also have "...  C" unfolding C_def
        using mult_mono[OF max.cobounded1[of M 1] max.cobounded2[of 1 c]] by simp
      also have "...  C * g x"
        using assms(2)[OF that] mult_left_mono[of 1 "g x" C] C  1 by argo
      finally show ?thesis .
    next
      case False
      then have "f x  c * g x" using f_le_g[of x] assms(2)[OF that] by simp
      also have "...  C * g x" apply (intro mult_mono[OF C  c])
        subgoal by (rule order.refl)
        subgoal using C  1 by simp
        subgoal using assms(2)[OF that] by simp
        done
      finally show ?thesis .
    qed
    then show ?thesis by blast
  qed
qed

lemma floor_in_nat_iff: "floor x    x  0"
proof
  assume "floor x  "
  then obtain n where "floor x = of_nat n" unfolding Nats_def by auto
  then have "floor x  0" using of_nat_0_le_iff by simp
  then show "x  0" by simp
next
  assume "0  x"
  then have "floor x  0" by simp
  then obtain n where "floor x = of_nat n" using nat_0_le by metis
  then show "floor x  " unfolding Nats_def by simp
qed

lemma bigo_floor:
  fixes f :: "nat  nat"
  fixes g :: "nat  real"
  assumes "(λx. real (f x))  O(g)"
  assumes "eventually (λx. g x  1) at_top"
  shows "(λx. real (f x))  O(λx. real (nat (floor (g x))))"
proof -
  have ineq: "x  2 * real_of_int (floor x)" if "x  1" for x :: real
  proof -
    have "x  real_of_int (floor x) + 1"
      by (rule real_of_int_floor_add_one_ge)
    also have "...  2 * real_of_int (floor x)"
      using that by simp
    finally show ?thesis .
  qed
  obtain c where "c > 0" and f_le_g: "eventually (λx. real (f x)  c * norm (g x)) at_top"
    using landau_o.bigE[OF assms(1)] by auto
  have "eventually (λx. g x  2 * of_int (floor (g x))) at_top"
    using eventually_rev_mp[OF assms(2), of "λx. g x  2 * of_int (floor (g x))"]
    using assms(2) ineq by simp
  then have 1: "eventually (λx. c * g x  (2 * c) * of_int (floor (g x))) at_top"
    using eventually_mp[of "λx. g x  2 * of_int (floor (g x))" "λx. c * g x  (2 * c) * of_int (floor (g x))"]
    using c > 0 by simp
  have 2: "eventually (λx. c * norm (g x) = c * g x) at_top"
    using eventually_rev_mp[OF assms(2)] by simp
  have 3: "eventually (λx. c * norm (g x)  (2 * c) * of_int (floor (g x))) at_top"
    apply (intro eventually_rev_mp[OF eventually_conj[OF 1 2], of "λx. c * norm (g x)  (2 * c) * of_int (floor (g x))"])
    apply (intro always_eventually allI impI)
    by argo
  have 4: "eventually (λx. real (f x)  (2 * c) * of_int (floor (g x))) at_top"
    apply (intro eventually_rev_mp[OF eventually_conj[OF f_le_g 3], where Q = "λx. real (f x)  (2 * c) * of_int (floor (g x))"])
    by simp
  show ?thesis
    apply (intro landau_o.bigI[where c = "2 * c"])
    subgoal using c > 0 by argo
    subgoal apply (intro eventually_rev_mp[OF eventually_conj[OF 4 assms(2)], where Q = "λx. norm (real (f x))  (2 * c) * norm (real (nat g x))"])
      by simp
    done
qed

end