Theory Z_mod_Fermat

subsection "Representing $\\mathbb{Z}_{F_n}$"

theory Z_mod_Fermat
  imports
    Z_mod_power_of_2
    "../NTT_Rings/FNTT_Rings"
    "../Preliminaries/Schoenhage_Strassen_Preliminaries"
    "Karatsuba.Estimation_Method"
begin

lemma to_nat_replicate_True2:
  assumes "Nat_LSBF.to_nat xs = 2 ^ (length xs) - 1"
  shows "xs = replicate (length xs) True"
proof (intro iffD2[OF list_is_replicate_iff], rule ccontr)
  assume "¬ (i{0..<length xs}. xs ! i = True)"
  then obtain i where "i < length xs" "xs ! i = False" by auto
  then obtain xs1 xs2 where "xs = xs1 @ False # xs2"
    by (metis(full_types) id_take_nth_drop)
  then have "Nat_LSBF.to_nat xs < Nat_LSBF.to_nat (xs1 @ True # xs2)"
    using change_bit_ineq[of xs1 xs1 xs2] by argo
  also have "...  2 ^ (length (xs1 @ True # xs2)) - 1"
    by (intro to_nat_length_upper_bound)
  also have "... = 2 ^ (length xs) - 1"
    using xs = xs1 @ False # xs2 by simp
  finally show False using assms by simp
qed

lemma residue_ring_pow: "n > 1  a [^]residue_ring nb = (a ^ b) mod n"
  by (induction b) (simp_all add: residue_ring_def mod_mult_right_eq mult.commute)

lemma (in residues) pow_nat_eq:
"a [^]R(n :: nat) = a ^ n mod m"
  using R_m_def m_gt_one residue_ring_pow by blast

locale int_lsbf_fermat =
  fixes k :: nat
begin

abbreviation n where "n  (2::nat) ^ (2 ^ k) + 1"

lemma n_positive[simp]: "n > 0" by simp
lemma n_gt_1[simp]: "n > 1" by simp
lemma n_gt_2[simp]: "n > 2"
  by (metis add_less_mono1 nat_1_add_1 one_less_numeral_iff one_less_power pos2 semiring_norm(76) zero_less_power)  

definition Fn where "Fn  residue_ring (int n)"

sublocale residues n Fn
  apply unfold_locales
  subgoal by simp
  subgoal by (rule Fn_def)
  done

definition fermat_non_unique_carrier where
"fermat_non_unique_carrier  {xs :: nat_lsbf. length xs = 2 ^ (k + 1)}"

lemma fermat_non_unique_carrierI[intro]:
"length xs = 2 ^ (k + 1)  xs  fermat_non_unique_carrier"
  unfolding fermat_non_unique_carrier_def by simp

lemma fermat_non_unique_carrierE[elim]:
  "xs  fermat_non_unique_carrier  (length xs = 2 ^ (k + 1)  P)  P"
  unfolding fermat_non_unique_carrier_def by simp

