Theory Schoenhage_Strassen_TM
section "Running Time Formalization"
theory "Schoenhage_Strassen_TM"
imports
"Schoenhage_Strassen"
"../Preliminaries/Schoenhage_Strassen_Preliminaries"
"Z_mod_Fermat_TM"
"Karatsuba.Karatsuba_TM"
"Landau_Symbols.Landau_More"
begin
definition solve_special_residue_problem_tm where
"solve_special_residue_problem_tm n ξ η =1 do {
n2 ← n +⇩t 2;
ξmod ← take_tm n2 ξ;
δ ← int_lsbf_mod.subtract_mod_tm n2 η ξmod;
pown ← 2 ^⇩t n;
δ_shifted ← δ >>⇩n⇩t pown;
δ1 ← δ_shifted +⇩n⇩t δ;
ξ +⇩n⇩t δ1
}"
lemma val_solve_special_residue_problem_tm[simp, val_simp]:
"val (solve_special_residue_problem_tm n ξ η) = solve_special_residue_problem n ξ η"
proof -
have a: "n + 2 > 0" by simp
show ?thesis
unfolding solve_special_residue_problem_tm_def solve_special_residue_problem_def
using int_lsbf_mod.val_subtract_mod_tm[OF int_lsbf_mod.intro[OF a]]
by (simp add: Let_def)
qed
lemma time_solve_special_residue_problem_tm_le:
"time (solve_special_residue_problem_tm n ξ η) ≤ 245 + 74 * 2 ^ n + 55 * length η + 2 * length ξ"
proof -
define n2 where "n2 = n + 2"
define ξmod where "ξmod = take n2 ξ"
define δ where "δ = int_lsbf_mod.subtract_mod n2 η ξmod"
define pown where "pown = (2::nat) ^ n"
define δ_shifted where "δ_shifted = δ >>⇩n pown"
define δ1 where "δ1 = add_nat δ_shifted δ"
note defs = n2_def ξmod_def δ_def pown_def δ_shifted_def δ1_def
interpret mr: int_lsbf_mod n2 apply (intro int_lsbf_mod.intro) unfolding n2_def by simp
have length_ξmod_le: "length ξmod ≤ n2" unfolding ξmod_def by simp
have length_δ_le: "length δ ≤ max n2 (length η)"
unfolding δ_def mr.subtract_mod_def if_distrib[where f = length] mr.length_reduce
apply (estimation estimate: conjunct2[OF subtract_nat_aux])
using length_ξmod_le by auto
have length_δ1add_le: "max (length δ_shifted) (length δ) ≤ 2 ^ n + (n + 2) + length η"
unfolding δ_shifted_def pown_def
using length_δ_le unfolding n2_def by simp
have "time (solve_special_residue_problem_tm n ξ η) =
n + 1 + time (take_tm n2 ξ) + time (int_lsbf_mod.subtract_mod_tm n2 η ξmod) +
time (2 ^⇩t n) +
time (δ >>⇩n⇩t pown) +
time (δ_shifted +⇩n⇩t δ) +
time (ξ +⇩n⇩t δ1) +
1"
unfolding solve_special_residue_problem_tm_def tm_time_simps
by (simp del: One_nat_def add_2_eq_Suc' add: add.assoc[symmetric] defs[symmetric])
also have "... ≤ n + 1 + (n + 3) + (118 + 51 * (n + 2 + length η)) +
(3 * 2 ^ Suc n + 5 * n + 1) +
(2 * 2 ^ n + 3) +
(2 * 2 ^ n + 2 * length η + 2 * n + 7) +
(2 * length ξ + 2 * 2 ^ n + 2 * n + 2 * length η + 9) +
1"
apply (intro add_mono order.refl)
subgoal apply (estimation estimate: time_take_tm_le) unfolding n2_def by simp
subgoal
apply (estimation estimate: mr.time_subtract_mod_tm_le)
apply (estimation estimate: length_ξmod_le)
apply (estimation estimate: Nat_max_le_sum[of "length η"])
by (simp add: n2_def Nat_max_le_sum)
subgoal by (rule time_power_nat_tm_le)
subgoal unfolding time_shift_right_tm pown_def by simp
subgoal
apply (estimation estimate: time_add_nat_tm_le)
apply (estimation estimate: length_δ1add_le)
by simp
subgoal
apply (estimation estimate: time_add_nat_tm_le)
unfolding δ1_def
apply (estimation estimate: length_add_nat_upper)
apply (estimation estimate: length_δ1add_le)
apply (estimation estimate: Nat_max_le_sum)
by simp
done
also have "... = 245 + 12 * 2 ^ n + 62 * n + 55 * length η + 2 * length ξ" unfolding n2_def by simp
also have "... ≤ 245 + 74 * 2 ^ n + 55 * length η + 2 * length ξ"
using less_exp[of n] by simp
finally show ?thesis .
qed
fun combine_z_aux_tm where
"combine_z_aux_tm l acc [] =1 rev_tm acc ⤜ concat_tm"
| "combine_z_aux_tm l acc [z] =1 combine_z_aux_tm l (z # acc) []"
| "combine_z_aux_tm l acc (z1 # z2 # zs) =1 do {
(z1h, z1t) ← split_at_tm l z1;
r ← z1t +⇩n⇩t z2;
combine_z_aux_tm l (z1h # acc) (r # zs)
}"
lemma val_combine_z_aux_tm[simp, val_simp]: "val (combine_z_aux_tm l acc zs) = combine_z_aux l acc zs"
by (induction l acc zs rule: combine_z_aux.induct; simp)
lemma time_combine_z_aux_tm_le:
assumes "⋀z. z ∈ set zs ⟹ length z ≤ lz"
assumes "length z ≤ lz + 1"
assumes "l > 0"
shows "time (combine_z_aux_tm l acc (z # zs)) ≤ (2 * l + 2 * lz + 7) * length zs + 3 * (length acc + length zs) + length (concat acc) + length zs * l + lz + 9"
using assms proof (induction zs arbitrary: acc z)
case Nil
then show ?case
by (simp del: One_nat_def)
next
case (Cons z1 zs)
then have len_drop_z: "length (drop l z) ≤ lz" by simp
have lena: "length (add_nat (drop l z) z1) ≤ lz + 1"
apply (estimation estimate: length_add_nat_upper)
using len_drop_z Cons.prems by simp
have "time (combine_z_aux_tm l acc (z # z1 # zs)) =
time (split_at_tm l z) +
time (drop l z +⇩n⇩t z1) +
time (combine_z_aux_tm l (take l z # acc) ((drop l z +⇩n z1) # zs)) + 1"
by simp
also have "... ≤
(2 * l + 3) +
(2 * lz + 3) +
((2 * l + 2 * lz + 7) * length zs + 3 * (length (take l z # acc) + length zs) +
length (concat (take l z # acc)) + length zs * l + lz + 9) + 1"
apply (intro add_mono order.refl)
subgoal by (simp add: time_split_at_tm)
subgoal
apply (estimation estimate: time_add_nat_tm_le)
using len_drop_z Cons.prems by simp
subgoal
apply (intro Cons.IH)
subgoal using Cons.prems by simp
subgoal using lena .
subgoal using Cons.prems(3) .
done
done
also have "... = (2 * l + 2 * lz + 7) * length (z1 # zs) + 3 * (length acc + 1 + length zs) +
length (concat acc) + length (take l z) + length zs * l + lz + 9"
by simp
also have "... ≤ (2 * l + 2 * lz + 7) * length (z1 # zs) + 3 * (length acc + 1 + length zs) +
length (concat acc) + l + length zs * l + lz + 9"
apply (intro add_mono order.refl) by simp
also have "... = (2 * l + 2 * lz + 7) * length (z1 # zs) + 3 * (length acc + length (z1 # zs)) +
length (concat acc) + length (z1 # zs) * l + lz + 9"
by simp
finally show ?case .
qed
definition combine_z_tm where "combine_z_tm l zs =1 combine_z_aux_tm l [] zs"
lemma val_combine_z_tm[simp, val_simp]: "val (combine_z_tm l zs) = combine_z l zs"
unfolding combine_z_tm_def combine_z_def by simp
lemma time_combine_z_tm_le:
assumes "⋀z. z ∈ set zs ⟹ length z ≤ lz"
assumes "l > 0"
shows "time (combine_z_tm l zs) ≤ 10 + (3 * l + 2 * lz + 10) * length zs"
proof (cases zs)
case Nil
then have "time (combine_z_tm l zs) = 5"
unfolding combine_z_tm_def by simp
then show ?thesis by simp
next
case (Cons z zs')
then have "time (combine_z_tm l zs) = time (combine_z_aux_tm l [] (z # zs')) + 1"
unfolding combine_z_tm_def by simp
also have "... ≤ (2 * l + 2 * lz + 7) * length zs' + 3 * (length ([] :: nat_lsbf list) + length zs') + length (concat ([] :: nat_lsbf list)) +
length zs' * l + lz + 9 + 1"
apply (intro add_mono time_combine_z_aux_tm_le order.refl)
subgoal using Cons assms by simp
subgoal using Cons assms by force
subgoal using assms(2) .
done
also have "... = 10 + (3 * l + 2 * lz + 10) * length zs' + lz"
by (simp add: add_mult_distrib)
also have "... ≤ 10 + (3 * l + 2 * lz + 10) * length zs"
unfolding Cons by simp
finally show ?thesis .
qed
lemma schoenhage_strassen_tm_termination_aux: "¬ m < 3 ⟹ Suc (m div 2) < m"
by linarith
function schoenhage_strassen_tm :: "nat ⇒ nat_lsbf ⇒ nat_lsbf ⇒ nat_lsbf tm" where
"schoenhage_strassen_tm m a b =1 do {
m_le_3 ← m <⇩t 3;
if m_le_3 then do {
ab ← a *⇩n⇩t b;
int_lsbf_fermat.from_nat_lsbf_tm m ab
} else do {
odd_m ← odd_tm m;
n ← (if odd_m then do {
m1 ← m +⇩t 1;
m1 div⇩t 2
} else do {
m2 ← m +⇩t 2;
m2 div⇩t 2
});
n_plus_1 ← n +⇩t 1;
n_minus_1 ← n -⇩t 1;
n_plus_2 ← n +⇩t 2;
oe_n ← (if odd_m then return n_plus_1 else return n);
segment_lens ← 2 ^⇩t n_minus_1;
a' ← subdivide_tm segment_lens a;
b' ← subdivide_tm segment_lens b;
α ← map_tm (int_lsbf_mod.reduce_tm n_plus_2) a';
three_n ← 3 *⇩t n;
pad_length ← three_n +⇩t 5;
α_padded ← map_tm (fill_tm pad_length) α;
u ← concat_tm α_padded;
β ← map_tm (int_lsbf_mod.reduce_tm n_plus_2) b';
β_padded ← map_tm (fill_tm pad_length) β;
v ← concat_tm β_padded;
oe_n_plus_1 ← oe_n +⇩t 1;
two_pow_oe_n_plus_1 ← 2 ^⇩t oe_n_plus_1;
uv_length ← pad_length *⇩t two_pow_oe_n_plus_1;
uv_unpadded ← karatsuba_mul_nat_tm u v;
uv ← ensure_length_tm uv_length uv_unpadded;
oe_n_minus_1 ← oe_n -⇩t 1;
two_pow_oe_n_minus_1 ← 2 ^⇩t oe_n_minus_1;
γs ← subdivide_tm pad_length uv;
γ ← subdivide_tm two_pow_oe_n_minus_1 γs;
γ0 ← nth_tm γ 0;
γ1 ← nth_tm γ 1;
γ2 ← nth_tm γ 2;
γ3 ← nth_tm γ 3;
η ← map4_tm
(λx y z w. do {
xmod ← take_tm n_plus_2 x;
ymod ← take_tm n_plus_2 y;
zmod ← take_tm n_plus_2 z;
wmod ← take_tm n_plus_2 w;
xy ← int_lsbf_mod.subtract_mod_tm n_plus_2 xmod ymod;
zw ← int_lsbf_mod.subtract_mod_tm n_plus_2 zmod wmod;
int_lsbf_mod.add_mod_tm n_plus_2 xy zw
})
γ0 γ1 γ2 γ3;
prim_root_exponent ← if odd_m then return 1 else return 2;
fn_carrier_len ← 2 ^⇩t n_plus_1;
a'_carrier ← map_tm (fill_tm fn_carrier_len) a';
b'_carrier ← map_tm (fill_tm fn_carrier_len) b';
a_dft ← int_lsbf_fermat.fft_tm n prim_root_exponent a'_carrier;
b_dft ← int_lsbf_fermat.fft_tm n prim_root_exponent b'_carrier;
a_dft_odds ← evens_odds_tm False a_dft;
b_dft_odds ← evens_odds_tm False b_dft;
c_dft_odds ← map2_tm (schoenhage_strassen_tm n) a_dft_odds b_dft_odds;
prim_root_exponent_2 ← prim_root_exponent *⇩t 2;
c_diffs ← int_lsbf_fermat.ifft_tm n prim_root_exponent_2 c_dft_odds;
two_pow_oe_n ← 2 ^⇩t oe_n;
interval1 ← upt_tm 0 two_pow_oe_n_minus_1;
interval2 ← upt_tm two_pow_oe_n_minus_1 two_pow_oe_n;
two_pow_n ← 2 ^⇩t n;
oe_n_plus_two_pow_n ← oe_n +⇩t two_pow_n;
oe_n_plus_two_pow_n_zeros ← replicate_tm oe_n_plus_two_pow_n False;
oe_n_plus_two_pow_n_one ← oe_n_plus_two_pow_n_zeros @⇩t [True];
ξ' ← map2_tm (λx y. do {
v1 ← prim_root_exponent *⇩t y;
v2 ← oe_n +⇩t v1;
v3 ← v2 -⇩t 1;
summand1 ← int_lsbf_fermat.divide_by_power_of_2_tm x v3;
summand2 ← int_lsbf_fermat.from_nat_lsbf_tm n oe_n_plus_two_pow_n_one;
int_lsbf_fermat.add_fermat_tm n summand1 summand2
})
c_diffs interval1;
ξ ← map_tm (int_lsbf_fermat.reduce_tm n) ξ';
z ← map2_tm (solve_special_residue_problem_tm n) ξ η;
z_filled ← map_tm (fill_tm segment_lens) z;
z_consts ← replicate_tm two_pow_oe_n_minus_1 oe_n_plus_two_pow_n_one;
z_complete ← z_filled @⇩t z_consts;
z_sum ← combine_z_tm segment_lens z_complete;
result ← int_lsbf_fermat.from_nat_lsbf_tm m z_sum;
return result
}
}"
by pat_completeness auto
termination
apply (relation "Wellfounded.measure (λ(n, a, b). n)")
subgoal by blast
subgoal for m by (cases "odd m"; simp)
subgoal for m by (cases "odd m"; simp)
subgoal for m by (cases "odd m"; simp)
subgoal for m by (cases "odd m"; simp)
subgoal for m by (cases "odd m"; simp)
subgoal for m by (cases "odd m"; simp)
subgoal for m by (cases "odd m"; simp)
subgoal for m by (cases "odd m"; simp)
subgoal for m by (cases "odd m"; simp)
subgoal for m by (cases "odd m"; simp)
subgoal for m by (cases "odd m"; simp)
subgoal for m by (cases "odd m"; simp)
subgoal for m by (cases "odd m"; simp)
subgoal for m by (cases "odd m"; simp)
subgoal for m by (cases "odd m"; simp)
subgoal for m by (cases "odd m"; simp)
subgoal for m by (cases "odd m"; simp)
subgoal for m by (cases "odd m"; simp)
subgoal for m by (cases "odd m"; simp)
subgoal for m by (cases "odd m"; simp)
done
context schoenhage_strassen_context begin
abbreviation γ0 where "γ0 ≡ γ ! 0"
abbreviation γ1 where "γ1 ≡ γ ! 1"
abbreviation γ2 where "γ2 ≡ γ ! 2"
abbreviation γ3 where "γ3 ≡ γ ! 3"
definition fn_carrier_len where "fn_carrier_len = (2::nat) ^ (n + 1)"
definition segment_lens where "segment_lens = (2::nat) ^ (n - 1)"
definition interval1 where "interval1 = [0..<2 ^ (oe_n - 1)]"
definition interval2 where "interval2 = [2 ^ (oe_n - 1)..<2 ^ oe_n]"
definition oe_n_plus_two_pow_n_zeros where "oe_n_plus_two_pow_n_zeros = replicate (oe_n + 2 ^ n) False"
definition oe_n_plus_two_pow_n_one where "oe_n_plus_two_pow_n_one = append oe_n_plus_two_pow_n_zeros [True]"
definition z_complete where "z_complete = z_filled @ z_consts"
lemmas defs' =
segment_lens_def fn_carrier_len_def
c_diffs_def interval1_def interval2_def
oe_n_plus_two_pow_n_zeros_def oe_n_plus_two_pow_n_one_def
z_complete_def
lemma z_filled_def': "z_filled = map (fill segment_lens) z"
unfolding z_filled_def defs'[symmetric] by (rule refl)
lemma z_sum_def': "z_sum = combine_z segment_lens z_complete"
unfolding z_sum_def defs'[symmetric] by (rule refl)
lemmas defs'' = defs' z_filled_def' z_sum_def'
lemma segment_lens_pos: "segment_lens > 0" unfolding segment_lens_def by simp
lemma length_γs: "length γs = 2 ^ (oe_n + 1)"
using scuv(1) unfolding defs[symmetric] .
lemma length_γs': "length γs = 2 ^ (oe_n - 1) * 4"
using two_pow_Suc_oe_n_as_prod length_γs unfolding defs[symmetric]
by simp
lemma val_nth_γ[simp, val_simp]:
"val (nth_tm γ 0) = γ ! 0"
"val (nth_tm γ 1) = γ ! 1"
"val (nth_tm γ 2) = γ ! 2"
"val (nth_tm γ 3) = γ ! 3"
unfolding defs' using scγ by simp_all
lemma val_fft1[simp, val_simp]: "val (int_lsbf_fermat.fft_tm n prim_root_exponent A.num_blocks_carrier) =
int_lsbf_fermat.fft n prim_root_exponent A.num_blocks_carrier"
by (intro int_lsbf_fermat.val_fft_tm[where m = oe_n] A.length_num_blocks_carrier)
lemma val_fft2[simp, val_simp]: "val (int_lsbf_fermat.fft_tm n prim_root_exponent B.num_blocks_carrier) =
int_lsbf_fermat.fft n prim_root_exponent B.num_blocks_carrier"
by (intro int_lsbf_fermat.val_fft_tm[where m = oe_n] B.length_num_blocks_carrier)
lemma val_ifft[simp, val_simp]: "val (int_lsbf_fermat.ifft_tm n (prim_root_exponent * 2) c_dft_odds) =
int_lsbf_fermat.ifft n (prim_root_exponent * 2) c_dft_odds"
apply (intro int_lsbf_fermat.val_ifft_tm[where m = "oe_n - 1"])
apply (simp add: c_dft_odds_def)
done
end
lemma val_schoenhage_strassen_tm[simp, val_simp]:
assumes "a ∈ int_lsbf_fermat.fermat_non_unique_carrier m"
assumes "b ∈ int_lsbf_fermat.fermat_non_unique_carrier m"
shows "val (schoenhage_strassen_tm m a b) = schoenhage_strassen m a b"
using assms proof (induction m arbitrary: a b rule: less_induct)
case (less m)
show ?case
proof (cases "m < 3")
case True
then show ?thesis
unfolding schoenhage_strassen_tm.simps[of m a b] val_simps
unfolding schoenhage_strassen.simps[of m a b]
using int_lsbf_fermat.val_from_nat_lsbf_tm by simp
next
case False
interpret schoenhage_strassen_context m a b
apply unfold_locales using False less.prems by simp_all
have val_ih: "map2 (λx y. val (schoenhage_strassen_tm n x y)) A.num_dft_odds B.num_dft_odds =
map2 (λx y. schoenhage_strassen n x y) A.num_dft_odds B.num_dft_odds"
apply (intro map_cong refl)
subgoal premises prems for p
proof -
from prems set_zip obtain i
where i_le: "i < min (length A.num_dft_odds) (length B.num_dft_odds)"
and p_i: "p = (A.num_dft_odds ! i, B.num_dft_odds ! i)"
by blast
then have "i < 2 ^ (oe_n - 1)"
using A.length_num_dft_odds by simp
show ?thesis unfolding p_i prod.case
apply (intro less.IH n_lt_m set_subseteqD A.num_dft_odds_carrier B.num_dft_odds_carrier)
using i_le by simp_all
qed
done
have "val (schoenhage_strassen_tm m a b) = result"
unfolding schoenhage_strassen_tm.simps[of m a b]
unfolding val_simp
val_times_nat_tm
val_subdivide_tm[OF segment_lens_pos] val_subdivide_tm[OF pad_length_gt_0]
Znr.val_reduce_tm Znr.val_subtract_mod_tm Znr.val_add_mod_tm
val_nth_γ val_subdivide_tm[OF two_pow_pos] val_fft1 val_fft2 val_ih val_ifft
defs[symmetric] Let_def
val_subdivide_tm[OF two_pow_pos] Fnr.val_ifft_tm[OF length_c_dft_odds]
using False by argo
then show ?thesis using result_eq by argo
qed
qed
fun schoenhage_strassen_Fm_bound where
"schoenhage_strassen_Fm_bound m = (if m < 3 then 5336 else
let n = (if odd m then (m + 1) div 2 else (m + 2) div 2);
oe_n = (if odd m then n + 1 else n) in
23525 * 2 ^ m + 8093 * (n * 2 ^ (2 * n)) + 8410 +
time_karatsuba_mul_nat_bound ((3 * n + 5) * 2 ^ oe_n) +
4 * karatsuba_lower_bound +
schoenhage_strassen_Fm_bound n * 2 ^ (oe_n - 1))"
declare schoenhage_strassen_Fm_bound.simps[simp del]
lemma time_schoenhage_strassen_tm_le:
assumes "a ∈ int_lsbf_fermat.fermat_non_unique_carrier m"
assumes "b ∈ int_lsbf_fermat.fermat_non_unique_carrier m"
shows "time (schoenhage_strassen_tm m a b) ≤ schoenhage_strassen_Fm_bound m"
using assms proof (induction m arbitrary: a b rule: less_induct)
case (less m)
consider "m = 0" | "m ≥ 1 ∧ m < 3" | "¬ m < 3" by linarith
then show ?case
proof cases
case 1
from less.prems int_lsbf_fermat.fermat_carrier_length
have len_ab: "length a = 2" "length b = 2" unfolding 1 by simp_all
then have len_mul_ab: "length (grid_mul_nat a b) ≤ 4"
using length_grid_mul_nat[of a b] by simp
from 1 have "time (schoenhage_strassen_tm m a b) =
time (m <⇩t 3) +
time (a *⇩n⇩t b) +
time (int_lsbf_fermat.from_nat_lsbf_tm m (grid_mul_nat a b)) + 1"
unfolding schoenhage_strassen_tm.simps[of m a b] time_bind_tm val_less_nat_tm
by (simp del: One_nat_def)
also have "... ≤ (2 * m + 2) +
(8 * length a * max (length a) (length b) + 1) +
int_lsbf_fermat.time_from_nat_lsbf_tm_bound m (length (grid_mul_nat a b)) + 1"
apply (intro add_mono order.refl)
subgoal by (simp add: time_less_nat_tm 1)
subgoal by (rule time_grid_mul_nat_tm_le)
subgoal by (intro int_lsbf_fermat.time_from_nat_lsbf_tm_le_bound order.refl)
done
also have "... ≤ 2 + 33 + 240 + 1"
apply (intro add_mono order.refl)
subgoal unfolding 1 by simp
subgoal unfolding len_ab by simp
subgoal unfolding int_lsbf_fermat.time_from_nat_lsbf_tm_bound.simps[of 0 "(length (grid_mul_nat a b))"] 1
using len_mul_ab by simp
done
also have "... = 276" by simp
finally show ?thesis unfolding schoenhage_strassen_Fm_bound.simps[of m] using 1 by simp
next
case 2
then have "(2::nat) ^ (m + 1) ≥ 4"
using power_increasing[of 2 "m + 1" "2::nat"] by simp
from 2 have "(2::nat) ^ (m + 1) ≤ 8"
using power_increasing[of "m + 1" 3 "2::nat"] by simp
from less.prems have len_ab: "length a = 2 ^ (m + 1)" "length b = 2 ^ (m + 1)"
using int_lsbf_fermat.fermat_carrier_length by simp_all
then have len_ab_le: "length a ≤ 8" "length b ≤ 8"
using ‹2 ^ (m + 1) ≤ 8› by linarith+
have len_mul_ab_le: "length (grid_mul_nat a b) ≤ 2 * 2 ^ (m + 1)"
using length_grid_mul_nat[of a b] len_ab by simp
from 2 have "time (schoenhage_strassen_tm m a b) =
time (m <⇩t 3) +
time (a *⇩n⇩t b) +
time (int_lsbf_fermat.from_nat_lsbf_tm m (grid_mul_nat a b)) + 1"
unfolding schoenhage_strassen_tm.simps[of m a b] time_bind_tm val_less_nat_tm
by (simp del: One_nat_def)
also have "... ≤ (2 * m + 2) +
(8 * length a * max (length a) (length b) + 1) +
(720 + 512 * 2 ^ (m + 1)) + 1"
apply (intro add_mono order.refl)
subgoal by (simp add: time_less_nat_tm 2)
subgoal by (rule time_grid_mul_nat_tm_le)
subgoal using int_lsbf_fermat.time_from_nat_lsbf_tm_le[OF ‹4 ≤ 2 ^ (m + 1)› len_mul_ab_le]
by simp
done
also have "... ≤ 6 + 513 + (720 + 512 * 8) + 1"
apply (intro add_mono mult_le_mono order.refl)
subgoal using 2 by simp
subgoal
apply (estimation estimate: max.boundedI[OF len_ab_le])
using len_ab_le by simp
subgoal using ‹2 ^ (m + 1) ≤ 8› .
done
also have "... = 5336" by simp
finally show ?thesis unfolding schoenhage_strassen_Fm_bound.simps[of m] using 2 by simp
next
case 3
interpret schoenhage_strassen_context m a b
apply unfold_locales using 3 less.prems by simp_all
define time_η where "time_η = time (map4_tm
(λx y z w. do {
xmod ← take_tm (n + 2) x;
ymod ← take_tm (n + 2) y;
zmod ← take_tm (n + 2) z;
wmod ← take_tm (n + 2) w;
xy ← Znr.subtract_mod_tm xmod ymod;
zw ← Znr.subtract_mod_tm zmod wmod;
Znr.add_mod_tm xy zw
})
γ0 γ1 γ2 γ3)" (is "time_η = time (map4_tm ?η_fun _ _ _ _)")
define time_ξ' where "time_ξ' = time (map2_tm (λx y. do {
v1 ← prim_root_exponent *⇩t y;
v2 ← oe_n +⇩t v1;
v3 ← v2 -⇩t 1;
summand1 ← Fnr.divide_by_power_of_2_tm x v3;
summand2 ← Fnr.from_nat_lsbf_tm oe_n_plus_two_pow_n_one;
Fnr.add_fermat_tm summand1 summand2
})
c_diffs interval1)"
define time_ξ where "time_ξ = time (map_tm (int_lsbf_fermat.reduce_tm n) ξ')"
define time_z where "time_z = time (map2_tm (solve_special_residue_problem_tm n) ξ η)"
define time_z_filled where "time_z_filled = time (map_tm (fill_tm segment_lens) z)"
note map_time_defs = time_η_def time_ξ'_def time_ξ_def time_z_def time_z_filled_def
from Fmr.res_carrier_eq have Fm_carrierI: "⋀i. 0 ≤ i ⟹ i < 2 ^ 2 ^ m + 1 ⟹ i ∈ carrier Fm"
by simp
have length_uv_unpadded_le: "length uv_unpadded ≤ 12 * (3 * n + 5) * 2 ^ oe_n +
(6 + 2 * karatsuba_lower_bound)"
unfolding uv_unpadded_def
apply (estimation estimate: length_karatsuba_mul_nat_le)
unfolding A.length_num_Zn_pad B.length_num_Zn_pad pad_length_def by simp
have prim_root_exponent_le: "prim_root_exponent ≤ 2" unfolding prim_root_exponent_def by simp
then have prim_root_exponent_2_le: "prim_root_exponent * 2 ≤ 4"
by simp
have length_interval1: "length interval1 = 2 ^ (oe_n - 1)"
unfolding interval1_def by simp
have length_interval2: "length interval2 = 2 ^ (oe_n - 1)"
unfolding interval2_def using two_pow_oe_n_as_halves by simp
have length_oe_n_plus_two_pow_n_zeros: "length oe_n_plus_two_pow_n_zeros = oe_n + 2 ^ n"
unfolding oe_n_plus_two_pow_n_zeros_def by simp
have length_oe_n_plus_two_pow_n_one: "length oe_n_plus_two_pow_n_one = oe_n + 2 ^ n + 1"
unfolding oe_n_plus_two_pow_n_one_def
using length_oe_n_plus_two_pow_n_zeros by simp
have c_dft_odds_carrier: "set c_dft_odds ⊆ Fnr.fermat_non_unique_carrier"
unfolding c_dft_odds_def
apply (intro set_subseteqI)
subgoal premises prems for i
proof -
have "map2 (schoenhage_strassen n) A.num_dft_odds B.num_dft_odds ! i =
schoenhage_strassen n (A.num_dft_odds ! i) (B.num_dft_odds ! i)"
using nth_map prems by simp
also have "... ∈ Fnr.fermat_non_unique_carrier"
apply (intro conjunct2[OF schoenhage_strassen_correct'])
subgoal
apply (intro set_subseteqD[OF A.num_dft_odds_carrier])
using prems by simp
subgoal
apply (intro set_subseteqD[OF B.num_dft_odds_carrier])
using prems by simp
done
finally show ?thesis .
qed
done
have c_diffs_carrier: "c_diffs ! i ∈ Fnr.fermat_non_unique_carrier" if "i < 2 ^ (oe_n - 1)" for i
unfolding c_diffs_def Fnr.ifft.simps
apply (intro set_subseteqD[OF Fnr.fft_ifft_carrier[of _ "oe_n - 1"]])
subgoal using length_c_dft_odds .
subgoal using c_dft_odds_carrier .
subgoal using Fnr.length_ifft[OF length_c_dft_odds] that by simp
done
have ξ'_carrier: "ξ' ! i ∈ Fnr.fermat_non_unique_carrier" if "i < 2 ^ (oe_n - 1)" for i
proof -
from that have "ξ' ! i = Fnr.add_fermat
(Fnr.divide_by_power_of_2 (c_diffs ! i)
(oe_n + prim_root_exponent * ([0..<2 ^ (oe_n - 1)] ! i) - 1))
(Fnr.from_nat_lsbf (replicate (oe_n + 2 ^ n) False @ [True]))"
unfolding ξ'_def using nth_map2 that length_c_diffs by simp
also have "... ∈ Fnr.fermat_non_unique_carrier"
apply (intro Fnr.add_fermat_closed)
subgoal
by (intro Fnr.divide_by_power_of_2_closed that c_diffs_carrier)
subgoal by (intro Fnr.from_nat_lsbf_correct(1))
done
finally show "ξ' ! i ∈ Fnr.fermat_non_unique_carrier" .
qed
have ξ'_carrier': "set ξ' ⊆ Fnr.fermat_non_unique_carrier"
apply (intro set_subseteqI ξ'_carrier) unfolding length_ξ' .
have length_ξ_entries: "length x ≤ 2 ^ n + 2" if "x ∈ set ξ" for x
proof -
from that obtain x' where "x' ∈ set ξ'" "x = Fnr.reduce x'" unfolding ξ_def
by auto
from that show ?thesis unfolding ‹x = Fnr.reduce x'›
apply (intro Fnr.reduce_correct'(2))
using ‹x' ∈ set ξ'› ξ'_carrier' by auto
qed
have length_η_entries: "length (η ! i) = n + 2" if "i < 2 ^ (oe_n - 1)" for i
proof -
have "η ! i = Znr.add_mod (Znr.subtract_mod (take (n + 2) (γ0 ! i)) (take (n + 2) (γ1 ! i)))
(Znr.subtract_mod (take (n + 2) (γ2 ! i)) (take (n + 2) (γ3 ! i)))"
unfolding η_def Let_def defs'[symmetric]
apply (intro nth_map4)
unfolding length_γs defs' using length_γ_i that by simp_all
then show ?thesis using Znr.add_mod_closed by simp
qed
have length_z_entries: "length (z ! i) ≤ 2 ^ n + n + 4" if "i < 2 ^ (oe_n - 1)" for i
proof -
have "z ! i = solve_special_residue_problem n (ξ ! i) (η ! i)"
unfolding z_def apply (intro nth_map2) using that length_ξ length_η by simp_all
also have "length ... ≤ max (length (ξ ! i))
(2 ^ n + length (Znr.subtract_mod (η ! i) (take (n + 2) (ξ ! i))) + 1) + 1"
unfolding solve_special_residue_problem_def Let_def defs[symmetric]
apply (estimation estimate: length_add_nat_upper)
apply (estimation estimate: length_add_nat_upper)
by (simp del: One_nat_def)
also have "... ≤ max (2 ^ n + 2) ((2 ^ n + (n + 2)) + 1) + 1"
apply (intro add_mono order.refl max.mono)
subgoal using length_ξ_entries nth_mem[of i ξ] length_ξ that by simp
subgoal apply (intro Znr.length_subtract_mod)
subgoal using length_η_entries[OF that] by simp
subgoal by simp
done
done
also have "... = 2 ^ n + n + 4" by simp
finally show ?thesis .
qed
have length_z_filled_entries: "length (z_filled ! i) ≤ 2 ^ n + n + 4" if "i < 2 ^ (oe_n - 1)" for i
proof -
have "z_filled ! i = fill (2 ^ (n - 1)) (z ! i)"
unfolding z_filled_def segment_lens_def
using nth_map[of i z] unfolding length_z
using that by auto
also have "length ... ≤ max (2 ^ (n - 1)) (2 ^ n + n + 4)"
using length_z_entries[OF that] unfolding length_fill' by simp
also have "... ≤ 2 ^ n + n + 4"
apply (intro max.boundedI order.refl)
using power_increasing[of "n - 1" n "2::nat"] by linarith
finally show ?thesis .
qed
have length_z_complete_entries: "length i ≤ 2 ^ n + n + 4" if "i ∈ set z_complete" for i
proof -
from that consider "i ∈ set z_filled" | "i ∈ set z_consts"
unfolding z_complete_def by auto
then show ?thesis
proof cases
case 1
show ?thesis
using iffD1[OF in_set_conv_nth 1] length_z_filled_entries length_z_filled
by auto
next
case 2
then have i_eq: "i = oe_n_plus_two_pow_n_one"
unfolding z_consts_def defs'
by simp
show ?thesis unfolding i_eq length_oe_n_plus_two_pow_n_one
using oe_n_le_n by simp
qed
qed
have length_z_complete: "length z_complete = 2 ^ oe_n"
unfolding z_complete_def
by (simp add: length_z_filled length_z_consts two_pow_oe_n_as_halves)
have length_z_sum_le: "length z_sum ≤ 28 * Fmr.e"
proof -
have "length z_sum ≤ ((2 ^ n + n + 4) + 1) * length z_complete"
unfolding z_sum_def z_complete_def
apply (intro length_combine_z_le segment_lens_pos)
using length_z_complete_entries z_complete_def by simp_all
also have "... = (2 ^ n + n + 5) * 2 ^ oe_n"
unfolding length_z_complete by simp
also have "... ≤ (2 ^ n + 2 ^ n + 5 * 2 ^ n) * (2 * 2 ^ n)"
apply (intro mult_le_mono add_mono order.refl)
subgoal using less_exp by simp
subgoal by simp
subgoal by (estimation estimate: oe_n_le_n; simp)
done
also have "... = 14 * 2 ^ (2 * n)"
by (simp add: mult_2[of n] power_add)
also have "... ≤ 28 * Fmr.e"
using two_pow_two_n_le by simp
finally show ?thesis .
qed
have val_ih: "map2 (λx y. val (schoenhage_strassen_tm n x y)) A.num_dft_odds B.num_dft_odds =
c_dft_odds"
unfolding c_dft_odds_def
apply (intro map_cong ext refl)
subgoal premises prems for p
proof -
from prems obtain i where p_decomp: "i < length A.num_dft_odds" "i < length B.num_dft_odds"
"p = (A.num_dft_odds ! i, B.num_dft_odds ! i)"
using set_zip[of A.num_dft_odds B.num_dft_odds] by auto
show ?thesis unfolding p_decomp prod.case
apply (intro val_schoenhage_strassen_tm)
subgoal using set_subseteqD[OF A.num_dft_odds_carrier]
using p_decomp by simp
subgoal using set_subseteqD[OF B.num_dft_odds_carrier]
using p_decomp by simp
done
qed
done
have ξ'_alt: "map2
(λx y. Fnr.add_fermat
(Fmr.divide_by_power_of_2 x (oe_n + prim_root_exponent * y - 1))
(Fnr.from_nat_lsbf oe_n_plus_two_pow_n_one))
c_diffs interval1 = ξ'"
unfolding ξ'_def Let_def defs'[symmetric] by (rule refl)
have "time_η ≤ ((112 * (n + 2) + 254) + 1) * min (min (min (length γ0) (length γ1)) (length γ2)) (length γ3) + 1"
unfolding time_η_def
apply (intro time_map4_tm_bounded)
unfolding tm_time_simps add.assoc[symmetric] val_take_tm Znr.val_subtract_mod_tm Znr.val_add_mod_tm
subgoal premises prems for x y z w
proof -
have "time (take_tm (n + 2) x) + time (take_tm (n + 2) y) + time (take_tm (n + 2) z) + time (take_tm (n + 2) w) +
time (Znr.subtract_mod_tm (take (n + 2) x) (take (n + 2) y)) +
time (Znr.subtract_mod_tm (take (n + 2) z) (take (n + 2) w)) +
time (Znr.add_mod_tm (Znr.subtract_mod (take (n + 2) x) (take (n + 2) y)) (Znr.subtract_mod (take (n + 2) z) (take (n + 2) w))) ≤
((n + 2) + 1) + ((n + 2) + 1) + ((n + 2) + 1) + ((n + 2) + 1) +
(118 + 51 * (n + 2)) +
(118 + 51 * (n + 2)) +
(14 + 4 * (n + 2) + 2 * (n + 2))"
apply (intro add_mono time_take_tm_le)
subgoal
apply (estimation estimate: Znr.time_subtract_mod_tm_le)
unfolding length_take
apply (estimation estimate: min.cobounded2)
apply (estimation estimate: min.cobounded2)
by (simp add: defs')
subgoal
apply (estimation estimate: Znr.time_subtract_mod_tm_le)
unfolding length_take
apply (estimation estimate: min.cobounded2)
apply (estimation estimate: min.cobounded2)
by (simp add: defs')
subgoal
apply (estimation estimate: Znr.time_add_mod_tm_le)
apply (estimation estimate: Znr.length_subtract_mod[OF length_take_cobounded1 length_take_cobounded1])
apply (estimation estimate: Znr.length_subtract_mod[OF length_take_cobounded1 length_take_cobounded1])
apply simp
done
done
also have "... = 112 * (n + 2) + 254" by simp
finally show ?thesis .
qed
done
also have "... = (255 + 112 * (n + 2)) * 2 ^ (oe_n - 1) + 1"
unfolding length_γs defs' using length_γ_i by simp
also have "... ≤ (255 + 112 * (n + 2)) * 2 ^ n + 1 * 2 ^ n"
apply (intro add_mono mult_le_mono order.refl)
unfolding oe_n_def by simp_all
also have "... = (256 + 112 * (n + 2)) * 2 ^ n"
by (simp add: add_mult_distrib)
also have "... ≤ (128 * (n + 2) + 112 * (n + 2)) * 2 ^ n"
apply (intro add_mono mult_le_mono order.refl)
by simp
finally have time_η_le: "time_η ≤ 240 * (n + 2) * 2 ^ n" by simp
have oe_n_prim_root_le: "oe_n + prim_root_exponent * y - 1 ≤ fn_carrier_len" if "y ∈ set interval1" for y
proof -
have "oe_n + prim_root_exponent * y - 1 ≤ n + prim_root_exponent * y"
using oe_n_minus_1_le_n by simp
also have "... ≤ n + prim_root_exponent * 2 ^ (oe_n - 1)"
using that unfolding interval1_def defs' by simp
also have "... = n + 2 ^ n"
unfolding oe_n_def prim_root_exponent_def
by (cases "odd m"; simp add: n_gt_0 power_Suc[symmetric])
also have "... ≤ 2 ^ n + 2 ^ n"
by simp
also have "... = fn_carrier_len"
unfolding defs' by simp
finally show ?thesis .
qed
have "time_ξ' ≤ ((475 + 378 * Fnr.e) + 2) * length c_diffs + 3"
unfolding time_ξ'_def
apply (intro time_map2_tm_bounded)
subgoal unfolding length_c_diffs length_interval1 by (rule refl)
subgoal premises prems for x y
unfolding tm_time_simps add.assoc[symmetric] val_times_nat_tm defs[symmetric]
val_plus_nat_tm val_minus_nat_tm Fmr.val_divide_by_power_of_2_tm
Fnr.val_from_nat_lsbf_tm
proof -
have "time (prim_root_exponent *⇩t y) +
time (oe_n +⇩t (prim_root_exponent * y)) +
time ((oe_n + prim_root_exponent * y) -⇩t 1) +
time (Fmr.divide_by_power_of_2_tm x (oe_n + prim_root_exponent * y - 1)) +
time (Fnr.from_nat_lsbf_tm oe_n_plus_two_pow_n_one) +
time
(Fnr.add_fermat_tm
(Fmr.divide_by_power_of_2 x (oe_n + prim_root_exponent * y - 1))
(Fnr.from_nat_lsbf oe_n_plus_two_pow_n_one)) ≤
(2 * y + 5) + (oe_n + 1) + 2 + (24 + 26 * fn_carrier_len + 26 * length x) +
(288 * 1 + 144 + (96 + 192 * 1 + 8 * 1 * 1) * Fnr.e) +
(13 + 7 * length x + 21 * Fnr.e)" (is "?t ≤ _")
apply (intro add_mono)
subgoal unfolding time_times_nat_tm
apply (estimation estimate: prim_root_exponent_le)
by simp
subgoal unfolding time_plus_nat_tm by simp
subgoal unfolding time_minus_nat_tm by simp
subgoal apply (estimation estimate: Fmr.time_divide_by_power_of_2_tm_le)
apply (estimation estimate: oe_n_prim_root_le[OF prems(2)])
apply (estimation estimate: Nat_max_le_sum)
by simp
subgoal
apply (intro Fnr.time_from_nat_lsbf_tm_le Fnr.e_ge_4 n_gt_0)
unfolding length_oe_n_plus_two_pow_n_one using oe_n_n_bound_1 by simp
subgoal
apply (estimation estimate: Fnr.time_add_fermat_tm_le)
unfolding Fmr.length_multiply_with_power_of_2 Fnr.length_from_nat_lsbf
apply (estimation estimate: Nat_max_le_sum)
by simp
done
also have "... = 477 + 2 * y + oe_n + 343 * Fnr.e + 33 * length x"
unfolding fn_carrier_len_def by simp
also have "... = 477 + 2 * y + oe_n + 376 * Fnr.e"
using prems set_subseteqI[OF c_diffs_carrier] length_c_diffs by auto
also have "... ≤ 477 + 2 * 2 ^ (oe_n - 1) + oe_n + 376 * Fnr.e"
using prems unfolding interval1_def
by simp
also have "... ≤ 477 + oe_n + 377 * Fnr.e"
unfolding oe_n_def by simp
also have "... ≤ 475 + 378 * Fnr.e"
using oe_n_n_bound_1 by simp
finally show "?t ≤ 475 + 378 * Fnr.e" unfolding defs[symmetric] .
qed
done
also have "... ≤ (475 + 758 * 2 ^ n) * 2 ^ n + 3"
apply (intro add_mono[of _ _ 3] order.refl mult_le_mono)
subgoal by simp
subgoal unfolding length_c_diffs oe_n_def by simp
done
also have "... = 3 + 475 * 2 ^ n + 758 * 2 ^ (2 * n)"
by (simp add: add_mult_distrib power_add mult_2)
finally have time_ξ'_le: "time_ξ' ≤ ..." .
have time_reduce_ξ'_nth: "time (Fnr.reduce_tm i) ≤ 155 + 216 * 2 ^ n" if "i ∈ set ξ'" for i
proof -
have "length i = Fnr.e"
using iffD1[OF in_set_conv_nth that]
Fnr.fermat_carrier_length[OF ξ'_carrier] length_ξ' by auto
show ?thesis
by (estimation estimate: Fnr.time_reduce_tm_le)
(simp add: ‹length i = Fnr.e›)
qed
have "time_ξ ≤ ((155 + 216 * 2 ^ n) + 1) * length ξ' + 1"
unfolding time_ξ_def
by (intro time_map_tm_bounded time_reduce_ξ'_nth)
also have "... ≤ (156 + 216 * 2 ^ n) * 2 ^ n + 1"
unfolding length_ξ' oe_n_def by simp
also have "... = 1 + 156 * 2 ^ n + 216 * 2 ^ (2 * n)"
by (simp add: add_mult_distrib power_add mult_2)
finally have time_ξ_le: "time_ξ ≤ ..." .
have "time_z ≤ ((245 + 74 * 2 ^ n + 55 * (n + 2) + 2 * (2 ^ n + 2)) + 2) * length ξ + 3"
unfolding time_z_def
apply (intro time_map2_tm_bounded)
subgoal unfolding length_ξ length_η by (rule refl)
subgoal premises prems for x y
apply (estimation estimate: time_solve_special_residue_problem_tm_le)
apply (intro add_mono mult_le_mono order.refl)
subgoal using length_η_entries length_η iffD1[OF in_set_conv_nth ‹y ∈ set η›] by auto
subgoal using length_ξ_entries[OF ‹x ∈ set ξ›] .
done
done
also have "... = (361 + 76 * 2 ^ n + 55 * n) * 2 ^ (oe_n - 1) + 3"
unfolding length_ξ by simp
also have "... ≤ (361 + 76 * 2 ^ n + 55 * 2 ^ n) * 2 ^ n + 3"
apply (intro add_mono order.refl mult_le_mono)
subgoal using less_exp by simp
subgoal unfolding oe_n_def by simp
done
also have "... = 131 * 2 ^ (2 * n) + 361 * 2 ^ n + 3"
by (simp add: add_mult_distrib mult_2 power_add)
finally have time_z_le: "time_z ≤ ..." .
have "time_z_filled ≤ ((2 * (2 ^ n + n + 4) + 2 ^ (n - 1) + 5) + 1) * length z + 1"
unfolding time_z_filled_def
apply (intro time_map_tm_bounded)
unfolding time_fill_tm segment_lens_def
using length_z_entries in_set_conv_nth[of _ z] unfolding length_z
by fastforce
also have "... ≤ (2 * 2 ^ n + 2 * n + 2 ^ (n - 1) + 14) * 2 ^ n + 1"
apply (intro add_mono[of _ _ _ 1] mult_le_mono order.refl)
subgoal by simp
subgoal unfolding length_z oe_n_def by simp
done
also have "... ≤ (5 * 2 ^ n + 14) * 2 ^ n + 1"
apply (intro add_mono[of _ _ _ 1] mult_le_mono order.refl)
using less_exp[of n] power_increasing[of "n - 1" n "2::nat"] by linarith
also have "... = 5 * 2 ^ (2 * n) + 14 * 2 ^ n + 1"
by (simp add: add_mult_distrib mult_2 power_add)
finally have time_z_filled_le: "time_z_filled ≤ ..." .
have "time (map2_tm (schoenhage_strassen_tm n) A.num_dft_odds B.num_dft_odds) ≤
(schoenhage_strassen_Fm_bound n + 2) * length A.num_dft_odds + 3"
apply (intro time_map2_tm_bounded)
subgoal unfolding A.length_num_dft_odds B.length_num_dft_odds by (rule refl)
subgoal premises prems for x y
apply (intro less.IH[OF n_lt_m])
subgoal using prems A.num_dft_odds_carrier by blast
subgoal using prems B.num_dft_odds_carrier by blast
done
done
also have "... ≤ schoenhage_strassen_Fm_bound n * 2 ^ (oe_n - 1) + 2 * 2 ^ n + 3"
unfolding A.length_num_dft_odds
using oe_n_minus_1_le_n
by simp
finally have recursive_time: "time (map2_tm (schoenhage_strassen_tm n) A.num_dft_odds B.num_dft_odds) ≤
..." .
have two_pow_pos: "(2::nat) ^ x > 0" for x
by simp
have "time (schoenhage_strassen_tm m a b) =
time (m <⇩t 3) + time (odd_tm m) +
(if odd m then time (m +⇩t 1) + time ((m + 1) div⇩t 2)
else time (m +⇩t 2) + time ((m + 2) div⇩t 2)) +
time (n +⇩t 1) +
time (n -⇩t 1) +
time (n +⇩t 2) +
(if odd m then 0 else 0) +
time (2 ^⇩t (n - 1)) +
time (subdivide_tm segment_lens a) +
time (subdivide_tm segment_lens b) +
time (map_tm Znr.reduce_tm A.num_blocks) +
time (3 *⇩t n) +
time ((3 * n) +⇩t 5) +
time (map_tm (fill_tm pad_length) A.num_Zn) +
time (concat_tm (map (fill pad_length) A.num_Zn)) +
time (map_tm Znr.reduce_tm B.num_blocks) +
time (map_tm (fill_tm pad_length) B.num_Zn) +
time (concat_tm (map (fill pad_length) B.num_Zn)) +
time (oe_n +⇩t 1) +
time (2 ^⇩t (oe_n + 1)) +
time (pad_length *⇩t 2 ^ (oe_n + 1)) +
time (karatsuba_mul_nat_tm A.num_Zn_pad B.num_Zn_pad) +
time (ensure_length_tm uv_length uv_unpadded) +
time (oe_n -⇩t 1) +
time (2 ^⇩t (oe_n - 1)) +
time (subdivide_tm pad_length uv) +
time (subdivide_tm (2 ^ (oe_n - 1)) γs) +
time (nth_tm γ 0) +
time (nth_tm γ 1) +
time (nth_tm γ 2) +
time (nth_tm γ 3) +
time_η +
(if odd m then 0 else 0) +
time (2 ^⇩t (n + 1)) +
time (map_tm (fill_tm fn_carrier_len) A.num_blocks) +
time (map_tm (fill_tm fn_carrier_len) B.num_blocks) +
time (Fnr.fft_tm prim_root_exponent A.num_blocks_carrier) +
time (Fnr.fft_tm prim_root_exponent B.num_blocks_carrier) +
time (evens_odds_tm False A.num_dft) +
time (evens_odds_tm False B.num_dft) +
time (map2_tm (schoenhage_strassen_tm n) A.num_dft_odds B.num_dft_odds) +
time (prim_root_exponent *⇩t 2) +
time (Fnr.ifft_tm (prim_root_exponent * 2) c_dft_odds) +
time (2 ^⇩t oe_n) +
time (upt_tm 0 (2 ^ (oe_n - 1))) +
time (upt_tm (2 ^ (oe_n - 1)) (2 ^ oe_n)) +
time (2 ^⇩t n) +
time (oe_n +⇩t 2 ^ n) +
time (replicate_tm (oe_n + 2 ^ n) False) +
time (oe_n_plus_two_pow_n_zeros @⇩t [True]) +
time_ξ' +
time_ξ +
time_z +
time (map_tm (fill_tm segment_lens) z) +
time (replicate_tm (2 ^ (oe_n - 1)) oe_n_plus_two_pow_n_one) +
time (z_filled @⇩t z_consts) +
time (combine_z_tm segment_lens z_complete) +
time (Fmr.from_nat_lsbf_tm z_sum) +
0 +
1"
unfolding schoenhage_strassen_tm.simps[of m a b] tm_time_simps
unfolding val_simp val_times_nat_tm val_subdivide_tm[OF two_pow_pos] val_subdivide_tm[OF pad_length_gt_0] Znr.val_reduce_tm defs[symmetric]
Let_def val_nth_γ val_fft1 val_fft2 val_ifft val_ih Fnr.val_ifft_tm[OF length_c_dft_odds]
unfolding Eq_FalseI[OF 3] if_False add.assoc[symmetric] time_z_filled_def[symmetric]
apply (intro arg_cong2[where f = "(+)"] refl)
unfolding defs''[symmetric] time_ξ'_def[symmetric] time_η_def[symmetric] time_ξ_def[symmetric]
time_z_def[symmetric] time_z_filled_def[symmetric]
by (intro refl)+
also have "... ≤ 8 + (8 * m + 14) +
(28 + 9 * m) +
(n + 1) +
2 +
(n + 1) +
0 +
(8 * 2 ^ n + 1) +
(10 * 2 ^ m + 2 ^ n + 4) +
(10 * 2 ^ m + 2 ^ n + 4) +
(((2 ^ n + 2 * n + 12) + 1) * length A.num_blocks + 1) +
(7 + 3 * n) +
(6 + 3 * n) +
(((5 * n + 14) + 1) * length A.num_Zn + 1) +
(14 * 2 ^ n + 6 * (n * 2 ^ n) + 1) +
(((2 ^ n + 2 * n + 12) + 1) * length B.num_blocks + 1) +
(((5 * n + 14) + 1) * length B.num_Zn + 1) +
(14 * 2 ^ n + 6 * (n * 2 ^ n) + 1) +
(n + 2) +
(24 * 2 ^ n + 5 * n + 11) +
(12 * (n * 2 ^ n) + 20 * 2 ^ n + 6 * n + 11) +
time (karatsuba_mul_nat_tm A.num_Zn_pad B.num_Zn_pad) +
(168 * (n * 2 ^ n) + 280 * 2 ^ n + (4 * karatsuba_lower_bound + 19)) +
2 +
12 * 2 ^ n +
(14 + 60 * (n * 2 ^ n) + (100 * 2 ^ n + 6 * n)) +
(22 * 2 ^ n + 4) +
(0 + 1) +
(1 + 1) +
(2 + 1) +
(3 + 1) +
(480 * 2 ^ n + 240 * (n * 2 ^ n)) +
0 +
24 * 2 ^ n +
(((3 * 2 ^ n + 5) + 1) * length A.num_blocks + 1) +
(((3 * 2 ^ n + 5) + 1) * length B.num_blocks + 1) +
(2 ^ oe_n * (66 + 87 * Fnr.e) + oe_n * 2 ^ oe_n * (76 + 116 * Fnr.e) +
8 * prim_root_exponent * 2 ^ (2 * oe_n)) +
(2 ^ oe_n * (66 + 87 * Fnr.e) + oe_n * 2 ^ oe_n * (76 + 116 * Fnr.e) +
8 * prim_root_exponent * 2 ^ (2 * oe_n)) +
(2 * 2 ^ n + 1) +
(2 * 2 ^ n + 1) +
(schoenhage_strassen_Fm_bound n * 2 ^ (oe_n - 1) + 2 * 2 ^ n + 3) +
9 +
(2 ^ (oe_n - 1) * (66 + 87 * Fnr.e) +
(oe_n - 1) * 2 ^ (oe_n - 1) * (76 + 116 * Fnr.e) +
8 * (prim_root_exponent * 2) * 2 ^ (2 * (oe_n - 1))) +
24 * 2 ^ n +
(2 * 2 ^ (2 * n) + 5 * 2 ^ n + 2) +
(8 * 2 ^ (2 * n) + 10 * 2 ^ n + 2) +
12 * 2 ^ n +
(n + 2) +
(2 ^ n + n + 2) +
(2 ^ n + n + 2) +
(3 + 475 * 2 ^ n + 758 * 2 ^ (2 * n)) +
(1 + 156 * 2 ^ n + 216 * 2 ^ (2 * n)) +
(131 * 2 ^ (2 * n) + 361 * 2 ^ n + 3) +
(5 * 2 ^ (2 * n) + 14 * 2 ^ n + 1) +
(2 ^ n + 1) +
(2 ^ n + 1) +
(10 + (3 * segment_lens + 2 * (2 ^ n + n + 4) + 10) * length z_complete) +
(8208 + 23488 * 2 ^ m) + 0 + 1"
apply (intro add_mono)
subgoal unfolding time_less_nat_tm by simp
subgoal by (rule time_odd_tm_le)
subgoal
apply (estimation estimate: if_le_max)
unfolding time_plus_nat_tm
apply (estimation estimate: time_divide_nat_tm_le)
apply (estimation estimate: time_divide_nat_tm_le)
by simp
subgoal unfolding time_plus_nat_tm by (rule order.refl)
subgoal unfolding time_minus_nat_tm by simp
subgoal unfolding time_plus_nat_tm by (rule order.refl)
subgoal by simp
subgoal
apply (estimation estimate: time_power_nat_tm_le)
unfolding Suc_diff_1[OF n_gt_0]
using less_exp[of "n - 1"] power_increasing[of "n - 1" n "2::nat"]
by linarith
subgoal
apply (estimation estimate: time_subdivide_tm_le[OF segment_lens_pos])
unfolding A.length_num segment_lens_def power_Suc[symmetric]
Suc_diff_1[OF n_gt_0] by simp
subgoal
apply (estimation estimate: time_subdivide_tm_le[OF segment_lens_pos])
unfolding B.length_num segment_lens_def power_Suc[symmetric]
Suc_diff_1[OF n_gt_0] by simp
subgoal
apply (intro time_map_tm_bounded)
subgoal premises prems for i
proof -
have "time (Znr.reduce_tm i) = 8 + 2 * length i + 2 * n + 4"
unfolding Znr.time_reduce_tm by simp
also have "... = 8 + 2 * 2 ^ (n - 1) + 2 * n + 4"
apply (intro arg_cong2[where f = "(+)"] arg_cong2[where f = "(*)"] refl)
using A.length_nth_num_blocks iffD1[OF in_set_conv_nth prems]
unfolding A.length_num_blocks by auto
also have "... = 2 ^ n + 2 * n + 12"
unfolding power_Suc[symmetric] Suc_diff_1[OF n_gt_0] by simp
finally show ?thesis by simp
qed
done
subgoal by (simp del: One_nat_def)
subgoal by simp
subgoal apply (intro time_map_tm_bounded)
subgoal premises prems for i
proof -
have "time (fill_tm pad_length i) = 2 * length i + 3 * n + 10"
unfolding time_fill_tm pad_length_def by simp
also have "... = 2 * (n + 2) + 3 * n + 10"
apply (intro arg_cong2[where f = "(+)"] arg_cong2[where f = "(*)"] refl)
using A.length_nth_num_Zn iffD1[OF in_set_conv_nth prems]
unfolding A.length_num_Zn by auto
also have "... = 5 * n + 14"
by simp
finally show ?thesis by simp
qed
done
subgoal unfolding time_concat_tm length_map A.num_Zn_pad_def[symmetric] A.length_num_Zn_pad
unfolding A.length_num_Zn pad_length_def
apply (estimation estimate: oe_n_le_n)
by (simp add: add_mult_distrib)
subgoal
apply (intro time_map_tm_bounded)
subgoal premises prems for i
proof -
have "time (Znr.reduce_tm i) = 8 + 2 * length i + 2 * n + 4"
unfolding Znr.time_reduce_tm by simp
also have "... = 8 + 2 * 2 ^ (n - 1) + 2 * n + 4"
apply (intro arg_cong2[where f = "(+)"] arg_cong2[where f = "(*)"] refl)
using B.length_nth_num_blocks iffD1[OF in_set_conv_nth prems]
unfolding B.length_num_blocks by auto
also have "... = 2 ^ n + 2 * n + 12"
unfolding power_Suc[symmetric] Suc_diff_1[OF n_gt_0] by simp
finally show ?thesis by simp
qed
done
subgoal apply (intro time_map_tm_bounded)
subgoal premises prems for i
proof -
have "time (fill_tm pad_length i) = 2 * length i + 3 * n + 10"
unfolding time_fill_tm pad_length_def by simp
also have "... = 2 * (n + 2) + 3 * n + 10"
apply (intro arg_cong2[where f = "(+)"] arg_cong2[where f = "(*)"] refl)
using B.length_nth_num_Zn iffD1[OF in_set_conv_nth prems]
unfolding B.length_num_Zn by auto
also have "... = 5 * n + 14"
by simp
finally show ?thesis by simp
qed
done
subgoal unfolding time_concat_tm length_map B.num_Zn_pad_def[symmetric] B.length_num_Zn_pad
unfolding B.length_num_Zn pad_length_def
apply (estimation estimate: oe_n_le_n)
by (simp add: add_mult_distrib)
subgoal unfolding oe_n_def by simp
subgoal
apply (estimation estimate: time_power_nat_tm_le)
apply (estimation estimate: oe_n_le_n)
by simp_all
subgoal
unfolding time_times_nat_tm pad_length_def
unfolding add_mult_distrib add_mult_distrib2
apply (estimation estimate: oe_n_le_n)
by simp_all
subgoal by (rule order.refl)
subgoal unfolding time_ensure_length_tm
apply (estimation estimate: length_uv_unpadded_le)
unfolding uv_length_def pad_length_def
apply (estimation estimate: oe_n_le_n)
by (simp_all add: add_mult_distrib)
subgoal by simp
subgoal
apply (estimation estimate: time_power_nat_tm_2_le)
apply (estimation estimate: oe_n_minus_1_le_n)
by simp_all
subgoal
apply (estimation estimate: time_subdivide_tm_le[OF pad_length_gt_0])
unfolding uv_length_def length_uv pad_length_def
apply (estimation estimate: oe_n_le_n)
by (simp_all add: add_mult_distrib)
subgoal
apply (estimation estimate: time_subdivide_tm_le[OF two_pow_pos])
unfolding length_γs'
apply (estimation estimate: oe_n_minus_1_le_n)
by simp_all
subgoal using time_nth_tm[of 0 γ] scγ(1) by simp
subgoal using time_nth_tm[of 1 γ] scγ(1) by simp
subgoal using time_nth_tm[of 2 γ] scγ(1) by simp
subgoal using time_nth_tm[of 3 γ] scγ(1) by simp
subgoal
apply (estimation estimate: time_η_le)
by (simp add: add_mult_distrib)
subgoal by simp
subgoal
apply (estimation estimate: time_power_nat_tm_2_le)
by simp
subgoal apply (intro time_map_tm_bounded)
unfolding time_fill_tm
subgoal premises prems for i
proof -
have leni: "length i = 2 ^ (n - 1)"
using iffD1[OF in_set_conv_nth prems]
unfolding A.length_num_blocks
using A.length_nth_num_blocks by auto
show ?thesis unfolding leni power_Suc[symmetric] Suc_diff_1[OF n_gt_0]
unfolding fn_carrier_len_def
by simp
qed
done
subgoal apply (intro time_map_tm_bounded)
unfolding time_fill_tm
subgoal premises prems for i
proof -
have leni: "length i = 2 ^ (n - 1)"
using iffD1[OF in_set_conv_nth prems]
unfolding B.length_num_blocks
using B.length_nth_num_blocks by auto
show ?thesis unfolding leni power_Suc[symmetric] Suc_diff_1[OF n_gt_0]
unfolding fn_carrier_len_def
by simp
qed
done
subgoal apply (intro Fnr.time_fft_tm_le A.length_num_blocks_carrier)
using A.fill_num_blocks_carrier
using Fnr.fermat_carrier_length
unfolding defs[symmetric] by blast
subgoal apply (intro Fnr.time_fft_tm_le B.length_num_blocks_carrier)
using B.fill_num_blocks_carrier
using Fnr.fermat_carrier_length
unfolding defs[symmetric] by blast
subgoal
apply (estimation estimate: time_evens_odds_tm_le)
unfolding A.length_num_dft
apply (estimation estimate: oe_n_le_n)
by simp_all
subgoal
apply (estimation estimate: time_evens_odds_tm_le)
unfolding B.length_num_dft
apply (estimation estimate: oe_n_le_n)
by simp_all
subgoal by (rule recursive_time)
subgoal using prim_root_exponent_le by simp
subgoal apply (intro Fnr.time_ifft_tm_le length_c_dft_odds)
using c_dft_odds_carrier Fnr.fermat_carrier_length by auto
subgoal
apply (estimation estimate: time_power_nat_tm_2_le)
apply (estimation estimate: oe_n_le_n)
by simp_all
subgoal apply (estimation estimate: time_upt_tm_le')
apply (estimation estimate: oe_n_minus_1_le_n)
by (simp_all only: power_add[symmetric] mult.assoc mult_2[symmetric])
subgoal apply (estimation estimate: time_upt_tm_le')
apply (estimation estimate: oe_n_le_n)
by (simp_all add: power_add[symmetric] mult_2[symmetric])
subgoal apply (estimation estimate: time_power_nat_tm_2_le)
by (rule order.refl)
subgoal using oe_n_le_n by simp
subgoal unfolding time_replicate_tm
using oe_n_le_n by simp
subgoal using oe_n_le_n
by (simp add: oe_n_plus_two_pow_n_zeros_def)
subgoal by (rule time_ξ'_le)
subgoal by (rule time_ξ_le)
subgoal by (rule time_z_le)
subgoal unfolding time_z_filled_def[symmetric] by (rule time_z_filled_le)
subgoal unfolding time_replicate_tm
using oe_n_minus_1_le_n by simp
subgoal using oe_n_minus_1_le_n by (simp add: length_z_filled)
subgoal apply (intro time_combine_z_tm_le[OF _ segment_lens_pos])
using length_z_complete_entries .
subgoal
apply (estimation estimate: Fmr.time_from_nat_lsbf_tm_le[OF Fmr.e_ge_4, OF m_gt_0 length_z_sum_le])
by simp
subgoal by (rule order.refl)
subgoal by (rule order.refl)
done
also have "... ≤ 8410 + 23508 * 2 ^ m + 2069 * 2 ^ n + 1141 * 2 ^ (2 * n) + 29 * n +
32 * 2 ^ (2 * oe_n) +
2 * (oe_n * (2 ^ oe_n * (76 + 232 * 2 ^ n))) +
2 * (2 ^ oe_n * (66 + 174 * 2 ^ n)) +
2 * (2 ^ oe_n * (6 + 3 * 2 ^ n)) +
492 * (n * 2 ^ n) +
2 * (2 ^ oe_n * (15 + 5 * n)) +
2 * (2 ^ oe_n * (13 + 2 ^ n + 2 * n)) +
17 * m +
time (karatsuba_mul_nat_tm A.num_Zn_pad B.num_Zn_pad) +
4 * karatsuba_lower_bound +
schoenhage_strassen_Fm_bound n * 2 ^ (oe_n - 1) +
2 ^ (oe_n - 1) * (66 + 174 * 2 ^ n) +
(oe_n - 1) * 2 ^ (oe_n - 1) * (76 + 232 * 2 ^ n) +
32 * 2 ^ (2 * (oe_n - 1)) +
(18 + 3 * 2 ^ (n - 1) + 2 * 2 ^ n + 2 * n) * 2 ^ oe_n"
unfolding A.length_num_blocks A.length_num_Zn B.length_num_blocks B.length_num_Zn
apply (estimation estimate: prim_root_exponent_le)
apply (estimation estimate: prim_root_exponent_2_le)
unfolding segment_lens_def length_z_complete
by (simp add: add.assoc[symmetric])
also have "... ≤ 8410 + 23508 * 2 ^ m + 2069 * 2 ^ n + 1141 * 2 ^ (2 * n) + 29 * n +
128 * 2 ^ (2 * n) +
2464 * (n * 2 ^ (2 * n)) +
(264 * 2 ^ n + 696 * 2 ^ (2 * n)) +
(24 * 2 ^ n + 12 * 2 ^ (2 * n)) +
492 * (n * 2 ^ n) +
(60 * 2 ^ n + 20 * (n * 2 ^ n)) +
(52 * 2 ^ n + 4 * 2 ^ (2 * n) + 8 * (n * 2 ^ n)) +
17 * m +
time (karatsuba_mul_nat_tm A.num_Zn_pad B.num_Zn_pad) +
4 * karatsuba_lower_bound +
schoenhage_strassen_Fm_bound n * 2 ^ (oe_n - 1) +
(66 * 2 ^ n + 174 * 2 ^ (2 * n)) +
(76 * (n * 2 ^ n) + n * (232 * 2 ^ (2 * n))) +
32 * 2 ^ (2 * n) +
(36 * 2 ^ n + 10 * 2 ^ (2 * n) + 4 * (n * 2 ^ n))"
apply (intro add_mono order.refl)
subgoal apply (estimation estimate: oe_n_le_n) by simp_all
subgoal
proof -
have "2 * (oe_n * (2 ^ oe_n * (76 + 232 * 2 ^ n))) ≤
2 * ((2 * n) * (2 ^ (n + 1) * (76 + 232 * 2 ^ n)))"
apply (intro add_mono mult_le_mono order.refl)
subgoal apply (estimation estimate: oe_n_le_n)
unfolding mult_2 using n_gt_0 by simp
subgoal by (estimation estimate: oe_n_le_n; simp)
done
also have "... = 8 * n * 2 ^ n * (76 + 232 * 2 ^ n)"
by simp
also have "... ≤ 8 * n * 2 ^ n * (76 * 2 ^ n + 232 * 2 ^ n)"
by (intro add_mono mult_le_mono order.refl; simp)
also have "... = 2464 * (n * 2 ^ (2 * n))"
by (simp add: mult_2 power_add)
finally show ?thesis .
qed
subgoal apply (estimation estimate: oe_n_le_n)
by (simp add: add_mult_distrib2 mult_2 power_add)
subgoal apply (estimation estimate: oe_n_le_n)
by (simp add: add_mult_distrib2 mult_2 power_add)
subgoal apply (estimation estimate: oe_n_le_n)
by (simp add: add_mult_distrib2 mult_2 power_add)
subgoal apply (estimation estimate: oe_n_le_n)
by (simp add: add_mult_distrib2 power_add[symmetric])
subgoal apply (estimation estimate: oe_n_minus_1_le_n)
by (simp add: add_mult_distrib2 mult_2 power_add)
subgoal apply (estimation estimate: oe_n_minus_1_le_n)
by (simp add: add_mult_distrib2 mult_2 power_add)
subgoal apply (estimation estimate: oe_n_minus_1_le_n)
by (simp add: add_mult_distrib2 mult_2 power_add)
subgoal apply (estimation estimate: oe_n_le_n)
using power_increasing[of "n - 1" n "2::nat"]
by (simp add: add_mult_distrib2 add_mult_distrib mult_2[of n, symmetric] power_add[symmetric])
done
also have "... = 600 * (n * 2 ^ n) + 2197 * 2 ^ (2 * n) + 2571 * 2 ^ n +
2696 * (n * 2 ^ (2 * n)) +
8410 +
23508 * 2 ^ m +
29 * n +
17 * m +
time (karatsuba_mul_nat_tm A.num_Zn_pad B.num_Zn_pad) +
4 * karatsuba_lower_bound +
schoenhage_strassen_Fm_bound n * 2 ^ (oe_n - 1)"
by (simp add: add.assoc[symmetric])
also have "... ≤ 600 * (n * 2 ^ n) + 2197 * 2 ^ (2 * n) + 2571 * 2 ^ n +
2696 * (n * 2 ^ (2 * n)) +
8410 +
23508 * 2 ^ m +
29 * n +
17 * m +
time_karatsuba_mul_nat_bound ((3 * n + 5) * 2 ^ oe_n) +
4 * karatsuba_lower_bound +
schoenhage_strassen_Fm_bound n * 2 ^ (oe_n - 1)"
apply (intro add_mono order.refl time_karatsuba_mul_nat_tm_le)
unfolding A.length_num_Zn_pad B.length_num_Zn_pad pad_length_def by simp
also have "... ≤ 600 * (n * 2 ^ (2 * n)) + 2197 * (n * 2 ^ (2 * n)) + 2571 * (n * 2 ^ (2 * n)) +
2696 * (n * 2 ^ (2 * n)) +
8410 +
23508 * 2 ^ m +
29 * (n * 2 ^ (2 * n)) +
17 * 2 ^ m +
time_karatsuba_mul_nat_bound ((3 * n + 5) * 2 ^ oe_n) +
4 * karatsuba_lower_bound +
schoenhage_strassen_Fm_bound n * 2 ^ (oe_n - 1)"
apply (intro add_mono mult_le_mono order.refl power_increasing)
subgoal by simp
subgoal by simp
subgoal using n_gt_0 by simp
subgoal using power_increasing[of n "2 * n" "2::nat"] ‹2 ^ (2 * n) ≤ n * 2 ^ (2 * n)› by linarith
subgoal by simp
subgoal by simp
done
also have "... = 23525 * 2 ^ m + 8093 * (n * 2 ^ (2 * n)) + 8410 +
time_karatsuba_mul_nat_bound ((3 * n + 5) * 2 ^ oe_n) +
4 * karatsuba_lower_bound +
schoenhage_strassen_Fm_bound n * 2 ^ (oe_n - 1)"
by simp
also have "... = schoenhage_strassen_Fm_bound m"
unfolding schoenhage_strassen_Fm_bound.simps[of m] Let_def defs[symmetric] using 3 by argo
finally show ?thesis .
qed
qed
definition karatsuba_const where
"karatsuba_const = (SOME c. (∀x. x > 0 ⟶ time_karatsuba_mul_nat_bound x ≤ c * nat (floor (real x powr log 2 3))))"
lemma real_divide_mult_eq:
assumes "(c :: real) ≠ 0"
shows "a / c * c = a"
using assms by simp
lemma powr_unbounded:
assumes "(c :: real) > 0"
shows "eventually (λx. d ≤ x powr c) at_top"
proof (cases "d > 0")
case True
define N where "N = d powr (1 / c)"
have "d ≤ x powr c" if "x ≥ N" for x
proof -
have "d = d powr 1" apply (intro powr_one[symmetric]) using True by simp
also have "... = (d powr (1 / c)) powr c"
unfolding powr_powr
apply (intro arg_cong2[where f = "(powr)"] refl real_divide_mult_eq[symmetric]) using assms by simp
also have "... = N powr c" unfolding N_def by (rule refl)
also have "... ≤ x powr c"
apply (intro powr_mono2)
subgoal using assms by simp
subgoal unfolding N_def by (rule powr_ge_pzero)
subgoal by (rule that)
done
finally show ?thesis .
qed
then show ?thesis unfolding eventually_at_top_linorder by blast
next
case False
then show ?thesis
apply (intro always_eventually allI)
subgoal for x using powr_ge_pzero[of x c] by argo
done
qed
lemma time_kar_le_kar_const:
assumes "x > 0"
shows "time_karatsuba_mul_nat_bound x ≤ karatsuba_const * nat (floor (real x powr log 2 3))"
proof -
have "∃c. (∀x. x ≥ 1 ⟶ time_karatsuba_mul_nat_bound x ≤ c * nat (floor (real x powr log 2 3)))"
apply (intro eventually_early_nat)
subgoal
apply (intro bigo_floor)
subgoal by (rule time_karatsuba_mul_nat_bound_bigo)
subgoal apply (intro eventually_nat_real[OF powr_unbounded[of "log 2 3" 1]]) by simp
done
subgoal premises prems for x
proof -
have "real x ≥ 1" using prems by simp
then have "real x powr log 2 3 ≥ 1 powr log 2 3"
by (intro powr_mono2; simp)
then have "real x powr log 2 3 ≥ 1" by simp
then have "floor (real x powr log 2 3) ≥ 1" by simp
then show ?thesis by simp
qed
done
then have "∀x > 0. time_karatsuba_mul_nat_bound x ≤ karatsuba_const * nat ⌊real x powr log 2 3⌋"
unfolding karatsuba_const_def
apply (intro someI_ex[of "λc. ∀x>0. time_karatsuba_mul_nat_bound x ≤ c * nat ⌊real x powr log 2 3⌋"])
by (metis int_one_le_iff_zero_less nat_int nat_mono nat_one_as_int of_nat_0_less_iff)
then show ?thesis using assms by blast
qed
lemma poly_smallo_exp:
assumes "c > 1"
shows "(λn. (real n) powr d) ∈ o(λn. c powr (real n))"
by (intro smallo_real_nat_transfer power_smallo_exponential assms)
lemma kar_aux_lem: "(λn. real (n * 2 ^ n) powr log 2 3) ∈ O(λn. real (2 ^ (2 * n)))"
proof -
define c where "c = 2 powr (2 / log 2 3 - 1)"
have "c > 1" unfolding c_def
apply (intro gr_one_powr)
subgoal by simp
subgoal apply simp using less_powr_iff[of 2 3 2] by simp
done
have 1: "(log 2 c + 1) * log 2 3 = 2"
proof -
have "log 2 c = 2 / log 2 3 - 1"
unfolding c_def by (intro log_powr_cancel; simp)
then have "log 2 c + 1 = 2 / log 2 3" by simp
then have "(log 2 c + 1) * log 2 3 = 2 / log 2 3 * log 2 3" by simp
also have "... = 2" apply (intro real_divide_mult_eq)
using zero_less_log_cancel_iff[of 2 3] by linarith
finally show ?thesis .
qed
from poly_smallo_exp[OF ‹c > 1›, of 1] have "real ∈ o(λn. c powr real n)" by simp
then have "(λn. real (n * 2 ^ n)) ∈ o(λn. (c powr real n) * real (2 ^ n))"
by simp
then have "(λn. real (n * 2 ^ n)) ∈ O(λn. (c powr real n) * real (2 ^ n))"
using landau_o.small_imp_big by blast
then have "(λn. real (n * 2 ^ n) powr log 2 3) ∈ O(λn. ((c powr real n) * real (2 ^ n)) powr log 2 3)"
by (intro iffD2[OF bigo_powr_iff]; simp)
also have "... = O(λn. ((c powr real n) * 2 powr (real n)) powr log 2 3)"
using powr_realpow[of 2] by simp
also have "... = O(λn. (((2 powr log 2 c) powr real n) * 2 powr (real n)) powr log 2 3)"
using powr_log_cancel[of 2 c] ‹c > 1› by simp
also have "... = O(λn. 2 powr ((log 2 c * real n + real n) * log 2 3))"
unfolding powr_powr powr_add[symmetric] by (rule refl)
also have "... = O(λn. 2 powr (real n * (log 2 c + 1) * log 2 3))"
apply (intro_cong "[cong_tag_1 (λf. O(f)), cong_tag_2 (powr), cong_tag_2 (*)]" more: refl ext)
by argo
also have "... = O(λn. 2 powr (real n * 2))"
apply (intro_cong "[cong_tag_1 (λf. O(f)), cong_tag_2 (powr)]" more: ext refl)
using 1 by simp
also have "... = O(λn. real (2 ^ (2 * n)))"
apply (intro_cong "[cong_tag_1 (λf. O(f))]" more: ext)
subgoal for n
using powr_realpow[of 2 "2 * n", symmetric]
by (simp add: mult.commute)
done
finally show ?thesis .
qed
definition kar_aux_const where "kar_aux_const = (SOME c. ∀n ≥ 1. real (n * 2 ^ n) powr log 2 3 ≤ c * real (2 ^ (2 * n)))"
lemma kar_aux_lem_le:
assumes "n > 0"
shows "real (n * 2 ^ n) powr log 2 3 ≤ kar_aux_const * real (2 ^ (2 * n))"
proof -
have "(∃c. ∀n ≥ 1. real (n * 2 ^ n) powr log 2 3 ≤ c * real (2 ^ (2 * n)))"
using eventually_early_real[OF kar_aux_lem] by simp
then have "∀n ≥ 1. real (n * 2 ^ n) powr log 2 3 ≤ kar_aux_const * real (2 ^ (2 * n))"
unfolding kar_aux_const_def apply (intro someI_ex[of "λc. ∀n ≥ 1. real (n * 2 ^ n) powr log 2 3 ≤ c * real (2 ^ (2 * n))"]) .
then show ?thesis using assms by simp
qed
lemma kar_aux_const_gt_0: "kar_aux_const > 0"
proof (rule ccontr)
assume "¬ kar_aux_const > 0"
then have "kar_aux_const ≤ 0" by simp
then show "False" using kar_aux_lem_le[of 1] by simp
qed
definition kar_aux_const_nat where "kar_aux_const_nat = karatsuba_const * nat ⌈16 powr log 2 3⌉ * nat ⌈kar_aux_const⌉"
definition s_const1 where "s_const1 = 55897 + 4 * kar_aux_const_nat"
definition s_const2 where "s_const2 = 8410 + 4 * karatsuba_lower_bound"
function schoenhage_strassen_Fm_bound' :: "nat ⇒ nat" where
"m < 3 ⟹ schoenhage_strassen_Fm_bound' m = 5336"
| "m ≥ 3 ⟹ schoenhage_strassen_Fm_bound' m = s_const1 * (m * 2 ^ m) + s_const2 + schoenhage_strassen_Fm_bound' ((m + 2) div 2) * 2 ^ ((m + 1) div 2)"
by fastforce+
termination
by (relation "Wellfounded.measure (λm. m)"; simp)
declare schoenhage_strassen_Fm_bound'.simps[simp del]
lemma schoenhage_strassen_Fm_bound_le_schoenhage_strassen_Fm_bound':
shows "schoenhage_strassen_Fm_bound m ≤ schoenhage_strassen_Fm_bound' m"
proof (induction m rule: less_induct)
case (less m)
show ?case
proof (cases "m < 3")
case True
from True have "schoenhage_strassen_Fm_bound m = 5336" unfolding schoenhage_strassen_Fm_bound.simps[of m] by simp
also have "... = schoenhage_strassen_Fm_bound' m" using schoenhage_strassen_Fm_bound'.simps True by simp
finally show ?thesis by simp
next
case False
then interpret m_lemmas: m_lemmas m
by (unfold_locales; simp)
from False have "m ≥ 3" by simp
define n where "n = m_lemmas.n"
define oe_n where "oe_n = m_lemmas.oe_n"
have kar_arg_pos: "(3 * n + 5) * 2 ^ oe_n > 0" by simp
have fm: "schoenhage_strassen_Fm_bound m = 23525 * 2 ^ m + 8093 * (n * 2 ^ (2 * n)) + 8410 +
time_karatsuba_mul_nat_bound ((3 * n + 5) * 2 ^ oe_n) +
4 * karatsuba_lower_bound +
schoenhage_strassen_Fm_bound n * 2 ^ (oe_n - 1)" (is "_ = ?t1 + ?t2 + ?t3 + ?t4 + ?t5 + ?t6")
unfolding schoenhage_strassen_Fm_bound.simps[of m] n_def oe_n_def using False m_lemmas.n_def m_lemmas.oe_n_def
by simp
have "?t4 ≤ karatsuba_const * nat ⌊real ((3 * n + 5) * 2 ^ oe_n) powr log 2 3⌋"
by (intro time_kar_le_kar_const[OF kar_arg_pos])
also have "... ≤ karatsuba_const * nat ⌊real ((8 * n) * 2 ^ (n + 1)) powr log 2 3⌋"
apply (intro add_mono order.refl mult_le_mono nat_mono floor_mono powr_mono2 iffD1[OF real_mono] power_increasing)
using m_lemmas.oe_n_gt_0 m_lemmas.n_gt_0 m_lemmas.oe_n_le_n by (simp_all add: n_def oe_n_def)
also have "... = karatsuba_const * nat ⌊real (16 * (n * 2 ^ n)) powr log 2 3⌋"
by simp
also have "... = karatsuba_const * nat ⌊(16 powr log 2 3) * ((n * 2 ^ n) powr log 2 3)⌋"
unfolding real_multiplicative using powr_mult[of "real 16" "real n * real (2 ^ n)" "log 2 3"]
by simp
also have "... ≤ karatsuba_const * nat ⌊(16 powr log 2 3) * (kar_aux_const * real (2 ^ (2 * n)))⌋"
apply (intro mult_le_mono order.refl nat_mono floor_mono mult_mono kar_aux_lem_le)
subgoal using m_lemmas.n_gt_0 unfolding n_def .
subgoal by simp
subgoal by simp
done
also have "... ≤ karatsuba_const * nat ⌈(16 powr log 2 3) * (kar_aux_const * real (2 ^ (2 * n)))⌉"
by (intro mult_le_mono order.refl nat_mono floor_le_ceiling)
also have "... ≤ karatsuba_const * (nat (⌈16 powr log 2 3⌉ * ⌈kar_aux_const * real (2 ^ (2 * n))⌉))"
using kar_aux_const_gt_0 by (intro mult_le_mono order.refl nat_mono mult_ceiling_le; simp)
also have "... = karatsuba_const * (nat ⌈16 powr log 2 3⌉ * nat ⌈kar_aux_const * real (2 ^ (2 * n))⌉)"
apply (intro arg_cong2[where f = "(*)"] refl nat_mult_distrib)
using powr_ge_pzero[of 16 "log 2 3"] by linarith
also have "... ≤ karatsuba_const * (nat ⌈16 powr log 2 3⌉ * nat (⌈kar_aux_const⌉ * ⌈real (2 ^ (2 * n))⌉))"
apply (intro mult_le_mono order.refl nat_mono mult_ceiling_le)
using kar_aux_const_gt_0 by simp_all
also have "... = karatsuba_const * (nat ⌈16 powr log 2 3⌉ * (nat ⌈kar_aux_const⌉ * nat ⌈real (2 ^ (2 * n))⌉))"
apply (intro arg_cong2[where f = "(*)"] refl nat_mult_distrib)
using kar_aux_const_gt_0 by simp
also have "... = karatsuba_const * nat ⌈16 powr log 2 3⌉ * nat ⌈kar_aux_const⌉ * (2 ^ (2 * n))"
by simp
also have "... = kar_aux_const_nat * 2 ^ (2 * n)"
unfolding kar_aux_const_nat_def[symmetric] by (rule refl)
also have "... ≤ kar_aux_const_nat * (n * 2 ^ (2 * n))"
using m_lemmas.n_gt_0 n_def by simp
finally have t4_le: "?t4 ≤ ..." .
have "schoenhage_strassen_Fm_bound m ≤ ?t1 + ?t2 + ?t3 + kar_aux_const_nat * (n * 2 ^ (2 * n)) + ?t5 + ?t6"
unfolding fm
by (intro add_mono order.refl t4_le)
also have "... = ?t1 + (8093 + kar_aux_const_nat) * (n * 2 ^ (2 * n)) + 8410 + 4 * karatsuba_lower_bound + schoenhage_strassen_Fm_bound n * 2 ^ (oe_n - 1)"
by (simp add: add_mult_distrib)
also have "... ≤ 23525 * (m * 2 ^ m) + (8093 + kar_aux_const_nat) * (m * (2 * 2 ^ (m + 1))) + 8410 + 4 * karatsuba_lower_bound + schoenhage_strassen_Fm_bound n * 2 ^ (oe_n - 1)"
apply (intro add_mono order.refl mult_le_mono)
subgoal using m_lemmas.m_gt_0 by simp
subgoal using m_lemmas.n_lt_m n_def by simp
subgoal using m_lemmas.two_pow_two_n_le n_def by simp
done
also have "... = (55897 + 4 * kar_aux_const_nat) * (m * 2 ^ m) + (8410 + 4 * karatsuba_lower_bound) + schoenhage_strassen_Fm_bound n * 2 ^ (oe_n - 1)"
by (simp add: add_mult_distrib)
also have "... ≤ (55897 + 4 * kar_aux_const_nat) * (m * 2 ^ m) + (8410 + 4 * karatsuba_lower_bound) + schoenhage_strassen_Fm_bound' n * 2 ^ (oe_n - 1)"
apply (intro add_mono order.refl mult_le_mono less.IH)
unfolding n_def using m_lemmas.n_lt_m .
also have "... = (55897 + 4 * kar_aux_const_nat) * (m * 2 ^ m) + (8410 + 4 * karatsuba_lower_bound) + schoenhage_strassen_Fm_bound' ((m + 2) div 2) * 2 ^ ((m + 1) div 2)"
apply (intro_cong "[cong_tag_2 (+), cong_tag_2 (*), cong_tag_2 (^), cong_tag_1 schoenhage_strassen_Fm_bound']" more: refl)
subgoal unfolding n_def m_lemmas.n_def by (cases "odd m"; simp)
subgoal unfolding oe_n_def m_lemmas.oe_n_def m_lemmas.n_def by (cases "odd m"; simp)
done
also have "... = schoenhage_strassen_Fm_bound' m" using schoenhage_strassen_Fm_bound'.simps[of m] False unfolding s_const1_def[symmetric] s_const2_def[symmetric] by simp
finally show ?thesis .
qed
qed
definition γ_0 where "γ_0 = 2 * s_const1 + s_const2"
lemma schoenhage_strassen_Fm_bound'_oe_rec:
assumes "n ≥ 3"
shows "schoenhage_strassen_Fm_bound' (2 * n - 2) ≤ γ_0 * n * 2 ^ (2 * n - 2) + schoenhage_strassen_Fm_bound' n * 2 ^ (n - 1)"
and "schoenhage_strassen_Fm_bound' (2 * n - 1) ≤ γ_0 * n * 2 ^ (2 * n - 1) + schoenhage_strassen_Fm_bound' n * 2 ^ n"
proof -
from assms have r: "2 * n - 2 ≥ 4" by linarith
from r have "schoenhage_strassen_Fm_bound' (2 * n - 1) = s_const1 * (2 * n - 1) * 2 ^ (2 * n - 1) + s_const2 + schoenhage_strassen_Fm_bound' n * 2 ^ n"
using schoenhage_strassen_Fm_bound'.simps[of "2 * n - 1"] by auto
also have "... ≤ s_const1 * (2 * n) * 2 ^ (2 * n - 1) + s_const2 * (n * 2 ^ (2 * n - 1)) + schoenhage_strassen_Fm_bound' n * 2 ^ n"
apply (intro add_mono order.refl mult_le_mono)
subgoal by simp
subgoal using assms by simp
done
also have "... = γ_0 * n * 2 ^ (2 * n - 1) + schoenhage_strassen_Fm_bound' n * 2 ^ n"
unfolding γ_0_def by (simp add: add_mult_distrib)
finally show "schoenhage_strassen_Fm_bound' (2 * n - 1) ≤ ..." .
from r have "schoenhage_strassen_Fm_bound' (2 * n - 2) = s_const1 * ((2 * n - 2) * 2 ^ (2 * n - 2)) + s_const2 +
schoenhage_strassen_Fm_bound' ((2 * n - 2 + 2) div 2) * 2 ^ ((2 * n - 2 + 1) div 2)"
using schoenhage_strassen_Fm_bound'.simps(2)[of "2 * n - 2"] by fastforce
also have "... = s_const1 * ((2 * n - 2) * 2 ^ (2 * n - 2)) + s_const2 +
schoenhage_strassen_Fm_bound' n * 2 ^ (n - 1)"
apply (intro_cong "[cong_tag_2 (+), cong_tag_2 (*), cong_tag_2 (^), cong_tag_1 schoenhage_strassen_Fm_bound']" more: refl)
subgoal using r by linarith
subgoal using r by linarith
done
also have "... ≤ s_const1 * ((2 * n) * 2 ^ (2 * n - 2)) + s_const2 * (n * 2 ^ (2 * n - 2)) + schoenhage_strassen_Fm_bound' n * 2 ^ (n - 1)"
apply (intro add_mono order.refl mult_le_mono)
subgoal by simp
subgoal using assms by simp
done
also have "... = γ_0 * n * 2 ^ (2 * n - 2) + schoenhage_strassen_Fm_bound' n * 2 ^ (n - 1)"
unfolding γ_0_def by (simp add: add_mult_distrib)
finally show "schoenhage_strassen_Fm_bound' (2 * n - 2) ≤ ..." .
qed
definition γ where "γ = Max {γ_0, schoenhage_strassen_Fm_bound' 0, schoenhage_strassen_Fm_bound' 1, schoenhage_strassen_Fm_bound' 2, schoenhage_strassen_Fm_bound' 3}"
lemma schoenhage_strassen_Fm_bound'_le_aux1:
assumes "m ≤ 2 ^ Suc k + 1"
shows "schoenhage_strassen_Fm_bound' m ≤ γ * Suc k * 2 ^ (Suc k + m)"
using assms proof (induction k arbitrary: m rule: less_induct)
case (less k)
consider "m ≤ 3" | "m ≥ 4" by linarith
then show ?case
proof cases
case 1
then have "m ∈ {0, 1, 2, 3}" by auto
then have "schoenhage_strassen_Fm_bound' m ∈ {γ_0, schoenhage_strassen_Fm_bound' 0, schoenhage_strassen_Fm_bound' 1, schoenhage_strassen_Fm_bound' 2, schoenhage_strassen_Fm_bound' 3}" by auto
then have "schoenhage_strassen_Fm_bound' m ≤ γ" unfolding γ_def by (intro Max.coboundedI; simp)
also have "... = γ * 1 * 1" by simp
also have "... ≤ γ * Suc k * 2 ^ (Suc k + m)"
by (intro mult_le_mono order.refl; simp)
finally show ?thesis .
next
case 2
have "k > 0"
proof (rule ccontr)
assume "¬ k > 0"
with less.prems have "m ≤ 3" by simp
thus False using 2 by simp
qed
then obtain k' where "k = Suc k'" "k' < k"
using gr0_conv_Suc by auto
have ih': "schoenhage_strassen_Fm_bound' m ≤ γ * k * 2 ^ (k + m)" if "m ≤ 2 ^ k + 1" for m
using less.IH[OF ‹k' < k›] unfolding ‹k = Suc k'›[symmetric] using that by simp
interpret ml: m_lemmas m
apply unfold_locales
using 2 by simp
define n' where "n' = (if odd m then ml.n else ml.n - 1)"
have "n' = ml.oe_n - 1"
unfolding n'_def ml.oe_n_def by simp
have "ml.n + n' = m + 1"
unfolding ml.m1 ‹n' = ml.oe_n - 1›
using Nat.add_diff_assoc[of 1 ml.oe_n ml.n]
using Nat.diff_add_assoc2[of 1 ml.n ml.oe_n]
using ml.oe_n_gt_0 ml.n_gt_0
by simp
have "ml.n ≥ 3" using 2 ml.mn by (cases "odd m"; simp)
have "ml.n ≤ 2 ^ k + 1"
using less.prems ml.mn by (cases "odd m"; simp)
note ih = ih'[OF this]
have "schoenhage_strassen_Fm_bound' m ≤ γ_0 * ml.n * 2 ^ m + schoenhage_strassen_Fm_bound' ml.n * 2 ^ n'"
unfolding n'_def
using schoenhage_strassen_Fm_bound'_oe_rec[OF ‹ml.n ≥ 3›] ml.mn
by (cases "odd m"; algebra)
also have "... ≤ γ * ml.n * 2 ^ m + (γ * k * 2 ^ (k + ml.n)) * 2 ^ n'"
apply (intro add_mono mult_le_mono order.refl ih)
apply (unfold γ_def)
apply simp
done
also have "... = γ * ml.n * 2 ^ m + γ * k * 2 ^ (k + ml.n + n')"
by (simp add: power_add)
also have "... = γ * ml.n * 2 ^ m + γ * k * 2 ^ (k + m + 1)"
using ‹ml.n + n' = m + 1› by (simp add: add.assoc)
also have "... = γ * 2 ^ m * (ml.n + k * 2 ^ (k + 1))"
by (simp add: Nat.add_mult_distrib2 power_add)
also have "... ≤ γ * 2 ^ m * (2 ^ (k + 1) + k * 2 ^ (k + 1))"
apply (intro mono_intros)
apply (estimation estimate: ‹ml.n ≤ 2 ^ k + 1›)
apply simp
done
also have "... = γ * 2 ^ m * (k + 1) * 2 ^ (k + 1)"
by (simp add: Nat.add_mult_distrib2 Nat.add_mult_distrib)
also have "... = γ * (k + 1) * 2 ^ (k + 1 + m)"
by (simp add: power_add Nat.add_mult_distrib)
finally show ?thesis by simp
qed
qed
lemma schoenhage_strassen_Fm_bound'_le_aux2:
assumes "k ≥ 1"
assumes "m ≤ 2 ^ k + 1"
shows "schoenhage_strassen_Fm_bound' m ≤ γ * k * 2 ^ (k + m)"
proof -
from assms(1) obtain k' where "k = Suc k'"
by (metis Suc_le_D numeral_nat(7))
then show ?thesis using schoenhage_strassen_Fm_bound'_le_aux1[of m k'] assms(2) by argo
qed
subsection "Multiplication in $\\mathbb{N}$"
definition schoenhage_strassen_mul_tm where
"schoenhage_strassen_mul_tm a b =1 do {
bits_a ← length_tm a ⤜ bitsize_tm;
bits_b ← length_tm b ⤜ bitsize_tm;
m' ← max_nat_tm bits_a bits_b;
m ← m' +⇩t 1;
m_plus_1 ← m +⇩t 1;
car_len ← 2 ^⇩t m_plus_1;
fill_a ← fill_tm car_len a;
fill_b ← fill_tm car_len b;
fm_result ← schoenhage_strassen_tm m fill_a fill_b;
int_lsbf_fermat.reduce_tm m fm_result
}"
lemma val_schoenhage_strassen_mul_tm[simp, val_simp]:
"val (schoenhage_strassen_mul_tm a b) = schoenhage_strassen_mul a b"
proof -
interpret schoenhage_strassen_mul_context a b .
have val_fm[val_simp]: "val (schoenhage_strassen_tm m fill_a fill_b) = schoenhage_strassen m fill_a fill_b"
apply (intro val_schoenhage_strassen_tm)
subgoal unfolding fill_a_def car_len_def
by (intro int_lsbf_fermat.fermat_non_unique_carrierI length_fill length_a')
subgoal unfolding fill_b_def car_len_def
by (intro int_lsbf_fermat.fermat_non_unique_carrierI length_fill length_b')
done
show ?thesis
unfolding schoenhage_strassen_mul_tm_def schoenhage_strassen_mul_def
unfolding val_simp Let_def int_lsbf_fermat.val_reduce_tm defs[symmetric]
by (rule refl)
qed
lemma real_power: "a > 0 ⟹ real ((a :: nat) ^ x) = real a powr real x"
using powr_realpow[of "real a" x] by simp
definition schoenhage_strassen_bound where
"schoenhage_strassen_bound n = 146 * n + 218 + 4 * (bitsize n + 1) + 126 * 2 ^ (bitsize n + 2) +
γ * bitsize (bitsize n + 1) * 2 ^ (bitsize (bitsize n + 1) + (bitsize n + 1))"
theorem time_schoenhage_strassen_mul_tm_le:
assumes "length a ≤ n" "length b ≤ n"
shows "time (schoenhage_strassen_mul_tm a b) ≤ schoenhage_strassen_bound n"
proof -
interpret schoenhage_strassen_mul_context a b .
have m_le: "m ≤ bitsize n + 1"
unfolding defs
by (intro add_mono order.refl max.boundedI bitsize_mono assms)
have m_gt_0: "m > 0" unfolding m_def by simp
have bits_a_le: "bits_a ≤ m - 1"
unfolding bits_a_def
by (intro iffD2[OF bitsize_length] length_a)
have bits_b_le: "bits_b ≤ m - 1"
unfolding bits_b_def
by (intro iffD2[OF bitsize_length] length_b)
have a_carrier: "fill_a ∈ int_lsbf_fermat.fermat_non_unique_carrier m"
unfolding fill_a_def car_len_def
by (intro int_lsbf_fermat.fermat_non_unique_carrierI length_fill length_a')
have b_carrier: "fill_b ∈ int_lsbf_fermat.fermat_non_unique_carrier m"
unfolding fill_b_def car_len_def
by (intro int_lsbf_fermat.fermat_non_unique_carrierI length_fill length_b')
have val_fm: "val (schoenhage_strassen_tm m fill_a fill_b) = schoenhage_strassen m fill_a fill_b"
by (intro val_schoenhage_strassen_tm a_carrier b_carrier)
have "time (schoenhage_strassen_mul_tm a b) = time (length_tm a) + time (bitsize_tm (length a)) + time (length_tm b) +
time (bitsize_tm (length b)) +
time (max_nat_tm bits_a bits_b) +
time (m' +⇩t 1) +
time (m +⇩t 1) +
time (2 ^⇩t (m + 1)) +
time (fill_tm car_len a) +
time (fill_tm car_len b) +
time (schoenhage_strassen_tm m fill_a fill_b) +
time (int_lsbf_fermat.reduce_tm m (schoenhage_strassen m fill_a fill_b)) +
1"
unfolding schoenhage_strassen_mul_tm_def
unfolding tm_time_simps defs[symmetric] val_length_tm val_bitsize_tm val_simps
val_max_nat_tm Let_def val_plus_nat_tm val_power_nat_tm val_fill_tm val_fm add.assoc[symmetric]
by (rule refl)
also have "... ≤ (n + 1) + (72 * n + 23) + (n + 1) +
(72 * n + 23) +
(2 * (m - 1) + 3) +
m +
(m + 1) +
12 * 2 ^ (m + 1) +
(3 * 2 ^ (m + 1) + 5) +
(3 * 2 ^ (m + 1) + 5) +
schoenhage_strassen_Fm_bound' m +
(155 + 108 * 2 ^ (m + 1)) + 1"
apply (intro add_mono order.refl)
subgoal using assms by simp
subgoal apply (estimation estimate: time_bitsize_tm_le) using assms by simp
subgoal using assms by simp
subgoal apply (estimation estimate: time_bitsize_tm_le) using assms by simp
subgoal apply (estimation estimate: time_max_nat_tm_le)
apply (estimation estimate: min.cobounded1)
apply (estimation estimate: bits_a_le)
by (rule order.refl)
subgoal by (simp add: m_def)
subgoal by simp
subgoal apply (estimation estimate: time_power_nat_tm_2_le)
unfolding defs[symmetric] by (rule order.refl)
subgoal apply (estimation estimate: time_fill_tm_le)
apply (estimation estimate: length_a')
unfolding defs[symmetric] by simp
subgoal apply (estimation estimate: time_fill_tm_le)
apply (estimation estimate: length_b')
unfolding defs[symmetric] by simp
subgoal
apply (estimation estimate: time_schoenhage_strassen_tm_le[OF a_carrier b_carrier])
apply (estimation estimate: schoenhage_strassen_Fm_bound_le_schoenhage_strassen_Fm_bound')
by (rule order.refl)
subgoal
apply (estimation estimate: int_lsbf_fermat.time_reduce_tm_le)
unfolding int_lsbf_fermat.fermat_carrier_length[OF conjunct2[OF schoenhage_strassen_correct'[OF a_carrier b_carrier]]]
by simp
done
also have "... = 146 * n + 218 +
2 * (m - 1) + 2 * m + 126 * 2 ^ (m + 1) + schoenhage_strassen_Fm_bound' m"
by simp
also have "... ≤ 146 * n + 218 +
4 * m + 126 * 2 ^ (m + 1) + schoenhage_strassen_Fm_bound' m"
by simp
also have "... ≤ 146 * n + 218 +
4 * m + 126 * 2 ^ (m + 1) + (γ * bitsize m * 2 ^ (bitsize m + m))"
apply (intro add_mono order.refl schoenhage_strassen_Fm_bound'_le_aux2)
subgoal using bitsize_zero_iff[of m] iffD2[OF neq0_conv m_gt_0] by simp
subgoal using iffD1[OF bitsize_length order.refl[of "bitsize m"]]
by simp
done
also have "... ≤ 146 * n + 218 + 4 * (bitsize n + 1) + 126 * 2 ^ (bitsize n + 2) +
γ * bitsize (bitsize n + 1) * 2 ^ (bitsize (bitsize n + 1) + (bitsize n + 1))"
apply (estimation estimate: m_le)
by (intro bitsize_mono m_le order.refl)+ simp
finally show ?thesis unfolding schoenhage_strassen_bound_def[symmetric] .
qed
lemma real_diff: "a ≤ b ⟹ real (b - a) = real b - real a"
by simp
lemma bitsize_le_log: "n > 0 ⟹ real (bitsize n) ≤ log 2 (real n) + 1"
proof -
assume "n > 0"
then have "bitsize n > 0" using bitsize_zero_iff[of n] by simp
then have "¬ (bitsize n ≤ bitsize n - 1)" by simp
then have "n ≥ 2 ^ (bitsize n - 1)" using bitsize_length[of n "bitsize n - 1"] by simp
then have "log 2 (real n) ≥ real (bitsize n - 1)"
using le_log2_of_power by simp
then show ?thesis by simp
qed
lemma powr_mono_base2: "a ≤ b ⟹ 2 powr (a :: real) ≤ 2 powr b"
by (intro powr_mono; simp)
lemma log_mono_base2: "a > 0 ⟹ b > 0 ⟹ a ≤ b ⟹ log 2 a ≤ log 2 b"
using log_le_cancel_iff[of 2 a b] by simp
lemma log_nonneg_base2: "x ≥ 1 ⟹ log 2 x ≥ 0"
using zero_le_log_cancel_iff[of 2 x] by simp
lemma powr_log_cancel_base2: "x > 0 ⟹ 2 powr (log 2 x) = x"
by (intro powr_log_cancel; simp)
lemma const_bigo_log: "1 ∈ O(log 2)"
proof -
have 0: "log 2 x ≥ 1" if "x ≥ 2" for x
using log_mono_base2[of 2 x] that by simp
show ?thesis apply (intro landau_o.bigI[where c = 1])
subgoal by simp
subgoal unfolding eventually_at_top_linorder using 0 by fastforce
done
qed
lemma const_bigo_log_log: "1 ∈ O(λx. log 2 (log 2 x))"
proof -
have "log 2 4 = 2"
by (metis log2_of_power_eq mult_2 numeral_Bit0 of_nat_numeral power2_eq_square)
then have 0: "log 2 x ≥ 2" if "x ≥ 4" for x
using log_mono_base2[of 4 x] that by simp
have 1: "log 2 (log 2 x) ≥ 1" if "x ≥ 4" for x
using log_mono_base2[of 2 "log 2 x"] that 0[OF that] by simp
show ?thesis apply (intro landau_o.bigI[where c = 1])
subgoal by simp
subgoal unfolding eventually_at_top_linorder using 1 by fastforce
done
qed
theorem schoenhage_strassen_bound_bigo: "schoenhage_strassen_bound ∈ O(λn. n * log 2 n * log 2 (log 2 n))"
proof -
define explicit_bound where "explicit_bound = (λx. 1154 * x + 226 + 4 * log 2 x + (real γ * 24) * x * log 2 x * log 2 (log 2 x) + (real γ * 24 * (1 + log 2 3)) * x * log 2 x)"
have le: "real (schoenhage_strassen_bound n) ≤ explicit_bound (real n)" if "n ≥ 2" for n
proof -
have "(2::nat) > 0" by simp
from that have "n ≥ 1" "n > 0" by simp_all
have 0: "bitsize n + 1 > 0" by simp
define x where "x = real n"
then have "x ≥ 2" "x ≥ 1" "x > 0" using ‹n ≥ 2› ‹n ≥ 1› ‹n > 0› by simp_all
have log_ge: "log 2 x ≥ 1" using log_mono_base2[of 2 x] using ‹x ≥ 2› by simp
then have log_log_ge: "log 2 (log 2 x) ≥ 0" and "log 2 x > 0" by simp_all
have log_n: "real (bitsize n) ≤ log 2 x + 1"
unfolding x_def by (rule bitsize_le_log[OF ‹n > 0›])
have log_log_n: "real (bitsize (bitsize n + 1)) ≤ log 2 (log 2 x) + (1 + log 2 3)"
proof -
have "real (bitsize (bitsize n + 1)) ≤ log 2 (real (bitsize n + 1)) + 1"
apply (intro bitsize_le_log) by simp
also have "... = log 2 (real (bitsize n) + 1) + 1"
unfolding real_linear by simp
also have "... ≤ log 2 (log 2 (real n) + 1 + 1) + 1"
apply (intro add_mono order.refl log_mono_base2 bitsize_le_log ‹n > 0›)
subgoal by simp
subgoal using log_nonneg_base2[of "real n"] ‹n ≥ 1› by linarith
done
also have "... = log 2 (log 2 x + 2 * 1) + 1" unfolding x_def by argo
also have "... ≤ log 2 (log 2 x + 2 * log 2 x) + 1"
apply (intro add_mono order.refl log_mono_base2 mult_mono)
using log_ge by simp_all
also have "... = log 2 (3 * log 2 x) + 1" by simp
also have "... = (log 2 3 + log 2 (log 2 x)) + 1"
apply (intro arg_cong2[where f = "(+)"] refl log_mult)
using log_ge by simp_all
also have "... = log 2 (log 2 x) + (1 + log 2 3)" by simp
finally show ?thesis .
qed
have 1: "0 ≤ log 2 (log 2 x) + (1 + log 2 3)"
using log_log_ge by simp
have "real (schoenhage_strassen_bound n) = 146 * x + 218 + 4 * (real (bitsize n) + 1) + 126 * 2 powr (real (bitsize n) + 2) +
real γ * real (bitsize (bitsize n + 1)) * 2 powr (real (bitsize (bitsize n + 1)) + (real (bitsize n) + 1))"
unfolding schoenhage_strassen_bound_def real_linear real_multiplicative x_def real_power[OF ‹2 > 0›]
by (intro_cong "[cong_tag_2 (+), cong_tag_2 (*), cong_tag_2 (powr)]" more: refl; simp)
also have "... ≤ 146 * x + 218 + 4 * ((log 2 x + 1) + 1) + 126 * 2 powr ((log 2 x + 1) + 2) +
real γ * (log 2 (log 2 x) + (1 + log 2 3)) * 2 powr ((log 2 (log 2 x) + (1 + log 2 3)) + ((log 2 x + 1) + 1))"
apply (intro add_mono mult_mono order.refl powr_mono_base2 log_n log_log_n mult_nonneg_nonneg 1)
unfolding x_def by simp_all
also have "... = 1154 * x + (226 + 4 * log 2 x) + real γ * (log 2 (log 2 x) + (1 + log 2 3)) * (24 * (log 2 x * x))"
unfolding powr_add powr_log_cancel_base2[OF ‹x > 0›] powr_log_cancel_base2[OF ‹log 2 x > 0›] by simp
also have "... = 1154 * x + 226 + 4 * log 2 x + (real γ * 24) * x * log 2 x * log 2 (log 2 x) + (real γ * 24 * (1 + log 2 3)) * x * log 2 x"
unfolding distrib_left distrib_right add.assoc[symmetric] mult.assoc[symmetric] by simp
also have "... = explicit_bound x"
unfolding explicit_bound_def by (rule refl)
finally show ?thesis unfolding x_def .
qed
have le_bigo: "schoenhage_strassen_bound ∈ O(explicit_bound)"
apply (intro landau_o.bigI[where c = 1])
subgoal by simp
subgoal unfolding eventually_at_top_linorder using le by fastforce
done
have bigo: "explicit_bound ∈ O(λn. n * log 2 n * log 2 (log 2 n))"
unfolding explicit_bound_def
apply (intro sum_in_bigo(1))
subgoal
proof -
have "(*) 1154 ∈ O(λx. x)" by simp
moreover have "1 ∈ O(λx. log 2 x)" by (rule const_bigo_log)
moreover have "1 ∈ O(λx. log 2 (log 2 x))" by (rule const_bigo_log_log)
ultimately show ?thesis using landau_o.big_mult[of 1 _ _ 1] by auto
qed
subgoal
proof -
have a: "(λx. 225) ∈ O(λx. x :: real)" by simp
have b: "1 ∈ O(λx. log 2 x)" by (rule const_bigo_log)
have c: "(λx. 225) ∈ O(λx. x * log 2 x)"
using landau_o.big_mult[OF a b] by simp
have d: "1 ∈ O(λx. log 2 (log 2 x))" by (rule const_bigo_log_log)
show ?thesis using landau_o.big_mult[OF c d] by simp
qed
subgoal
proof -
have a: "(λx. 4) ∈ O(λx. x :: real)" by simp
have b: "(λx. 4 * log 2 x) ∈ O(λx. x * log 2 x)"
by (rule landau_o.big.mult_right[OF a])
have c: "1 ∈ O(λx. log 2 (log 2 x))" by (rule const_bigo_log_log)
show ?thesis using landau_o.big_mult[OF b c] by simp
qed
subgoal
proof -
have a: "(λx. real γ * 24 * x) ∈ O(λx. x :: real)" by simp
show ?thesis by (intro landau_o.big.mult_right a)
qed
subgoal
proof -
have a: "(λx. real γ * 24 * (1 + log 2 3) * x) ∈ O(λx. x :: real)" by simp
have b: "(λx. real γ * 24 * (1 + log 2 3) * x * log 2 x) ∈ O(λx. x * log 2 x)"
by (intro landau_o.big.mult_right a)
show ?thesis using landau_o.big_mult[OF b const_bigo_log_log] by simp
qed
done
show ?thesis using bigo landau_o.big_trans[OF le_bigo] by blast
qed
end