lemma two_pow_half_carrier_length[simp]: "(int 2 ^ (2 ^ k)) mod n = -1 mod n"
  apply simp
  using zmod_minus1[of "int n"] n_positive
  by (metis add_diff_cancel_left' diff_eq_eq of_nat_0_less_iff of_nat_numeral pos2 zero_less_power zless_add1_eq zmod_minus1)

lemma two_pow_half_carrier_length_neq_1: "2 ^ (2 ^ k) mod n  1"
  by simp

lemma two_pow_carrier_length[simp]: "(2::nat) ^ (2 ^ (k + 1)) mod n = 1"
proof -
  have "int 2 ^ (2 ^ (k + 1)) mod n = 1"
  proof -
    have "int 2 ^ (2 ^ (k + 1)) mod n = ((int 2) ^ (2 * 2 ^ k)) mod n"
      by simp
    also have "... = ((int 2) ^ (2 ^ k)) ^ 2 mod n"
      using power_mult[of "int 2" "2 ^ k" 2]
      by (simp add: mult.commute)
    also have "... = (int 2 ^ (2 ^ k) * int 2 ^ (2 ^ k)) mod n"
      by (simp add: power2_eq_square)
    also have "... = (((int 2 ^ (2 ^ k)) mod n) * ((int 2 ^ (2 ^ k)) mod n)) mod n"
      by simp
    also have "(int 2 ^ (2 ^ k)) mod n = -1 mod n"
      using two_pow_half_carrier_length .
    finally have "int 2 ^ (2 ^ (k + 1)) mod n = int 1 mod n"
      by (simp add: mod_simps)
    thus ?thesis by simp
  qed
  then show ?thesis
    by (metis int_ops(2) of_nat_eq_iff of_nat_power zmod_int)
qed

lemma two_pow_half_carrier_length_residue_ring[simp]:
"(2::int) [^]Fn(2::nat) ^ k = Fn𝟭Fn⇙"
proof -
  have "(2::int) [^]Fn(2::nat) ^ k = (2::int) ^ ((2::nat) ^ k) mod n"
    by (intro pow_nat_eq)
  also have "... = - 1 mod n" using two_pow_half_carrier_length by simp
  also have "... = Fn𝟭Fn⇙"
    using res_neg_eq res_one_eq by algebra
  finally show ?thesis .
qed

lemma two_pow_carrier_length_residue_ring[simp]:
  "(2::int) [^]Fn(2::nat) ^ (k + 1) = 𝟭Fn⇙"
proof -
  have "(2::int) [^]Fn(2::nat) ^ (k + 1) = (2::int) ^ ((2::nat) ^ (k + 1)) mod n"
    by (intro pow_nat_eq)
  also have "... = 1" using two_pow_carrier_length zmod_int
    by (metis int_exp_hom int_ops(2) int_ops(3))
  also have "... = 𝟭Fn⇙" by (simp only: res_one_eq)
  finally show ?thesis .
qed

corollary two_is_unit: "2  Units Fn"
  apply (intro pow_one_imp_unit[of "2 ^ (k + 1)"])
  subgoal by simp
  subgoal using res_carrier_eq by (simp add: self_le_power)
  subgoal using two_pow_carrier_length_residue_ring .
  done

corollary two_in_carrier: "2  carrier Fn"
  using Units_closed[OF two_is_unit] .

lemma nat_mod_eqE: "(a::nat) mod m = b mod m  i j. a + i * m = b + j * m"
proof -
  assume "a mod m = b mod m"
  then have "int a mod int m = int b mod int m" using zmod_int by metis
  then obtain l where "int a = int b + l * int m" by (metis mod_eqE mult.commute)
  define i j where "i = (if l  0 then 0 else nat (- l))" "j = (if l  0 then nat l else 0)"
  then have "int a + int i * int m = int b + int j * int m"
    using int a = int b + l * int m by simp
  then have "a + i * m = b + j * m" by (metis int_ops(7) nat_int_add)
  then show ?thesis by blast
qed

corollary pow_mod_carrier_length:
  assumes "(a::nat) mod 2 ^ (k + 1) = b mod 2 ^ (k + 1)"
  shows "2 [^]Fna = 2 [^]Fnb"
proof -
  from assms obtain i j where 0: "a + i * 2 ^ (k + 1) = b + j * 2 ^ (k + 1)"
    using nat_mod_eqE by blast
  have "2 [^]Fna = 2 [^]Fna Fn(2 [^]Fn((2::nat) ^ (k + 1))) [^]Fni"
    using two_pow_carrier_length_residue_ring two_in_carrier nat_pow_closed
    using nat_pow_one by algebra
  also have "... = 2 [^]Fn(a + i * 2 ^ (k + 1))"
    using nat_pow_pow nat_pow_mult two_in_carrier
    using mult.commute by metis
  also have "... = 2 [^]Fn(b + j * 2 ^ (k + 1))"
    using 0 by argo
  also have "... = 2 [^]Fnb Fn(2 [^]Fn((2::nat) ^ (k + 1))) [^]Fnj"
    using nat_pow_pow nat_pow_mult two_in_carrier
    using mult.commute by metis
  also have "... = 2 [^]Fnb"
    using two_pow_carrier_length_residue_ring two_in_carrier nat_pow_closed
    using nat_pow_one by algebra
  finally show ?thesis .
qed
lemma two_powers_trivial:
  assumes "s  2 ^ k"
  shows "2 [^]Fns = 2 ^ s"
proof -
  from assms have "2 ^ s  int n - 1" by simp
  then have "2 ^ s < int n" using n_positive by linarith
  then have "2 ^ s = 2 ^ s mod int n" by simp
  also have "... = 2 [^]Fns" using pow_nat_eq by simp
  finally show ?thesis by argo
qed

lemma two_powers_Units:
  assumes "s  2 ^ k"
  shows "2 ^ s  Units Fn"
  unfolding two_powers_trivial[OF assms, symmetric]
  by (intro Units_pow_closed two_is_unit)
corollary two_powers_in_carrier:
  assumes "s  2 ^ k"
  shows "2 ^ s  carrier Fn"
  using assms two_powers_Units Units_closed by simp

lemma two_powers_half_carrier_length_residue_ring[simp]:
  assumes "i + s = k"
  shows "(2 ^ 2 ^ i) [^]Fn(2::nat) ^ s = Fn𝟭Fn⇙"
proof -
  from assms have "i  k" by simp
  then have "(2 ^ 2 ^ i) [^]Fn(2::nat) ^ s =
    (2 [^]Fn((2::nat) ^ i)) [^]Fn(2::nat) ^ s"
    using two_powers_trivial[of "2 ^ i", symmetric] by simp
  also have "... = 2 [^]Fn((2::nat) ^ (i + s))" 
    using monoid.nat_pow_pow[OF _ two_in_carrier] cring
    using power_add[symmetric, of "2::nat" i s]
    using monoid_axioms by auto
  also have "... = Fn𝟭Fn⇙"
    using i + s = k two_pow_half_carrier_length_residue_ring by argo
  finally show ?thesis .
qed

interpretation z_mod_fermat_unit_group: group "units_of Fn"
  by (rule units_group)


lemma inv_of_2[simp]:
  "invFn2 = 2 [^]Fn((2::nat) ^ (k + 1) - 1)"
proof -
  have "𝟭Fn= 2 Fn2 [^]Fn((2::nat) ^ (k + 1) - 1)"
    by (metis two_is_unit two_pow_carrier_length_residue_ring Units_closed Units_r_inv inv_root_of_unity root_of_unityI zero_less_numeral zero_less_power)
  moreover have "𝟭Fn= 2 [^]Fn((2::nat) ^ (k + 1) - 1) Fn2"
   by (metis two_is_unit two_pow_carrier_length_residue_ring Units_closed Units_l_inv inv_root_of_unity root_of_unityI zero_less_numeral zero_less_power)
  ultimately show "invFn2 = 2 [^]Fn((2::nat) ^ (k + 1) - 1)"
    using less_2_cases_iff two_pow_carrier_length_residue_ring two_in_carrier inv_root_of_unity root_of_unityI by presburger
qed

lemma inv_of_2_powers:
  assumes "s  2 ^ k"
  shows "invFn(2 ^ s) = 2 [^]Fn(2 ^ (k + 1) - s)"
proof (cases "s = 0")
  case True
  then show ?thesis
    using inv_one res_one_eq
    using two_pow_carrier_length_residue_ring
    by simp
next
  case False
  then have "s > 0" by simp
  interpret m : multiplicative_subgroup Fn "Units Fn" "units_of Fn"
    apply unfold_locales
    subgoal by simp
    subgoal by (simp add: units_of_def)
    done
  have "invFn(2 ^ s) = invFn(2 [^]Fns)"
    using two_powers_trivial[OF s  2 ^ k] by argo
  also have "... = (invFn2) [^]Fns"
    using two_is_unit group.nat_pow_inv[OF m.M_group] m.inv_eq m.M_group m.carrier_M
    using m.nat_pow_eq Units_pow_closed by algebra
  also have "... = (2 [^]Fn((2::nat) ^ (k + 1) - 1)) [^]Fns"
    using inv_of_2
    by argo
  also have "... = 2 [^]Fn(((2::nat) ^ (k + 1) - 1) * s)"
    using two_in_carrier nat_pow_pow by presburger
  also have "((2::nat) ^ (k + 1) - 1) * s = (2::nat) ^ (k + 1) * s - s"
    using diff_mult_distrib by simp
  also have "... = 2 ^ (k + 1) * (s - 1) + 2 ^ (k + 1) - s"
    using s > 0 by (metis add.commute mult.commute mult_eq_if zero_less_iff_neq_zero)
  also have "... = 2 ^ (k + 1) * (s - 1) + (2 ^ (k + 1) - s)"
    apply (intro diff_add_assoc) using assms by simp
  also have "2 [^]Fn(2 ^ (k + 1) * (s - 1) + (2 ^ (k + 1) - s)) =
    2 [^]Fn(2 ^ (k + 1) - s)"
    apply (intro pow_mod_carrier_length) by simp
  finally show ?thesis .
qed

lemma inv_pow_mod_carrier_length:
  assumes "(a::nat) mod 2 ^ (k + 1) = b mod 2 ^ (k + 1)"
  shows "(invFn2) [^]Fna = (invFn2) [^]Fnb"
  unfolding inv_of_2 nat_pow_pow[OF two_in_carrier]
  apply (intro pow_mod_carrier_length)
  using assms mod_mult_cong by blast

lemma
  assumes "m > 0"
  shows "i j. (a::nat) = j + i * m  j < m"
  using mod_div_mult_eq[of a m, symmetric] pos_mod_bound[of m a] assms mod_less_divisor 
  by blast

corollary two_powers: "(2::nat) ^ a mod n = (2::nat) ^ (a mod (2 ^ (k + 1))) mod n"
proof -
  define i where "i = a mod 2 ^ (k + 1)"
  define j where "j = a div 2 ^ (k + 1)"
  have "a = i + j * 2 ^ (k + 1)" using mod_div_mult_eq[of a "2 ^ (k + 1)"] i_def j_def
    by simp
  hence "(2::nat) ^ a mod n = 2 ^ i * (2 ^ (2 ^ (k + 1))) ^ j mod n"
    using power_add[of "2::nat" i "j * 2 ^ (k + 1)"]
    using power_mult[of "2::nat" "2 ^ (k + 1)" j]
    using mult.commute[of j "2 ^ (k + 1)"]
    by argo
  also have "... = 2 ^ i * ((2 ^ (2 ^ (k + 1))) ^ j mod n) mod n"
    using mod_mult_right_eq by metis
  also have "... = 2 ^ i * ((2 ^ (2 ^ (k + 1)) mod n) ^ j mod n) mod n"
    using power_mod by metis
  also have "... = 2 ^ i * ((1::nat) ^ j mod n) mod n"
    using two_pow_carrier_length by simp
  also have "... = 2 ^ i mod n" by simp
  finally show ?thesis using i_def by simp
qed

lemma fermat_carrier_length[simp]: "xs  fermat_non_unique_carrier  length xs = 2 ^ (k + 1)"
  unfolding fermat_non_unique_carrier_def by simp

fun to_residue_ring :: "nat_lsbf  int" where
"to_residue_ring xs = int (Nat_LSBF.to_nat xs) mod int n"
fun from_residue_ring :: "int  nat_lsbf" where
"from_residue_ring x = fill (2 ^ (k + 1)) (Nat_LSBF.from_nat (nat x))"

lemma to_residue_ring_in_carrier[simp]: "to_residue_ring xs  carrier Fn"
  using zmod_int[of _ n, symmetric]
  by (simp add: res_carrier_eq)

lemma to_residue_ring_eq_to_nat: "Nat_LSBF.to_nat xs < n  to_residue_ring xs = int (Nat_LSBF.to_nat xs)"
  using zmod_int
  by (metis to_residue_ring.simps mod_less)

(* calculates a * 2 ^ k, where a is in the fermat carrier, k a natural number. *)
definition multiply_with_power_of_2 :: "nat_lsbf  nat  nat_lsbf" where
"multiply_with_power_of_2 xs m = rotate_right m xs"

definition divide_by_power_of_2 :: "nat_lsbf  nat  nat_lsbf" where
"divide_by_power_of_2 xs m = rotate_left m xs"

lemma length_multiply_with_power_of_2[simp]: "length (multiply_with_power_of_2 xs m) = length xs"
  unfolding multiply_with_power_of_2_def by simp

lemma length_divide_by_power_of_2[simp]: "length (divide_by_power_of_2 xs m) = length xs"
  unfolding divide_by_power_of_2_def by simp

lemma (in euclidean_semiring_cancel) sum_list_mod: "(i  xs. (f i mod m)) mod m = (i  xs. f i) mod m"
proof (induction xs)
  case Nil
  then show ?case by simp
next
  case (Cons a xs)
  have "(i  (a # xs). f i) mod m = (f a + (i  xs. f i)) mod m"
    by simp
  also have "... = (f a mod m + (i  xs. f i) mod m) mod m"
    using mod_add_eq[symmetric, of "f a"] by simp
  also have "... = (f a mod m + (i  xs. f i mod m) mod m) mod m"
    using Cons.IH by argo
  also have "... = (f a mod m + (i  xs. f i mod m)) mod m"
    using mod_add_right_eq by blast
  also have "... = (i  (a # xs). f i mod m) mod m"
    by simp
  finally show ?case by argo
qed

lemma (in euclidean_semiring_cancel) sum_list_mod':
  assumes "i. i  set xs  f i mod m = g i mod m"
  shows "(i  xs. f i) mod m = (i  xs. g i) mod m"
proof -
  have "(i  xs. f i) mod m = (i  xs. f i mod m) mod m"
    by (intro sum_list_mod[symmetric])
  also have "... = (i  xs. g i mod m) mod m"
    apply (intro_cong "[cong_tag_1 (λi. i mod m)]")
    apply (intro_cong "[cong_tag_1 sum_list]" more: map_cong refl)
    using assms by assumption
  also have "... = (i  xs. g i) mod m"
    by (intro sum_list_mod)
  finally show ?thesis .
qed

lemma multiply_with_power_of_2_correct': "xs  fermat_non_unique_carrier  Nat_LSBF.to_nat (multiply_with_power_of_2 xs m) mod n = Nat_LSBF.to_nat xs * 2 ^ m mod n  multiply_with_power_of_2 xs m  fermat_non_unique_carrier"
proof (intro conjI)
  assume "xs  fermat_non_unique_carrier"
  then have length_xs: "length xs = 2 ^ (k + 1)" by simp
  then have "length xs > 0" by simp

  let ?m = "length xs - m mod length xs"

  define ys zs where "ys = take ?m xs" "zs = drop ?m xs"
  then have "xs = ys @ zs"
    and length_ys: "length ys = ?m"
    and length_zs: "length zs = m mod length xs"
    using length xs = 2 ^ (k + 1) by simp_all
  
  have 1: "Nat_LSBF.to_nat xs = Nat_LSBF.to_nat ys + 2 ^ ?m * Nat_LSBF.to_nat zs" (is "_ = ?y + _ * ?z")
    apply (unfold xs = ys @ zs to_nat_app)
    apply (unfold xs = ys @ zs[symmetric] length_ys)
    apply (rule refl)
    done

  have 2: "multiply_with_power_of_2 xs m = zs @ ys"
  proof -
    have "multiply_with_power_of_2 xs m = rotate_right (m mod length xs) xs"
      unfolding multiply_with_power_of_2_def
      by (rule rotate_right_conv_mod)
    also have "... = rotate_right (length zs) (ys @ zs)"
      using xs = ys @ zs length_zs by simp
    also have "... = zs @ ys"
      by (rule rotate_right_append)
    finally show ?thesis .
  qed
  then have 3: "Nat_LSBF.to_nat (multiply_with_power_of_2 xs m)
    = ?z + 2 ^ (m mod length xs) * ?y"
    by (simp add: to_nat_app length_zs)

  from 1 have "Nat_LSBF.to_nat xs * 2 ^ m mod n = (?y + 2 ^ ?m * ?z) * 2 ^ m mod n"
    by argo
  also have "... = (?y + 2 ^ ?m * ?z) * (2 ^ m mod n) mod n"
    by (simp add: mod_simps)
  also have "... = (?y + 2 ^ ?m * ?z) * (2 ^ (m mod length xs) mod n) mod n"
    using length_xs two_powers by algebra
  also have "... = (?y + 2 ^ ?m * ?z) * 2 ^ (m mod length xs) mod n"
    by (simp add: mod_simps)
  also have "... = (?y * 2 ^ (m mod length xs) + 2 ^ (?m + (m mod length xs)) * ?z) mod n"
    by (simp add: algebra_simps power_add)
  also have "... = (?y * 2 ^ (m mod length xs) + 2 ^ length xs * ?z) mod n"
    by (simp add: length_xs)
  also have "... = (?y * 2 ^ (m mod length xs) + (2 ^ length xs mod n) * ?z mod n) mod n"
    by (simp add: mod_simps)
  also have "... = (?y * 2 ^ (m mod length xs) + 1 * ?z mod n) mod n"
    by (simp only: length_xs two_pow_carrier_length)
  also have "... = (?z + 2 ^ (m mod length xs) * ?y) mod n"
    by (simp add: mod_simps algebra_simps)
  also have "... = Nat_LSBF.to_nat (multiply_with_power_of_2 xs m) mod n"
    using 3 by argo
  finally show "Nat_LSBF.to_nat (multiply_with_power_of_2 xs m) mod n = Nat_LSBF.to_nat xs * 2 ^ m mod n"
    by argo

  have "length (multiply_with_power_of_2 xs m) = length xs"
    using 2 xs = ys @ zs by simp
  then show "multiply_with_power_of_2 xs m  fermat_non_unique_carrier"
    apply (intro fermat_non_unique_carrierI)
    using length_xs by argo
qed

corollary multiply_with_power_of_2_closed:
  assumes "xs  fermat_non_unique_carrier"
  shows "multiply_with_power_of_2 xs m  fermat_non_unique_carrier"
  by (intro conjunct2[OF multiply_with_power_of_2_correct'] assms)

corollary multiply_with_power_of_2_correct:
  assumes "xs  fermat_non_unique_carrier"
  shows "to_residue_ring (multiply_with_power_of_2 xs m) = to_residue_ring xs Fn2 [^]Fnm"
proof -
  have "to_residue_ring (multiply_with_power_of_2 xs m)
       = int (Nat_LSBF.to_nat (multiply_with_power_of_2 xs m) mod n)"
    using zmod_int by simp
  also have "... = int (Nat_LSBF.to_nat xs * 2 ^ m mod n)"
    using multiply_with_power_of_2_correct'[OF assms] by simp
  also have "... = (int (Nat_LSBF.to_nat xs)) * (2 ^ m) mod int n"
    using zmod_int by simp
  also have "... = (int (Nat_LSBF.to_nat xs) mod int n) * ((2 ^ m) mod int n) mod int n"
    by (simp add: mod_mult_eq)
  also have "... = (to_residue_ring xs) Fn((2 ^ m) mod int n)"
    using res_mult_eq by simp
  also have "(2 ^ m) mod int n = 2 [^]Fnm"
    using pow_nat_eq by simp
  finally show ?thesis .
qed

lemma
  assumes "xs  fermat_non_unique_carrier"
  shows divide_by_power_of_2_correct: "to_residue_ring (divide_by_power_of_2 xs m) = to_residue_ring xs Fn(invFn2) [^]Fnm"
    and divide_by_power_of_2_closed: "divide_by_power_of_2 xs m  fermat_non_unique_carrier"
  unfolding atomize_conj
proof (intro conjI)
  from assms show c: "divide_by_power_of_2 xs m  fermat_non_unique_carrier"
    unfolding fermat_non_unique_carrier_def by simp
  define divxs where "divxs = divide_by_power_of_2 xs m"
  define mulxs where "mulxs = multiply_with_power_of_2 xs m"

  have "multiply_with_power_of_2 divxs m = xs"
    unfolding divxs_def multiply_with_power_of_2_def divide_by_power_of_2_def by simp
  then have "to_residue_ring xs = to_residue_ring (multiply_with_power_of_2 divxs m)"
    by simp
  also have "... = to_residue_ring divxs Fn2 [^]Fnm"
    apply (intro multiply_with_power_of_2_correct)
    unfolding divxs_def by (rule c)
  finally have "to_residue_ring xs Fn(invFn2) [^]Fnm = to_residue_ring divxs Fn2 [^]Fnm Fn(invFn2) [^]Fnm"
    by simp
  also have "... = to_residue_ring divxs Fn(2 [^]Fnm Fn(invFn2) [^]Fnm)"
    apply (intro m_assoc to_residue_ring_in_carrier nat_pow_closed two_in_carrier)
    using two_is_unit by auto
  also have "(2 [^]Fnm Fn(invFn2) [^]Fnm) = (2 Fn(invFn2)) [^]Fnm"
    apply (intro pow_mult_distrib[symmetric] m_comm two_in_carrier)
    using two_is_unit by auto
  also have "... = 𝟭Fn[^]Fnm"
    by (intro arg_cong2[where f = "([^]Fn)"] refl Units_r_inv two_is_unit)
  also have "... = 𝟭Fn⇙" by simp
  also have "to_residue_ring divxs Fn𝟭Fn= to_residue_ring divxs"
    by (intro r_one to_residue_ring_in_carrier)
  finally show "to_residue_ring divxs = to_residue_ring xs FninvFn2 [^]Fnm" by simp
qed

definition add_fermat where
"add_fermat xs ys = (let zs = add_nat xs ys in if length zs = 2 ^ (k + 1) + 1 then inc_nat (butlast zs) else zs)"

lemma add_fermat_correct':
  assumes "xs  fermat_non_unique_carrier"
  assumes "ys  fermat_non_unique_carrier"
  shows "add_fermat xs ys  fermat_non_unique_carrier  Nat_LSBF.to_nat (add_fermat xs ys) mod n = (Nat_LSBF.to_nat xs + Nat_LSBF.to_nat ys) mod n"
proof -
  define zs where "zs = add_nat xs ys"
  show ?thesis
  proof (cases "length zs = 2 ^ (k + 1) + 1")
    case True
    then have "add_fermat xs ys = inc_nat (butlast zs)"
      using zs_def unfolding add_fermat_def by simp
    then have 1: "Nat_LSBF.to_nat (add_fermat xs ys) = 1 + Nat_LSBF.to_nat (butlast zs)" by (simp add: inc_nat_correct)
    from True obtain zs' where "zs = zs' @ [True]"
      using add_nat_last_bit_True assms zs_def by fastforce
    then have "butlast zs = zs'" by simp
    then have "Nat_LSBF.to_nat (add_fermat xs ys) = 1 + Nat_LSBF.to_nat zs'" using 1 by simp
    moreover have "Nat_LSBF.to_nat zs = Nat_LSBF.to_nat zs' + 2 ^ (2 ^ (k + 1))"
      using zs = zs' @ [True] True by (simp add: to_nat_app)
    hence "Nat_LSBF.to_nat zs mod n = (Nat_LSBF.to_nat zs' + 1) mod n"
      using two_pow_carrier_length by (metis mod_add_right_eq)
    ultimately have 2: "Nat_LSBF.to_nat (add_fermat xs ys) mod n = (Nat_LSBF.to_nat xs + Nat_LSBF.to_nat ys) mod n"
      using add_nat_correct[of xs ys] zs_def by auto

    have "length zs' = 2 ^ (k + 1)" using True zs = zs' @ [True] by simp

    have "Nat_LSBF.to_nat zs = Nat_LSBF.to_nat xs + Nat_LSBF.to_nat ys" using zs_def by (simp add: add_nat_correct)
    also have "...  (2 ^ length xs - 1) + (2 ^ length ys - 1)"
      using to_nat_length_upper_bound add_le_mono by algebra
    also have "... = (2 ^ (2 ^ (k + 1)) - 1) + (2 ^ (2 ^ (k + 1)) - 1)"
      using assms by simp
    also have "... < (2 ^ (2 ^ (k + 1)) - 1) + (2 ^ (2 ^ (k + 1)))"
      by (meson add_strict_left_mono diff_less pos2 zero_less_one zero_less_power)
    finally have "Nat_LSBF.to_nat zs' < 2 ^ (2 ^ (k + 1)) - 1"
      using Nat_LSBF.to_nat zs = Nat_LSBF.to_nat zs' + 2 ^ (2 ^ (k + 1)) by simp
    then have "length (inc_nat zs') = length zs'"
      using length_inc_nat' length zs' = 2 ^ (k + 1) by simp
    then have "length (add_fermat xs ys) = 2 ^ (k + 1)"
      using add_fermat xs ys = inc_nat (butlast zs) butlast zs = zs' length zs' = 2 ^ (k + 1)
      by simp
    with 2 show ?thesis unfolding fermat_non_unique_carrier_def by simp
  next
    case False
    have "length zs  2 ^ (k + 1)"
      using assms zs_def length_add_nat_lower[of xs ys] by simp
    moreover have "length zs  2 ^ (k + 1) + 1"
      using assms zs_def length_add_nat_upper[of xs ys] by simp
    ultimately have "length zs = 2 ^ (k + 1)" using False by simp
    then have "add_fermat xs ys  fermat_non_unique_carrier"
      unfolding fermat_non_unique_carrier_def add_fermat_def
      by (simp add: Let_def zs_def)
    moreover have "Nat_LSBF.to_nat zs = Nat_LSBF.to_nat xs + Nat_LSBF.to_nat ys"
      by (simp add: zs_def add_nat_correct)
    moreover have "add_fermat xs ys = zs"
      unfolding add_fermat_def using False zs_def by simp
    ultimately show ?thesis by algebra
  qed
qed

corollary add_fermat_closed:
  assumes "xs  fermat_non_unique_carrier"
  assumes "ys  fermat_non_unique_carrier"
  shows "add_fermat xs ys  fermat_non_unique_carrier"
  by (intro conjunct1[OF add_fermat_correct'] assms)

corollary add_fermat_correct:
  assumes "xs  fermat_non_unique_carrier"
  assumes "ys  fermat_non_unique_carrier"
  shows "to_residue_ring (add_fermat xs ys) = to_residue_ring xs Fnto_residue_ring ys"
proof -
  have "to_residue_ring (add_fermat xs ys) = (int (Nat_LSBF.to_nat xs) + int (Nat_LSBF.to_nat ys)) mod int n"
    using add_fermat_correct'[OF assms]
    by (metis of_nat_add of_nat_mod to_residue_ring.simps)
  also have "... = (int (Nat_LSBF.to_nat xs) mod int n + int (Nat_LSBF.to_nat ys) mod int n) mod int n"
    using mod_add_eq by presburger
  also have "... = (int (Nat_LSBF.to_nat xs mod n) + int (Nat_LSBF.to_nat ys mod n)) mod int n"
    using zmod_int by simp
  also have "... = to_residue_ring xs Fnto_residue_ring ys"
    by (simp add: res_add_eq zmod_int) 
  finally show ?thesis .
qed

definition subtract_fermat where
  "subtract_fermat xs ys = add_fermat xs (multiply_with_power_of_2 ys (2 ^ k))"

lemma subtract_fermat_correct':
  assumes "xs  fermat_non_unique_carrier"
  assumes "ys  fermat_non_unique_carrier"
  shows "subtract_fermat xs ys  fermat_non_unique_carrier  int (Nat_LSBF.to_nat (subtract_fermat xs ys)) mod n = (int (Nat_LSBF.to_nat xs) - int (Nat_LSBF.to_nat ys)) mod n"
proof -
  from assms(2) have "multiply_with_power_of_2 ys (2 ^ k)  fermat_non_unique_carrier"
    unfolding fermat_non_unique_carrier_def multiply_with_power_of_2_def rotate_right_def by simp
  with assms(1) have 1: "subtract_fermat xs ys  fermat_non_unique_carrier"
    unfolding subtract_fermat_def using add_fermat_correct' by simp
  have "int (Nat_LSBF.to_nat (subtract_fermat xs ys)) mod n = int (Nat_LSBF.to_nat (subtract_fermat xs ys) mod n)"
    using zmod_int by presburger
  also have "... = int ((Nat_LSBF.to_nat xs + Nat_LSBF.to_nat (multiply_with_power_of_2 ys (2 ^ k))) mod n)"
    using add_fermat_correct'
    using multiply_with_power_of_2 ys (2 ^ k)  fermat_non_unique_carrier
    using assms(1) subtract_fermat_def by presburger
  also have "... = int ((Nat_LSBF.to_nat xs + Nat_LSBF.to_nat (multiply_with_power_of_2 ys (2 ^ k)) mod n) mod n)"
    by presburger
  also have "... = int ((Nat_LSBF.to_nat xs + (Nat_LSBF.to_nat ys * 2 ^ (2 ^ k)) mod n) mod n)"
    using multiply_with_power_of_2_correct' assms(2) by presburger
  also have "... = (int (Nat_LSBF.to_nat xs) + int (Nat_LSBF.to_nat ys) * (int (2 ^ (2 ^ k)) mod n)) mod n"
    using zmod_int int_ops(7) int_plus
    by (simp add: mod_add_right_eq mod_mult_right_eq)
  also have "... = (int (Nat_LSBF.to_nat xs) + int (Nat_LSBF.to_nat ys) * ((-1) mod n)) mod n"
    using two_pow_half_carrier_length by simp
  also have "... = (int (Nat_LSBF.to_nat xs) - int (Nat_LSBF.to_nat ys)) mod n"
    by (simp add: mod_add_cong mod_mult_right_eq)
  finally show ?thesis using 1 by blast
qed

corollary subtract_fermat_closed:
  assumes "xs  fermat_non_unique_carrier"
  assumes "ys  fermat_non_unique_carrier"
  shows "subtract_fermat xs ys  fermat_non_unique_carrier"
  by (intro conjunct1[OF subtract_fermat_correct'] assms)

corollary subtract_fermat_correct:
  assumes "xs  fermat_non_unique_carrier"
  assumes "ys  fermat_non_unique_carrier"
  shows "to_residue_ring (subtract_fermat xs ys) = to_residue_ring xs Fnto_residue_ring ys"
proof -
  have "to_residue_ring (subtract_fermat xs ys) = (int (Nat_LSBF.to_nat xs) - int (Nat_LSBF.to_nat ys)) mod int n"
    using zmod_int subtract_fermat_correct' assms by simp
  also have "... = (int (Nat_LSBF.to_nat xs) mod int n - int (Nat_LSBF.to_nat ys) mod int n) mod int n"
    using mod_diff_eq by metis
  also have "... = (int (Nat_LSBF.to_nat xs mod n) - int (Nat_LSBF.to_nat ys mod n)) mod int n"
    using zmod_int by simp
  also have "... = to_residue_ring xs Fnto_residue_ring ys"
    using residues_minus_eq by (simp add: zmod_int)
  finally show ?thesis .
qed

end

context int_lsbf_fermat begin

definition reduce :: "nat_lsbf  nat_lsbf" where
"reduce xs = (let (ys, zs) = split xs in
  if compare_nat zs ys then
    subtract_nat ys zs
  else
    subtract_nat (add_nat (True # replicate (2 ^ k - 1) False @ [True]) ys) zs)"

lemma reduce_correct':
  assumes "xs  fermat_non_unique_carrier"
  shows "Nat_LSBF.to_nat (reduce xs) < n  Nat_LSBF.to_nat (reduce xs) mod n = Nat_LSBF.to_nat xs mod n" and "length (reduce xs)  2 ^ k + 2"
proof -
  obtain ys zs where "split xs = (ys, zs)" by fastforce
  then have "length ys = 2 ^ k" "length zs = 2 ^ k" using assms by (auto simp: split_def Let_def)
  then have "Nat_LSBF.to_nat ys < n" "Nat_LSBF.to_nat zs < n"
    using to_nat_length_upper_bound
    by (metis add.commute add_strict_increasing le_Suc_ex nat_le_linear nat_zero_less_power_iff not_add_less1 power_0 to_nat_bound_to_length_bound)+
  have "(int (Nat_LSBF.to_nat ys) - int (Nat_LSBF.to_nat zs)) mod n = (int (Nat_LSBF.to_nat ys) + (-1) mod n * int (Nat_LSBF.to_nat zs)) mod n"
    by (metis diff_minus_eq_add left_minus_one_mult_self mod_add_right_eq mod_mult_left_eq mult_minus1 power_one_right)
  also have "... = (int (Nat_LSBF.to_nat ys) + 2 ^ (2 ^ k) mod n * int (Nat_LSBF.to_nat zs)) mod n"
    using two_pow_half_carrier_length by simp
  also have "... = (int (Nat_LSBF.to_nat ys + 2 ^ (2 ^ k) * Nat_LSBF.to_nat zs)) mod n"
    by auto
  also have "... = (int (Nat_LSBF.to_nat (ys @ zs))) mod n"
    using length ys = 2 ^ k to_nat_app by presburger
  also have "... = (int (Nat_LSBF.to_nat xs)) mod n"
    using split xs = (ys, zs) app_split by presburger
  finally have 0: "(int (Nat_LSBF.to_nat ys) - int (Nat_LSBF.to_nat zs)) mod n = (int (Nat_LSBF.to_nat xs)) mod n" .
  have "Nat_LSBF.to_nat (reduce xs) < n  Nat_LSBF.to_nat (reduce xs) mod n = Nat_LSBF.to_nat xs mod n  length (reduce xs)  2 ^ k + 2"
  proof (cases "compare_nat zs ys")
    case True
    then have "reduce xs = subtract_nat ys zs"
      unfolding reduce_def split xs = (ys, zs) by simp
    then have 1: "Nat_LSBF.to_nat (reduce xs) = Nat_LSBF.to_nat ys - Nat_LSBF.to_nat zs"
      using subtract_nat_correct by presburger
    from True have "Nat_LSBF.to_nat zs  Nat_LSBF.to_nat ys"
      using compare_nat_correct by blast
    with 1 have "int (Nat_LSBF.to_nat (reduce xs)) = int (Nat_LSBF.to_nat ys) - int (Nat_LSBF.to_nat zs)"
      by linarith
    then have "int (Nat_LSBF.to_nat (reduce xs)) mod n = (int (Nat_LSBF.to_nat xs)) mod n"
      using 0 by presburger
    then have "Nat_LSBF.to_nat (reduce xs) mod n = Nat_LSBF.to_nat xs mod n"
      using zmod_int by (metis of_nat_eq_iff)

    have "Nat_LSBF.to_nat (reduce xs)  Nat_LSBF.to_nat ys" using 1 by linarith
    also have "... < n" using Nat_LSBF.to_nat ys < n .
    finally have "Nat_LSBF.to_nat (reduce xs) < n  Nat_LSBF.to_nat (reduce xs) mod n = Nat_LSBF.to_nat xs mod n"
      using Nat_LSBF.to_nat (reduce xs) mod n = Nat_LSBF.to_nat xs mod n by blast
    moreover have "length (reduce xs)  2 ^ k + 2" unfolding reduce xs = subtract_nat ys zs
      apply (estimation estimate: conjunct2[OF subtract_nat_aux])
      using length zs = 2 ^ k length ys = 2 ^ k by simp
    ultimately show ?thesis by simp
  next
    case False
    then have reduce_eq: "reduce xs = subtract_nat (add_nat (True # replicate (2 ^ k - 1) False @ [True]) ys) zs"
      unfolding reduce_def split xs = (ys, zs) by simp
    then have "Nat_LSBF.to_nat (reduce xs) = 1 + 2 * (2 ^ (2 ^ k - 1)) + Nat_LSBF.to_nat ys - Nat_LSBF.to_nat zs"
      by (simp add: subtract_nat_correct add_nat_correct to_nat_app)
    also have "(1::nat) + 2 * (2 ^ (2 ^ k - 1)) = 1 + 2 ^ (2 ^ k - 1 + 1)"
      by (metis add.commute power_add power_one_right)
    also have "... = n"
      by simp
    finally have 1: "Nat_LSBF.to_nat (reduce xs) = n + Nat_LSBF.to_nat ys - Nat_LSBF.to_nat zs" .
    then have "Nat_LSBF.to_nat (reduce xs) < n"
      using False Nat_LSBF.to_nat ys < n Nat_LSBF.to_nat zs < n unfolding compare_nat_correct
      by linarith
    from 1 have "int (Nat_LSBF.to_nat (reduce xs)) = int n + int (Nat_LSBF.to_nat ys) - int (Nat_LSBF.to_nat zs)"
      using Nat_LSBF.to_nat zs < n by linarith
    also have "... mod n = ((int n) mod n + (int (Nat_LSBF.to_nat ys) - int (Nat_LSBF.to_nat zs))) mod n"
      using add_diff_eq
      using mod_add_left_eq[of "int n" "int n" "int (Nat_LSBF.to_nat ys) - int (Nat_LSBF.to_nat zs)", symmetric]
      by metis
    also have "... = (int (Nat_LSBF.to_nat ys) - int (Nat_LSBF.to_nat zs)) mod n"
      using mod_self[of "int n"]
      by simp
    finally have "int (Nat_LSBF.to_nat (reduce xs)) mod n = int (Nat_LSBF.to_nat xs) mod n" using 0 by presburger
    then have "Nat_LSBF.to_nat (reduce xs) < n  Nat_LSBF.to_nat (reduce xs) mod n = Nat_LSBF.to_nat xs mod n"
      using Nat_LSBF.to_nat (reduce xs) < n zmod_int nat_int_comparison(1) by presburger
    moreover have "length (reduce xs)  2 ^ k + 2"
      unfolding reduce_eq
      apply (estimation estimate: conjunct2[OF subtract_nat_aux])
      apply (estimation estimate: length_add_nat_upper)
      unfolding length ys = 2 ^ k length zs = 2 ^ k by simp
    ultimately show ?thesis by simp
  qed
  then show "Nat_LSBF.to_nat (reduce xs) < n  Nat_LSBF.to_nat (reduce xs) mod n = Nat_LSBF.to_nat xs mod n" "length (reduce xs)  2 ^ k + 2"
    by simp_all
qed

lemma reduce_correct:
  assumes "xs  fermat_non_unique_carrier"
  shows "Nat_LSBF.to_nat xs mod n = Nat_LSBF.to_nat (reduce xs)"
  using reduce_correct'[OF assms] mod_less by metis

lemma add_take_drop_carry_aux:
  assumes "xs' = add_nat (take e xs) (drop e xs)"
  assumes "length xs = e + 1"
  assumes "e  1"
  shows "length xs'  e  (xs' = replicate e False @ [True]  xs = replicate e True @ [True])"
proof (intro verit_and_neg(3))
  assume a: "¬ (length xs'  e)"
  then have "length xs'  e + 1" by simp
  moreover have "length xs'  e + 1"
    unfolding assms(1)
    apply (estimation estimate: length_add_nat_upper)
    using assms by simp
  ultimately have len_xs': "length xs' = e + 1" by simp
  moreover have "max (length (take e xs)) (length (drop e xs)) = e"
    using assms by simp
  ultimately have "zs. xs' = zs @ [True]"
    unfolding assms(1) by (intro add_nat_last_bit_True, argo)
  then obtain zs where zs_def: "xs' = zs @ [True]" and len_zs: "length zs = e" using len_xs' by auto

  have "Nat_LSBF.to_nat xs' = Nat_LSBF.to_nat xs mod 2 ^ e + Nat_LSBF.to_nat xs div 2 ^ e"
    unfolding assms(1) by (simp add: add_nat_correct to_nat_take to_nat_drop)
  also have "... < (2 ^ e - 1) + (2 ^ (e + 1)) div 2 ^ e"
    apply (intro add_le_less_mono)
    subgoal using pos_mod_bound[of "2 ^ e" "Nat_LSBF.to_nat xs"] two_pow_pos 
      by (metis Suc_mask_eq_exp mask_eq_exp_minus_1 mod_Suc_le_divisor)
    subgoal using to_nat_length_upper_bound[of xs] assms div_le_mono
      by (metis add_diff_cancel_left' le_add1 less_mult_imp_div_less power_add power_commutes power_diff power_one_right to_nat_length_bound zero_neq_numeral)
    done
  also have "... = 2 ^ e + 1" by simp
  finally have "Nat_LSBF.to_nat xs'  2 ^ e" by simp
  moreover have "Nat_LSBF.to_nat xs' = Nat_LSBF.to_nat zs + 2 ^ e"
    unfolding zs_def by (simp add: to_nat_app len_zs)
  ultimately have "Nat_LSBF.to_nat zs = 0" by simp
  then have "zs = replicate e False" "Nat_LSBF.to_nat xs' = 2 ^ e"
    using len_zs to_nat_zero_iff truncate_Nil_iff Nat_LSBF.to_nat xs' = Nat_LSBF.to_nat zs + 2 ^ e
    by auto
  then have "xs' = replicate e False @ [True]" using zs_def by simp
  from assms(2) obtain xst xsh where xs_decomp: "xs = xst @ [xsh]" "length xst = e"
    by (metis Suc_eq_plus1 length_Suc_conv_rev)
  then have "take e xs = xst" "drop e xs = [xsh]" using assms by simp_all
  moreover have[simp]: "xsh = True"
  proof (rule ccontr)
    assume "xsh  True"
    then have "drop e xs = [False]" using xs_decomp by simp
    then have "Nat_LSBF.to_nat xs' = Nat_LSBF.to_nat (take e xs)"
      unfolding assms(1) add_nat_correct by simp
    also have "... < 2 ^ e"
      using assms(2) to_nat_length_bound[of "take e xs"] by simp
    finally show False using Nat_LSBF.to_nat xs' = 2 ^ e by simp
  qed
  ultimately have "Nat_LSBF.to_nat xs' = Nat_LSBF.to_nat xst + 1" unfolding assms(1) add_nat_correct
    by simp
  then have "Nat_LSBF.to_nat xst = 2 ^ e - 1" using Nat_LSBF.to_nat xs' = 2 ^ e by simp
  then have "xst = replicate e True" using to_nat_replicate_True2[of xst] length xst = e by argo
  then have "xs = replicate e True @ [True]"
    using xs = xst @ [xsh] by simp
  then show "xs' = replicate e False @ [True]  xs = replicate e True @ [True]"
    using xs' = replicate e False @ [True]
    by (simp add: replicate_append_same)
qed

function from_nat_lsbf :: "nat_lsbf  nat_lsbf" where
"from_nat_lsbf xs = (if length xs  2 ^ (k + 1) then fill (2 ^ (k + 1)) xs
    else from_nat_lsbf (add_nat (take (2 ^ (k + 1)) xs) (drop (2 ^ (k + 1)) xs)))"
  by pat_completeness auto
lemma from_nat_lsbf_dom_termination: "All from_nat_lsbf_dom"
proof (relation "measures [length, Nat_LSBF.to_nat]")
  show "wf (measures [length, Nat_LSBF.to_nat])" by simp
  fix xs :: nat_lsbf
  define e :: nat where "e = 2 ^ (k + 1)"
  then have e_ge_1: "e  1" and e_ge_2: "e  2" by simp_all
  define xs' where "xs' = add_nat (take e xs) (drop e xs)"
  assume "¬ length xs  2 ^ (k + 1)"
  then have a: "length xs  e + 1" unfolding e_def by simp
  then consider "length xs = e + 1  length xs'  e" |
      "length xs = e + 1  length xs'  e + 1" |
      "length xs  e + 2"
    by linarith
  then show "(add_nat (take (2 ^ (k + 1)) xs) (drop (2 ^ (k + 1)) xs), xs)
           measures [length, Nat_LSBF.to_nat]"
    unfolding e_def[symmetric] xs'_def[symmetric]
  proof cases
    case 1
    then show "(xs', xs)  measures [length, Nat_LSBF.to_nat]" by simp
  next
    case 2
    with add_take_drop_carry_aux[OF xs'_def _ e_ge_1] have
      xs'_rep: "xs' = replicate e False @ [True]" and
      xs_rep: "xs = replicate e True @ [True]"
      by simp_all
    then have "Nat_LSBF.to_nat xs' < Nat_LSBF.to_nat xs  (0::nat) < 2 ^ e - 1"
      by (auto simp: to_nat_app)
    also have "..." using e_ge_1
      by (metis One_nat_def Suc_le_lessD less_2_cases_iff one_less_power zero_less_diff)
    finally show "(xs', xs)  measures [length, Nat_LSBF.to_nat]"
      using 2 xs'_rep by simp
  next
    case 3
    have "length xs'  max e (length xs - e) + 1"
      unfolding xs'_def
      apply (estimation estimate: length_add_nat_upper)
      by simp
    also have "... < length xs" using 3 e_ge_2 by simp
    finally show "(xs', xs)  measures [length, Nat_LSBF.to_nat]" by simp
  qed
qed
termination by (rule from_nat_lsbf_dom_termination)

declare from_nat_lsbf.simps[simp del]

lemma from_nat_lsbf_correct:
  shows "from_nat_lsbf xs  fermat_non_unique_carrier"
    "to_residue_ring (from_nat_lsbf xs) = to_residue_ring xs"
proof (induction xs rule: from_nat_lsbf.induct)
  case (1 xs)
  then show "from_nat_lsbf xs  fermat_non_unique_carrier"
    apply (cases "length xs  2 ^ (k + 1)")
    subgoal
      unfolding fermat_non_unique_carrier_def
      by (simp add: from_nat_lsbf.simps[of xs] length_fill)
    subgoal
      by (simp add: from_nat_lsbf.simps[of xs])
    done
  show "to_residue_ring (from_nat_lsbf xs) = to_residue_ring xs"
  proof (cases "length xs  2 ^ (k + 1)")
    case True
    then show ?thesis
      by (simp add: from_nat_lsbf.simps[of xs])
  next
    case False
    let ?xs1 = "take (2 ^ (k + 1)) xs"
    let ?xs2 = "drop (2 ^ (k + 1)) xs"
    from False have "xs = ?xs1 @ ?xs2" by simp
    from False have "from_nat_lsbf xs = from_nat_lsbf (add_nat ?xs1 ?xs2)"
      by (simp add: from_nat_lsbf.simps[of xs])
    then have "to_residue_ring (from_nat_lsbf xs) = to_residue_ring (add_nat ?xs1 ?xs2)"
      using 1[OF False] by argo
    also have "... = (Nat_LSBF.to_nat ?xs1 + Nat_LSBF.to_nat ?xs2) mod n" by (simp add: add_nat_correct zmod_int)
    also have "... = (Nat_LSBF.to_nat ?xs1 + (2 ^ (2 ^ (k + 1))) * Nat_LSBF.to_nat ?xs2) mod n"
      using two_pow_carrier_length mod_add_right_eq mod_mult_left_eq
      by (metis (no_types, opaque_lifting) mult_numeral_1 numerals(1))
    also have "... = (Nat_LSBF.to_nat xs) mod n"
      by (intro_cong "[cong_tag_1 int, cong_tag_2 (mod)]" more: refl to_nat_drop_take[symmetric])
    finally show ?thesis by (simp add: zmod_int)
  qed
qed

lemma length_from_nat_lsbf: "length (from_nat_lsbf xs) = 2 ^ (k + 1)"
  using fermat_carrier_length[OF from_nat_lsbf_correct(1)] .

subsection "Implementing FNTT in $\\mathbb{Z}_{F_n}$"

lemma n_odd: "odd n"
  by simp

lemma ord_2: "ord n 2 = 2 ^ (k + 1)"
proof -
  have "ord n 2 dvd 2 ^ (k + 1)"
    using ord_divides[of "2::nat" "2 ^ (k + 1)" n]
    using two_pow_carrier_length
    by (simp add: cong_def)
  then obtain i where "ord n 2 = 2 ^ i" "i  k + 1"
    using divides_primepow_nat[OF two_is_prime_nat]
    by blast
  have "i = k + 1"
  proof (rule ccontr)
    assume "i  k + 1"
    then have "i  k" using i  k + 1 by linarith
    have "1  (2::nat) ^ (2 ^ k) mod n" using two_pow_half_carrier_length_neq_1[symmetric] .
    moreover have "(2::nat) ^ (2 ^ k) mod n = 1"
    proof -
      have "(2::nat) ^ (2 ^ k) mod n = (2 ^ (2 ^ i)) ^ (2 ^ (k - i)) mod n"
        by (simp add: i  k power_add[symmetric] power_mult[symmetric])
      also have "... = (2 ^ (2 ^ i) mod n) ^ (2 ^ (k - i)) mod n"
        by (simp add: power_mod)
      also have "2 ^ (2 ^ i) mod n = 1" using ord n 2 = 2 ^ i
        using ord[of 2 n] unfolding cong_def using n_gt_1 by simp
      finally show ?thesis by simp
    qed
    ultimately show "False" by argo
  qed
  then show ?thesis using ord n 2 = 2 ^ i by argo
qed
corollary ord_2_int: "ord (int n) 2 = 2 ^ (k + 1)"
  using ord_2 ord_int[of n 2] by simp

lemma two_is_primitive_root: "primitive_root (2 ^ (k + 1)) 2"
  apply (intro primitive_rootI)
  subgoal
    using two_in_carrier .
  subgoal
    using two_pow_carrier_length_residue_ring .
  subgoal for i
    using ord_2_int unfolding ord_def
    using pow_nat_eq not_less_Least cong_def
    by (metis (no_types, lifting) less_nat_zero_code one_cong)
  done

lemma two_inv_is_primitive_root: "primitive_root (2 ^ (k + 1)) (invFn2)"
  using primitive_root_inv[OF _ two_is_primitive_root] by simp

lemma two_powers_primitive_root:
  assumes "i + s = k + 1"
  assumes "i  k"
  shows "primitive_root (2 ^ s) (2 [^]Fn(2::nat) ^ i)"
proof (intro primitive_rootI nat_pow_closed two_in_carrier)

  have "(2 [^]Fn(2::nat) ^ i) [^]Fn(2::nat) ^ s = 2 [^]Fn((2::nat) ^ (i + s))"
    by (simp add: nat_pow_pow[OF two_in_carrier] power_add)
  also have "... = 𝟭Fn⇙"
    unfolding assms(1) by (rule two_pow_carrier_length_residue_ring)
  finally show "(2 [^]Fn(2::nat) ^ i) [^]Fn(2::nat) ^ s = 𝟭Fn⇙" .

  fix j :: nat
  assume "0 < j" "j < 2 ^ s"
  then have "2 ^ i * j < 2 ^ (k + 1)"
    using power_add assms(1)
    by (metis nat_mult_less_cancel1 pos2 zero_less_power)
  have "2 ^ i * j > 0" using j > 0 by simp
  have 1: "(l{1..<(2::nat) ^ (k + 1)}. 2 [^]Fnl  𝟭Fn)"
    using two_is_primitive_root unfolding primitive_root_def by simp
  have "(2 [^]Fn(2::nat) ^ i) [^]Fnj = 2 [^]Fn(2 ^ i * j)"
    by (simp add: nat_pow_pow[OF two_in_carrier])
  also have "...  𝟭Fn⇙"
    using 1 2 ^ i * j > 0 2 ^ i * j < 2 ^ (k + 1) by simp
  finally show "(2 [^]Fn(2::nat) ^ i) [^]Fnj  𝟭Fn⇙" .
qed

fun fft_combine_b_c_aux :: "(nat_lsbf  nat_lsbf  nat_lsbf)  (nat_lsbf  nat  nat_lsbf)  nat  nat_lsbf list × nat  nat_lsbf list  nat_lsbf list  nat_lsbf list" where
"fft_combine_b_c_aux f g l (revs, e) [] [] = rev revs"
| "fft_combine_b_c_aux f g l (revs, e) (b # bs) (c # cs) =
    fft_combine_b_c_aux f g l ((f b (g c e)) # revs, (e + l) mod 2 ^ (k + 1)) bs cs"
| "fft_combine_b_c_aux f g l _ _ _ = undefined"

fun fft_ifft_combine_b_c_add where
"fft_ifft_combine_b_c_add True l bs cs = fft_combine_b_c_aux add_fermat divide_by_power_of_2 l ([], 0) bs cs"
| "fft_ifft_combine_b_c_add False l bs cs = fft_combine_b_c_aux add_fermat multiply_with_power_of_2 l ([], 0) bs cs"

fun fft_ifft_combine_b_c_subtract where
"fft_ifft_combine_b_c_subtract True l bs cs = fft_combine_b_c_aux subtract_fermat divide_by_power_of_2 l ([], 0) bs cs"
| "fft_ifft_combine_b_c_subtract False l bs cs = fft_combine_b_c_aux subtract_fermat multiply_with_power_of_2 l ([], 0) bs cs"

lemma fft_combine_b_c_aux_correct:
  assumes "length bs = len_bc" "length cs = len_bc"
  assumes "e < 2 ^ (k + 1)"
  shows "fft_combine_b_c_aux f g l (revs, e) bs cs = rev revs @ map3 (λx y i. f x (g y ((e + l * i) mod 2 ^ (k + 1)))) bs cs [0..<len_bc]"
using assms proof (induction len_bc arbitrary: bs cs revs e)
  case 0
  then have "bs = []" "cs = []" by simp_all
  then show ?case by simp
next
  case (Suc len_bc)
  then obtain b bs' c cs' where bcs: "bs = b # bs'" "cs = c # cs'" by (meson length_Suc_conv)
  with Suc.prems have len_bcs': "length bs' = len_bc" "length cs' = len_bc" by simp_all
  have "(e + l * i) mod 2 ^ (k + 1) < 2 ^ (k + 1)" for i by simp
  note ih = Suc.IH[OF len_bcs' this]
  have "fft_combine_b_c_aux f g l (revs, e) bs cs =
    fft_combine_b_c_aux f g l (f b (g c e) # revs, (e + l) mod (2 * 2 ^ k)) bs' cs'"
    unfolding bcs by simp
  also have "... = rev (f b (g c e) # revs) @
    map3 (λx y i. f x (g y (((e + l * 1) mod 2 ^ (k + 1) + l * i) mod 2 ^ (k + 1)))) bs' cs'
    [0..<len_bc]"
    using ih[of "f b (g c e) # revs" 1] by simp
  also have "... = rev revs @ (f b (g c e) #
      map3 (λx y i. f x (g y (((e + l * 1) mod 2 ^ (k + 1) + l * i) mod 2 ^ (k + 1)))) bs' cs'
      [0..<len_bc])"
    by simp
  finally have r: "fft_combine_b_c_aux f g l (revs, e) bs cs = ..." .
  show ?case unfolding r
  proof (intro arg_cong2[where f = "(@)"] refl)
    have "f b (g c e) #
      map3 (λx y i. f x (g y (((e + l * 1) mod 2 ^ (k + 1) + l * i) mod 2 ^ (k + 1)))) bs' cs' [0..<len_bc] =
      f b (g c (e + l * 0)) #
      map3 (λx y i. f x (g y ((e + l * Suc i) mod 2 ^ (k + 1)))) bs' cs' [0..<len_bc]"
      (is "?l = ?f # ?m3")
      apply (intro arg_cong2[where f = "(#)"])
      subgoal by simp
      subgoal
        unfolding append.append_Nil
        apply (intro arg_cong[where f = "λi. map3 i _ _ _"])
        by (simp add: add.assoc mod_add_left_eq)
      done
    also have "?m3 = map3 (λx y i. f x (g y ((e + l * i) mod 2 ^ (k + 1)))) bs' cs' (map Suc [0..<len_bc])"
      by (rule map3_compose3)
    also have "... = map3 (λx y i. f x (g y ((e + l * i) mod 2 ^ (k + 1)))) bs' cs' [Suc 0..<Suc len_bc]"
      by (subst map_Suc_upt) (rule refl)
    also have "?f # ... = map3 (λx y i. f x (g y ((e + l * i) mod 2 ^ (k + 1)))) bs cs [0..<Suc len_bc]"
      unfolding upt_conv_Cons[OF zero_less_Suc[of len_bc]] bcs using Suc.prems by simp
    finally show "?l = ..." .
  qed
qed

lemma fft_ifft_combine_b_c_add_correct:
  assumes "length bs = len_bc" "length cs = len_bc"
  shows "fft_ifft_combine_b_c_add it l bs cs = map3 (λx y i. add_fermat x ((if it then divide_by_power_of_2 else multiply_with_power_of_2) y ((l * i) mod 2 ^ (k + 1)))) bs cs [0..<len_bc]"
  by (cases it; simp add: fft_combine_b_c_aux_correct[OF assms])

lemma fft_ifft_combine_b_c_subtract_correct:
  assumes "length bs = len_bc" "length cs = len_bc"
  shows "fft_ifft_combine_b_c_subtract it l bs cs = map3 (λx y i. subtract_fermat x ((if it then divide_by_power_of_2 else multiply_with_power_of_2) y ((l * i) mod 2 ^ (k + 1)))) bs cs [0..<len_bc]"
  by (cases it; simp add: fft_combine_b_c_aux_correct[OF assms])

lemma fft_ifft_combine_b_c_add_carrier:
  assumes "length bs = len_bc" "length cs = len_bc"
  assumes "set bs  fermat_non_unique_carrier"
  assumes "set cs  fermat_non_unique_carrier"
  shows "set (fft_ifft_combine_b_c_add it l bs cs)  fermat_non_unique_carrier"
  unfolding fft_ifft_combine_b_c_add_correct[OF assms(1) assms(2)]
  apply (intro set_map3_subseteqI[OF _ assms(3) assms(4) subset_refl] add_fermat_closed)
   apply (simp_all add: divide_by_power_of_2_closed multiply_with_power_of_2_closed)
  done

lemma fft_ifft_combine_b_c_subtract_carrier:
  assumes "length bs = len_bc" "length cs = len_bc"
  assumes "set bs  fermat_non_unique_carrier"
  assumes "set cs  fermat_non_unique_carrier"
  shows "set (fft_ifft_combine_b_c_subtract it l bs cs)  fermat_non_unique_carrier"
  unfolding fft_ifft_combine_b_c_subtract_correct[OF assms(1) assms(2)]
  apply (intro set_map3_subseteqI[OF _ assms(3) assms(4) subset_refl] subtract_fermat_closed)
   apply (simp_all add: divide_by_power_of_2_closed multiply_with_power_of_2_closed)
  done

fun fft_ifft :: "bool  nat  nat_lsbf list  nat_lsbf list" where
"fft_ifft it l [] = []"
| "fft_ifft it l [x] = [x]"
| "fft_ifft it l [x, y] = [add_fermat x y, subtract_fermat x y]"
| "fft_ifft it l a = (let nums1 = evens_odds True a;
                  nums2 = evens_odds False a;
                  b = fft_ifft it (2 * l) nums1;
                  c = fft_ifft it (2 * l) nums2;
                  g = fft_ifft_combine_b_c_add it l b c;
                  h = fft_ifft_combine_b_c_subtract it l b c
                   in g@h)"

(* fft l xs performs the fast fourier transform of xs with primitive root of unity 2 ^ l *)
fun fft where "fft l xs = fft_ifft False l xs"
fun ifft where "ifft l xs = fft_ifft True l xs"

end

locale fft_context = int_lsbf_fermat +
  fixes it :: bool
  fixes l e :: nat
  fixes a1 a2 a3 :: nat_lsbf
  fixes as :: "nat_lsbf list"
  assumes length_a': "length (a1 # a2 # a3 # as) = 2 ^ e"
begin

definition a where "a = a1 # a2 # a3 # as"
definition nums1 where "nums1 = evens_odds True a"
definition nums2 where "nums2 = evens_odds False a"
definition b where "b = fft_ifft it (2 * l) nums1"
definition c where "c = fft_ifft it (2 * l) nums2"
definition g where "g = fft_ifft_combine_b_c_add it l b c"
definition h where "h = fft_ifft_combine_b_c_subtract it l b c"
lemmas defs = a_def nums1_def nums2_def b_def c_def g_def h_def

lemma length_a: "length a = 2 ^ e" unfolding a_def by (rule length_a')
lemma e_ge_2: "e  2"
proof (rule ccontr)
  assume "¬ e  2"
  then have "e  1" by simp
  have "(2::nat) ^ e  2" using power_increasing[OF e  1, of "2::nat"] by simp
  then show False using length_a' by simp
qed
lemma e_pos: "e > 0" using e_ge_2 by simp
lemma two_pow_e_div_2: "(2::nat) ^ e div 2 = 2 ^ (e - 1)"
  using gr0_implies_Suc[OF e_pos] by auto
lemma two_pow_e_as_sum: "(2::nat) ^ e = 2 ^ (e - 1) + 2 ^ (e - 1)"
  by (metis e_pos two_pow_e_div_2 even_power even_two_times_div_two gcd_nat.eq_iff mult_2)

lemma
  shows length_nums1: "length nums1 = 2 ^ (e - 1)"
  and length_nums2: "length nums2 = 2 ^ (e - 1)"
  unfolding nums1_def nums2_def length_evens_odds length_a 
  using two_pow_e_div_2 by simp_all

lemma result_eq: "fft_ifft it l a = g @ h"
  unfolding a_def fft_ifft.simps[of it l] Let_def
  unfolding defs[symmetric] by (rule refl)

lemma
  assumes "set a  fermat_non_unique_carrier"
  shows nums1_carrier: "set nums1  fermat_non_unique_carrier"
  and nums2_carrier: "set nums2  fermat_non_unique_carrier"
  unfolding nums1_def nums2_def atomize_conj
  by (intro conjI subset_trans[OF set_evens_odds] assms)

end

context int_lsbf_fermat
begin

lemma length_fft_ifft:
  assumes "length a = 2 ^ e"
  shows "length (fft_ifft it l a) = 2 ^ e"
  using assms
proof (induction it l a arbitrary: e rule: fft_ifft.induct)
  case (4 it l a1 a2 a3 as)
  interpret fft_context k it l e a1 a2 a3 as
    apply unfold_locales
    using 4 by argo

  have len_b: "length b = 2 ^ (e - 1)"
    unfolding b_def
    apply (intro "4.IH"[of nums1 nums2])
    unfolding defs[symmetric] length_nums1
    by (rule refl)+
  have len_c: "length c = 2 ^ (e - 1)"
    unfolding c_def
    apply (intro "4.IH"(2)[of nums1 nums2 b])
    unfolding defs[symmetric] length_nums2
    by (rule refl)+
  have len_g: "length g = 2 ^ (e - 1)"
    unfolding g_def fft_ifft_combine_b_c_add_correct[OF len_b len_c] map3_as_map
    by (simp add: len_b len_c)
  have len_h: "length h = 2 ^ (e - 1)"
    unfolding h_def fft_ifft_combine_b_c_subtract_correct[OF len_b len_c] map3_as_map
    by (simp add: len_b len_c)
  show ?case
    unfolding a_def[symmetric] result_eq
    by (simp add: len_g len_h e_pos two_pow_e_as_sum)
qed simp_all

lemma length_fft:
  assumes "length a = 2 ^ e"
  shows "length (fft l a) = 2 ^ e"
  unfolding fft.simps length_fft_ifft[OF assms] by (rule refl)

lemma length_ifft:
  assumes "length a = 2 ^ e"
  shows "length (ifft l a) = 2 ^ e"
  unfolding ifft.simps length_fft_ifft[OF assms] by (rule refl)

end

context fft_context begin

lemma length_b: "length b = 2 ^ (e - 1)"
  unfolding b_def by (intro length_fft_ifft length_nums1)
lemma length_c: "length c = 2 ^ (e - 1)"
  unfolding c_def by (intro length_fft_ifft length_nums2)
lemma length_g: "length g = 2 ^ (e - 1)"
  unfolding g_def fft_ifft_combine_b_c_add_correct[OF length_b length_c] map3_as_map
  by (simp add: length_b length_c)
lemma length_h: "length h = 2 ^ (e - 1)"
  unfolding h_def fft_ifft_combine_b_c_subtract_correct[OF length_b length_c] map3_as_map
  by (simp add: length_b length_c)

end

context int_lsbf_fermat
begin

lemma fft_ifft_carrier:
  assumes "length a = 2 ^ l"
  assumes "set a  fermat_non_unique_carrier"
  shows "set (fft_ifft it s a)  fermat_non_unique_carrier"
using assms proof (induction it s a arbitrary: l rule: fft_ifft.induct)
  case (1 it s)
  then show ?case by simp
next
  case (2 it s x)
  then show ?case by simp
next
  case (3 it s x y)
  then show ?case by (simp add: add_fermat_closed subtract_fermat_closed)
next
  case (4 it s a1 a2 a3 as)
  interpret fft_context k it s l a1 a2 a3 as
    apply unfold_locales using 4 by simp

  have b_carrier: "set b  fermat_non_unique_carrier"
    unfolding b_def
    apply (intro "4.IH"(1)[of nums1 nums2 "l - 1"])
    unfolding defs[symmetric]
    using nums1_carrier length_nums1 "4.prems" a_def by simp_all
  have c_carrier: "set c  fermat_non_unique_carrier"
    unfolding c_def
    apply (intro "4.IH"(2)[of nums1 nums2 b "l - 1"])
    unfolding defs[symmetric]
    using nums2_carrier length_nums2 "4.prems" a_def by simp_all
    

  have g_carrier: "set g  fermat_non_unique_carrier"
    unfolding g_def
    apply (intro fft_ifft_combine_b_c_add_carrier)
    using length_b length_c b_carrier c_carrier by simp_all
  have h_carrier: "set h  fermat_non_unique_carrier"
    unfolding h_def
    apply (intro fft_ifft_combine_b_c_subtract_carrier)
    using length_b length_c b_carrier c_carrier by simp_all

  show ?case
    unfolding defs[symmetric] result_eq
    using g_carrier h_carrier by simp
qed

lemma fft_carrier:
  assumes "length a = 2 ^ l"
  assumes "set a  fermat_non_unique_carrier"
  shows "set (fft s a)  fermat_non_unique_carrier"
  using fft_ifft_carrier[OF assms] by simp

lemma ifft_carrier:
  assumes "length a = 2 ^ l"
  assumes "set a  fermat_non_unique_carrier"
  shows "set (ifft s a)  fermat_non_unique_carrier"
  using fft_ifft_carrier[OF assms] by simp

lemma fft_ifft_correct':
  assumes "length a = 2 ^ l"
  assumes "l  k + 1"
  assumes "set a  fermat_non_unique_carrier"
  shows "map to_residue_ring (fft_ifft it s a) = FNTT'' ((if it then invFn2 else 2) [^]Fns) (map to_residue_ring a)"
  using assms
proof (induction it s a arbitrary: l rule: fft_ifft.induct)
  case (1 it s)
  then show ?case by simp
next
  case (2 it s x)
  then show ?case by simp
next
  case (3 it s x y)
  then show ?case using add_fermat_correct subtract_fermat_correct by simp
next
  case (4 it s a1 a2 a3 as)
  interpret fft_context k it s l a1 a2 a3 as
    apply unfold_locales using 4 by simp

(* definitions of FNTT *)
  define nums1' where "nums1' = evens_odds True (map to_residue_ring local.a)"
  define nums2' where "nums2' = evens_odds False (map to_residue_ring local.a)"
  define b' where "b' = FNTT'' (((if it then invFn2 else 2) [^]Fns) [^]Fn(2::nat)) nums1'"
  define c' where "c' = FNTT'' (((if it then invFn2 else 2) [^]Fns) [^]Fn(2::nat)) nums2'"
  define g' where "g' = map2 (⊕Fn) b'
              (map2 (⊗Fn)
                (map (([^]Fn) ((if it then invFn2 else 2) [^]Fns)) [0..<length local.a div 2]) c')"
  define h' where "h' = map2 (a_minus Fn) b'
              (map2 (⊗Fn)
                (map (([^]Fn) ((if it then invFn2 else 2) [^]Fns)) [0..<length local.a div 2]) c')"
  note defs' = a_def nums1'_def nums2'_def b'_def c'_def g'_def h'_def

  have fntt_def: "FNTT'' ((if it then invFn2 else 2) [^]Fns) (map to_residue_ring (a1 # a2 # a3 # as))
    = g' @ h'"
    using length_map[of to_residue_ring local.a]
    by (simp only: list.map(2) FNTT''.simps Let_def defs')

(* correspondences *)
  from two_is_primitive_root have "root_of_unity (2 ^ (k + 1)) 2"
    unfolding primitive_root_def by blast

  from e_ge_2 have "Suc (k + 1 - l)  k" using l  k + 1 by linarith

  have pr_unit: "(if it then invFn2 else 2)  Units Fn"
    using two_is_unit Units_inv_Units by algebra
  then have pr_carrier: "(if it then invFn2 else 2)  carrier Fn"
    by (rule Units_closed)
  have pow_2s: "((if it then invFn2 else 2) [^]Fns) [^]Fn(2::nat) = (if it then invFn2 else 2) [^]Fn(2 * s)"
    using nat_pow_pow[OF pr_carrier, of s 2] mult.commute[of s 2] by argo

  from e_ge_2 obtain l' where l'_def[simp]: "l = Suc l'"
    by (metis not_numeral_le_zero old.nat.exhaust)

  have l'_le: "l'  k + 1"
    using l  k + 1 l = Suc l' by linarith

  have nums1'_def': "nums1' = map to_residue_ring nums1"
    by (simp add: nums1'_def nums1_def map_evens_odds)
  then have length_nums1': "length nums1' = 2 ^ l'" using length_nums1 l = Suc l' by simp
  have nums2'_def': "nums2' = map to_residue_ring nums2"
    by (simp add: nums2'_def nums2_def map_evens_odds)
  then have length_nums2': "length nums2' = 2 ^ l'" using length_nums2 l = Suc l' by simp

  have "set local.a  fermat_non_unique_carrier" using 4 by (simp only: a_def)
  have nums1_carrier: "set nums1  fermat_non_unique_carrier"
    unfolding nums1_def using set local.a  fermat_non_unique_carrier set_evens_odds by fastforce

  have b_b': "b' = map to_residue_ring b"
    unfolding b'_def b_def nums1'_def map_evens_odds[symmetric] pow_2s nums1_def
    apply (intro 4(1)[symmetric, of _ nums2 l'])
    subgoal unfolding a_def by (rule refl)
    subgoal unfolding nums2_def a_def by (rule refl)
    subgoal unfolding defs[symmetric] length_nums1 by simp
    subgoal by (rule l'_le)
    subgoal unfolding defs[symmetric] by (rule nums1_carrier)
    done
  then have length_b': "length b' = 2 ^ l'"
    using length_b by simp

  have nums2_carrier: "set nums2  fermat_non_unique_carrier"
    unfolding nums2_def using set local.a  fermat_non_unique_carrier set_evens_odds by fastforce

  have c_c': "c' = map to_residue_ring c"
    unfolding c'_def c_def nums2'_def map_evens_odds[symmetric] pow_2s nums2_def
    apply (intro 4(2)[symmetric, of nums1 _ b l'])
    subgoal unfolding defs by (rule refl)
    subgoal unfolding a_def by (rule refl)
    subgoal unfolding defs[symmetric] by (rule refl)
    subgoal unfolding defs[symmetric] length_nums2 by simp
    subgoal by (rule l'_le)
    subgoal unfolding defs[symmetric] by (rule nums2_carrier)
    done
  then have length_c': "length c' = 2 ^ l'"
    using length_c by simp
  
  have b_carrier: "set b  fermat_non_unique_carrier"
    unfolding b_def
    apply (intro fft_ifft_carrier nums1_carrier) using length_nums1 by simp
  have c_carrier: "set c  fermat_non_unique_carrier"
    unfolding c_def
    apply (intro fft_ifft_carrier nums2_carrier) using length_nums2 by simp

  have length_nums1': "length nums1' = 2 ^ l'"
    using length_nums1 nums1'_def' by simp
  have length_nums2': "length nums2' = 2 ^ l'"
    using length_nums2 nums2'_def' by simp

  have length_g': "length g' = 2 ^ l'"
    unfolding g'_def by (simp add: length_a length_b' length_c')
  have length_h': "length h' = 2 ^ l'"
    unfolding h'_def by (simp add: length_a length_b' length_c')

  have g_g': "g' = map to_residue_ring g"
  proof (intro nth_equalityI)
    show "length g' = length (map to_residue_ring g)"
      by (simp add: length_g length_g')
  next
    fix i
    assume "i < length g'"
    then have i_less: "i < 2 ^ (l - 1)" unfolding length_g' using l = Suc l' by simp

    have bi_carrier: "b ! i  fermat_non_unique_carrier"
      using set_subseteqD[OF b_carrier] length_b i_less by simp
    have ci_carrier: "c ! i  fermat_non_unique_carrier"
      using set_subseteqD[OF c_carrier] length_c i_less by simp

    have bi_b'i: "to_residue_ring (b ! i) = b' ! i"
      unfolding b_b' by (intro nth_map[symmetric]; simp add: length_b i_less del: l = Suc l' One_nat_def)
    have ci_c'i: "to_residue_ring (c ! i) = c' ! i"
      unfolding c_c' by (intro nth_map[symmetric]; simp add: length_c i_less del: l = Suc l' One_nat_def)
    
    show "g' ! i = (map to_residue_ring g) ! i"
    proof (cases it)
      case True
      then have it_op: "(if it then divide_by_power_of_2 else multiply_with_power_of_2) = divide_by_power_of_2" by argo
      have "map to_residue_ring g ! i = to_residue_ring (g ! i)"
        apply (intro nth_map)
        unfolding length_g using i_less by simp
      also have "... = to_residue_ring (add_fermat (b ! i) (divide_by_power_of_2 (c ! i) (s * ([0..<2 ^ (l - 1)] ! i) mod 2 ^ (k + 1))))"
        unfolding g_def fft_ifft_combine_b_c_add_correct[OF length_b length_c] it_op
        apply (intro arg_cong[where f = to_residue_ring] nth_map3)
        unfolding length_b length_c using i_less by simp_all
      also have "... = to_residue_ring (add_fermat (b ! i) (divide_by_power_of_2 (c ! i) (s * i mod 2 ^ (k + 1))))"
        using i_less by simp
      also have "... = to_residue_ring (b ! i) Fnto_residue_ring (divide_by_power_of_2 (c ! i) (s * i mod 2 ^ (k + 1)))"
        by (intro add_fermat_correct bi_carrier divide_by_power_of_2_closed ci_carrier)
      also have "... = to_residue_ring (b ! i) Fnto_residue_ring (c ! i) FninvFn2 [^]Fn(s * i mod 2 ^ (k + 1))"
        by (intro arg_cong2[where f = "(⊕Fn)"] divide_by_power_of_2_correct refl ci_carrier)
      also have "... = (b' ! i) Fn(c' ! i) FninvFn2 [^]Fn(s * i mod 2 ^ (k + 1))"
        unfolding bi_b'i ci_c'i by (rule refl)
      also have "... = (b' ! i) Fn(c' ! i) FninvFn2 [^]Fn(s * i)"
        by (intro_cong "[cong_tag_2 (⊕Fn), cong_tag_2 (⊗Fn)]" more: inv_pow_mod_carrier_length mod_mod_trivial)
      also have "... = (b' ! i) Fn(c' ! i) Fn((invFn2 [^]Fns) [^]Fni)"
        by (intro_cong "[cong_tag_2 (⊕Fn), cong_tag_2 (⊗Fn)]" more: nat_pow_pow[symmetric] Units_inv_closed two_is_unit)
      finally have 1: "map to_residue_ring g ! i = ..." .
      have "g' ! i = map3 (λx y z. x Fny Fnz) b' (map (([^]Fn) (invFn2 [^]Fns)) [0..<length local.a div 2]) c' ! i"
        unfolding g'_def eqTrueI[OF True] if_True map2_of_map2_r by (rule refl)
      also have "... = (b' ! i) Fn((map (([^]Fn) (invFn2 [^]Fns)) [0..<length local.a div 2]) ! i) Fn(c' ! i)"
        using i_less length_b' length_c' l = Suc l' length_a by (intro nth_map3) simp_all
      also have "... = (b' ! i) Fn(invFn2 [^]Fns) [^]Fni Fn(c' ! i)"
        apply (intro_cong "[cong_tag_2 (⊕Fn), cong_tag_2 (⊗Fn)]")
        using nth_map length_a l = Suc l' i_less by simp
      also have "... = (b' ! i) Fn(c' ! i) Fn(invFn2 [^]Fns) [^]Fni"
        apply (intro arg_cong2[where f = "(⊕Fn)"] refl m_comm nat_pow_closed Units_inv_closed two_is_unit)
        using to_residue_ring_in_carrier ci_c'i[symmetric] by simp
      finally show ?thesis unfolding 1 .
    next
      case False
      then have it_op: "(if it then divide_by_power_of_2 else multiply_with_power_of_2) = multiply_with_power_of_2" by argo
      have "map to_residue_ring g ! i = to_residue_ring (g ! i)"
        apply (intro nth_map)
        unfolding length_g using i_less by simp
      also have "... = to_residue_ring (add_fermat (b ! i) (multiply_with_power_of_2 (c ! i) (s * ([0..<2 ^ (l - 1)] ! i) mod 2 ^ (k + 1))))"
        unfolding g_def fft_ifft_combine_b_c_add_correct[OF length_b length_c] it_op
        apply (intro arg_cong[where f = to_residue_ring] nth_map3)
        unfolding length_b length_c using i_less by simp_all
      also have "... = to_residue_ring (add_fermat (b ! i) (multiply_with_power_of_2 (c ! i) (s * i mod 2 ^ (k + 1))))"
        using i_less by simp
      also have "... = to_residue_ring (b ! i) Fnto_residue_ring (multiply_with_power_of_2 (c ! i) (s * i mod 2 ^ (k + 1)))"
        by (intro add_fermat_correct bi_carrier multiply_with_power_of_2_closed ci_carrier)
      also have "... = to_residue_ring (b ! i) Fnto_residue_ring (c ! i) Fn2 [^]Fn(s * i mod 2 ^ (k + 1))"
        by (intro arg_cong2[where f = "(⊕Fn)"] multiply_with_power_of_2_correct refl ci_carrier)
      also have "... = (b' ! i) Fn(c' ! i) Fn2 [^]Fn(s * i mod 2 ^ (k + 1))"
        unfolding bi_b'i ci_c'i by (rule refl)
      also have "... = (b' ! i) Fn(c' ! i) Fn2 [^]Fn(s * i)"
        by (intro_cong "[cong_tag_2 (⊕Fn), cong_tag_2 (⊗Fn)]" more: pow_mod_carrier_length mod_mod_trivial)
      also have "... = (b' ! i) Fn(c' ! i) Fn((2 [^]Fns) [^]Fni)"
        by (intro_cong "[cong_tag_2 (⊕Fn), cong_tag_2 (⊗Fn)]" more: nat_pow_pow[symmetric] two_in_carrier)
      finally have 1: "map to_residue_ring g ! i = ..." .
      have "g' ! i = map3 (λx y z. x Fny Fnz) b' (map (([^]Fn) (2 [^]Fns)) [0..<length local.a div 2]) c' ! i"
        unfolding g'_def if_False map2_of_map2_r by (simp add: False)
      also have "... = (b' ! i) Fn((map (([^]Fn) (2 [^]Fns)) [0..<length local.a div 2]) ! i) Fn(c' ! i)"
        using i_less length_b' length_c' l = Suc l' length_a by (intro nth_map3) simp_all
      also have "... = (b' ! i) Fn(2 [^]Fns) [^]Fni Fn(c' ! i)"
        apply (intro_cong "[cong_tag_2 (⊕Fn), cong_tag_2 (⊗Fn)]")
        using nth_map length_a l = Suc l' i_less by simp
      also have "... = (b' ! i) Fn(c' ! i) Fn(2 [^]Fns) [^]Fni"
        apply (intro arg_cong2[where f = "(⊕Fn)"] refl m_comm nat_pow_closed two_in_carrier)
        using to_residue_ring_in_carrier ci_c'i[symmetric] by simp
      finally show ?thesis unfolding 1 .
    qed
  qed

  have h_h': "h' = map to_residue_ring h"
  proof (intro nth_equalityI)
    show "length h' = length (map to_residue_ring h)"
      by (simp add: length_h length_h')
  next
    fix i
    assume "i < length h'"
    then have i_less: "i < 2 ^ (l - 1)" unfolding length_h' using l = Suc l' by simp

    have bi_carrier: "b ! i  fermat_non_unique_carrier"
      using set_subseteqD[OF b_carrier] length_b i_less by simp
    have ci_carrier: "c ! i  fermat_non_unique_carrier"
      using set_subseteqD[OF c_carrier] length_c i_less by simp

    have bi_b'i: "to_residue_ring (b ! i) = b' ! i"
      unfolding b_b' by (intro nth_map[symmetric]; simp add: length_b i_less del: l = Suc l' One_nat_def)
    have ci_c'i: "to_residue_ring (c ! i) = c' ! i"
      unfolding c_c' by (intro nth_map[symmetric]; simp add: length_c i_less del: l = Suc l' One_nat_def)
    
    show "h' ! i = (map to_residue_ring h) ! i"
    proof (cases it)
      case True
      then have it_op: "(if it then divide_by_power_of_2 else multiply_with_power_of_2) = divide_by_power_of_2" by argo
      have "map to_residue_ring h ! i = to_residue_ring (h ! i)"
        apply (intro nth_map)
        unfolding length_h using i_less by simp
      also have "... = to_residue_ring (subtract_fermat (b ! i) (divide_by_power_of_2 (c ! i) (s * ([0..<2 ^ (l - 1)] ! i) mod 2 ^ (k + 1))))"
        unfolding h_def fft_ifft_combine_b_c_subtract_correct[OF length_b length_c] it_op
        apply (intro arg_cong[where f = to_residue_ring] nth_map3)
        unfolding length_b length_c using i_less by simp_all
      also have "... = to_residue_ring (subtract_fermat (b ! i) (divide_by_power_of_2 (c ! i) (s * i mod 2 ^ (k + 1))))"
        using i_less by simp
      also have "... = to_residue_ring (b ! i) Fnto_residue_ring (divide_by_power_of_2 (c ! i) (s * i mod 2 ^ (k + 1)))"
        by (intro subtract_fermat_correct bi_carrier divide_by_power_of_2_closed ci_carrier)
      also have "... = to_residue_ring (b ! i) Fnto_residue_ring (c ! i) FninvFn2 [^]Fn(s * i mod 2 ^ (k + 1))"
        by (intro arg_cong2[where f = "(λx y. x Fny)"] divide_by_power_of_2_correct refl ci_carrier)
      also have "... = (b' ! i) Fn(c' ! i) FninvFn2 [^]Fn(s * i mod 2 ^ (k + 1))"
        unfolding bi_b'i ci_c'i by (rule refl)
      also have "... = (b' ! i) Fn(c' ! i) FninvFn2 [^]Fn(s * i)"
        by (intro_cong "[cong_tag_2 (λx y. x Fny), cong_tag_2 (⊗Fn)]" more: inv_pow_mod_carrier_length mod_mod_trivial)
      also have "... = (b' ! i) Fn(c' ! i) Fn((invFn2 [^]Fns) [^]Fni)"
        by (intro_cong "[cong_tag_2 (λx y. x Fny), cong_tag_2 (⊗Fn)]" more: nat_pow_pow[symmetric] Units_inv_closed two_is_unit)
      finally have 1: "map to_residue_ring h ! i = ..." .
      have "h' ! i = map3 (λx y z. x Fny Fnz) b' (map (([^]Fn) (invFn2 [^]Fns)) [0..<length local.a div 2]) c' ! i"
        unfolding h'_def eqTrueI[OF True] if_True map2_of_map2_r by (rule refl)
      also have "... = (b' ! i) Fn((map (([^]Fn) (invFn2 [^]Fns)) [0..<length local.a div 2]) ! i) Fn(c' ! i)"
        using i_less length_b' length_c' l = Suc l' length_a by (intro nth_map3) simp_all
      also have "... = (b' ! i) Fn(invFn2 [^]Fns) [^]Fni Fn(c' ! i)"
        apply (intro_cong "[cong_tag_2 (λx y. x Fny), cong_tag_2 (⊗Fn)]")
        using nth_map length_a l = Suc l' i_less by simp
      also have "... = (b' ! i) Fn(c' ! i) Fn(invFn2 [^]Fns) [^]Fni"
        apply (intro arg_cong2[where f = "(λx y. x Fny)"] refl m_comm nat_pow_closed Units_inv_closed two_is_unit)
        using to_residue_ring_in_carrier ci_c'i[symmetric] by simp
      finally show ?thesis unfolding 1 .
    next
      case False
      then have it_op: "(if it then divide_by_power_of_2 else multiply_with_power_of_2) = multiply_with_power_of_2" by argo
      have "map to_residue_ring h ! i = to_residue_ring (h ! i)"
        apply (intro nth_map)
        unfolding length_h using i_less by simp
      also have "... = to_residue_ring (subtract_fermat (b ! i) (multiply_with_power_of_2 (c ! i) (s * ([0..<2 ^ (l - 1)] ! i) mod 2 ^ (k + 1))))"
        unfolding h_def fft_ifft_combine_b_c_subtract_correct[OF length_b length_c] it_op
        apply (intro arg_cong[where f = to_residue_ring] nth_map3)
        unfolding length_b length_c using i_less by simp_all
      also have "... = to_residue_ring (subtract_fermat (b ! i) (multiply_with_power_of_2 (c ! i) (s * i mod 2 ^ (k + 1))))"
        using i_less by simp
      also have "... = to_residue_ring (b ! i) Fnto_residue_ring (multiply_with_power_of_2 (c ! i) (s * i mod 2 ^ (k + 1)))"
        by (intro subtract_fermat_correct bi_carrier multiply_with_power_of_2_closed ci_carrier)
      also have "... = to_residue_ring (b ! i) Fnto_residue_ring (c ! i) Fn2 [^]Fn(s * i mod 2 ^ (k + 1))"
        by (intro arg_cong2[where f = "(λx y. x Fny)"] multiply_with_power_of_2_correct refl ci_carrier)
      also have "... = (b' ! i) Fn(c' ! i) Fn2 [^]Fn(s * i mod 2 ^ (k + 1))"
        unfolding bi_b'i ci_c'i by (rule refl)
      also have "... = (b' ! i) Fn(c' ! i) Fn2 [^]Fn(s * i)"
        by (intro_cong "[cong_tag_2 (λx y. x Fny), cong_tag_2 (⊗Fn)]" more: pow_mod_carrier_length mod_mod_trivial)
      also have "... = (b' ! i) Fn(c' ! i) Fn((2 [^]Fns) [^]Fni)"
        by (intro_cong "[cong_tag_2 (λx y. x Fny), cong_tag_2 (⊗Fn)]" more: nat_pow_pow[symmetric] two_in_carrier)
      finally have 1: "map to_residue_ring h ! i = ..." .
      have "h' ! i = map3 (λx y z. x Fny Fnz) b' (map (([^]Fn) (2 [^]Fns)) [0..<length local.a div 2]) c' ! i"
        unfolding h'_def if_False map2_of_map2_r by (simp add: False)
      also have "... = (b' ! i) Fn((map (([^]Fn) (2 [^]Fns)) [0..<length local.a div 2]) ! i) Fn(c' ! i)"
        using i_less length_b' length_c' l = Suc l' length_a by (intro nth_map3) simp_all
      also have "... = (b' ! i) Fn(2 [^]Fns) [^]Fni Fn(c' ! i)"
        apply (intro_cong "[cong_tag_2 (λx y. x Fny), cong_tag_2 (⊗Fn)]")
        using nth_map length_a l = Suc l' i_less by simp
      also have "... = (b' ! i) Fn(c' ! i) Fn(2 [^]Fns) [^]Fni"
        apply (intro arg_cong2[where f = "(λx y. x Fny)"] refl m_comm nat_pow_closed two_in_carrier)
        using to_residue_ring_in_carrier ci_c'i[symmetric] by simp
      finally show ?thesis unfolding 1 .
    qed
  qed
  show ?case
    using fntt_def
    unfolding a_def[symmetric] result_eq map_append g_g'[symmetric] h_h'[symmetric]
    by argo
qed

lemma fft_ifft_correct:
  assumes "length a = 2 ^ l"
  assumes "s = 2 ^ i"
  assumes "i + l = k + 1"
  assumes "l > 0"
  assumes "set a  fermat_non_unique_carrier"
  shows "map to_residue_ring (fft_ifft it s a) = NTT ((if it then invFn2 else 2) [^]Fns) (map to_residue_ring a)"
proof -
  let  = "(if it then invFn2 else 2) [^]Fns"
  have inv2s: "(invFn2 [^]Fns) = invFn(2 [^]Fns)"
    by (intro inv_nat_pow[symmetric] two_is_unit)

  then have it_unfold: "it   = invFn(2 [^]Fns)" "¬ it   = 2 [^]Fns"
    by simp_all
  from assms have "l  k + 1" by simp
  from assms have "i  k" by simp
  then have "(2::nat) ^ i  2 ^ k" by simp

  have "2 ^ l div 2 = (2::nat) ^ (l - 1)" using l > 0 by (simp add: Suc_leI power_diff)
  then have pow_2sl: "(2 [^]Fns) [^]Fn((2::nat) ^ l div 2) = Fn𝟭Fn⇙"
    using two_powers_half_carrier_length_residue_ring[of i "l - 1"]
    using l > 0 i + l = k + 1 s = 2 ^ i two_powers_trivial[of "2 ^ i"] i  k
    by simp
  have pr: "primitive_root (2 ^ l) (2 [^]Fns)"
    unfolding assms(2) by (intro two_powers_primitive_root assms i  k)

  from fft_ifft_correct'[OF length a = 2 ^ l l  k + 1 set a  fermat_non_unique_carrier]
  have "map to_residue_ring (fft_ifft it s a) = FNTT''  (map to_residue_ring a)"
    by simp
  also have "... = FNTT'  (map to_residue_ring a)"
    apply (intro FNTT''_FNTT')
    using assms(1) by simp
  also have "... = FNTT  (map to_residue_ring a)"
    by (intro FNTT'_FNTT)
  also have "... = NTT  (map to_residue_ring a)"
    apply (intro FNTT_NTT[of _ "2 ^ l" l])
    subgoal by (intro nat_pow_closed two_in_carrier prop_ifI[where P = "λx. x  carrier Fn"] Units_inv_closed two_is_unit)
    subgoal by argo
    subgoal
    proof (cases it)
      case True
      then show ?thesis unfolding it_unfold(1)[OF True]
      apply (intro primitive_root_inv)
      subgoal by simp
      subgoal by (rule pr)
      done
    next
      case False
      then show ?thesis unfolding it_unfold(2)[OF False] by (intro pr)
    qed
    subgoal
    proof (cases it)
      case True
      show ?thesis unfolding it_unfold(1)[OF True]
        by (intro inv_halfway_property Units_pow_closed two_is_unit pow_2sl)
    next
      case False
      then show ?thesis
        unfolding it_unfold(2)[OF False] by (intro pow_2sl)
    qed
    subgoal using assms(1) by simp
    subgoal apply (intro set_subseteqI) using to_residue_ring_in_carrier by simp
    done
  finally show ?thesis .
qed

lemma fft_correct:
  assumes "length a = 2 ^ l"
  assumes "s = 2 ^ i"
  assumes "i + l = k + 1"
  assumes "l > 0"
  assumes "set a  fermat_non_unique_carrier"
  shows "map to_residue_ring (fft s a) = NTT (2 [^]Fns) (map to_residue_ring a)"
  unfolding fft.simps using fft_ifft_correct[OF assms] by simp

lemma ifft_correct:
  assumes "length a = 2 ^ l"
  assumes "s = 2 ^ i"
  assumes "i + l = k + 1"
  assumes "l > 0"
  assumes "set a  fermat_non_unique_carrier"
  shows "map to_residue_ring (ifft s a) = NTT ((invFn2) [^]Fns) (map to_residue_ring a)"
  unfolding ifft.simps using fft_ifft_correct[OF assms] by simp

end

end