Theory NTT_Rings

section "Number Theoretic Transforms in Rings"

theory NTT_Rings
imports
  "Number_Theoretic_Transform.NTT"
  Karatsuba.Monoid_Sums
  Karatsuba.Karatsuba_Preliminaries
  "../Preliminaries/Schoenhage_Strassen_Preliminaries"
  "../Preliminaries/Schoenhage_Strassen_Ring_Lemmas"
begin

lemma max_dividing_power_factorization:
  fixes a :: nat
  assumes "a  0"
  assumes "k = Max {s. p ^ s dvd a}"
  assumes "r = a div (p ^ k)"
  assumes "prime p"
  shows "a = r * p ^ k" "coprime p r"
  subgoal
  proof -
    have "p ^ 0 dvd a" by simp
    then have "{s. p ^ s dvd a}  {}" by blast
    with assms have "p ^ k dvd a"
      by (metis Max_in finite_divisor_powers mem_Collect_eq not_prime_unit)
    with assms show ?thesis by simp
  qed
  subgoal
  proof (rule ccontr)
    assume "¬ coprime p r"
    then have "p dvd r" using prime_imp_coprime_nat prime p by blast
    then have "p ^ (k + 1) dvd a" using a = r * p ^ k by simp
    then have "k  k + 1"
      using assms Max_ge[of "{s. p ^ s dvd a}" k] Max_in[of "{s. p ^ s dvd a}"]
      by (metis Max.coboundedI finite_divisor_powers mem_Collect_eq not_prime_unit)
    then show "False" by simp
  qed
  done

context cring
begin    

interpretation units_group: group "units_of R"
  by (rule units_group)

interpretation units_subgroup: multiplicative_subgroup R "Units R" "units_of R"
  by (rule units_subgroup)

subsection "Roots of Unity"

definition root_of_unity :: "nat  'a  bool" where
"root_of_unity n μ  μ  carrier R  μ [^] n = 𝟭"

lemma root_of_unityI[intro]: "μ  carrier R  μ [^] n = 𝟭  root_of_unity n μ"
  unfolding root_of_unity_def by simp

lemma root_of_unityD[simp]: "root_of_unity n μ  μ [^] n = 𝟭"
  unfolding root_of_unity_def by simp

lemma root_of_unity_closed[simp]: "root_of_unity n μ  μ  carrier R"
  unfolding root_of_unity_def by simp


context
  fixes n :: nat
  assumes "n > 0"
begin

lemma roots_Units[simp]:
  assumes "root_of_unity n μ"
  shows "μ  Units R"
proof -
  from n > 0 obtain n' where "n = Suc n'"
    using gr0_implies_Suc by auto
  then have "𝟭 = μ  (μ [^] n')"
    using assms nat_pow_Suc2 unfolding root_of_unity_def by auto
  then show "μ  Units R" using assms m_comm[of μ "μ [^] n'"] nat_pow_closed[of μ n']
    unfolding Units_def root_of_unity_def by auto
qed

definition roots_of_unity_group where
"roots_of_unity_group   carrier = {μ. root_of_unity n μ}, monoid.mult = (⊗), one = 𝟭 "

lemma roots_of_unity_group_is_group:
  shows "group roots_of_unity_group"
  apply (intro groupI)
  unfolding roots_of_unity_group_def root_of_unity_def
  apply (simp_all add: nat_pow_distrib m_assoc)
  subgoal for x
    using n > 0
      by (metis Group.nat_pow_Suc Nat.lessE mult.commute nat_pow_closed nat_pow_one nat_pow_pow)
    done

interpretation root_group : group "roots_of_unity_group"
  by (rule roots_of_unity_group_is_group)

interpretation root_subgroup : multiplicative_subgroup R "{μ. root_of_unity n μ}" roots_of_unity_group
  apply unfold_locales
  subgoal using roots_Units n > 0 by blast
  subgoal unfolding roots_of_unity_group_def by simp
  done

lemma root_of_unity_inv:
  assumes "root_of_unity n μ"
  shows "root_of_unity n (inv μ)"
  using assms root_group.inv_closed[of μ] root_subgroup.carrier_M root_subgroup.inv_eq[of μ] by simp

lemma inv_root_of_unity:
  assumes "root_of_unity n μ"
  shows "inv μ = μ [^] (n - 1)"
proof -
  have "μ  Units R" using assms
    using roots_Units by blast
  then have "inv μ = μ [^] (-1 :: int)"
    using units_group.int_pow_neg units_subgroup.inv_eq units_subgroup.int_pow_eq
    using units_group.int_pow_1 by force
  also have "... = 𝟭  μ [^] (-1 :: int)"
    apply (intro l_one[symmetric])
    using μ  Units R by (metis Units_inv_closed calculation)
  also have "... = μ [^] n  μ [^] (-1 :: int)"
    using assms by simp
  also have "... = μ [^] (int n)  μ [^] (-1 :: int)"
    using Units_closed[OF μ  Units R]
    by (simp add: int_pow_int)
  also have "... = μ [^] (int n - 1)"
    using units_group.int_pow_mult[of μ] μ  Units R units_subgroup.int_pow_eq[of μ]
    using units_of_mult units_subgroup.carrier_M 
    by (metis add.commute uminus_add_conv_diff)
  also have "... = μ [^] (n - 1)"
    using n > 0 Units_closed[OF μ  Units R]
    by (metis Suc_diff_1 add_diff_cancel_left' int_pow_int mult_Suc_right nat_mult_1 of_nat_1 of_nat_add)
  finally show ?thesis .
qed

lemma inv_pow_root_of_unity:
  assumes "root_of_unity n μ"
  assumes "i  {1..<n}"
  shows "(inv μ) [^] i = μ [^] (n - i)" "n - i  {1..<n}"
proof -
  have "(inv μ) [^] i = (μ [^] (n - (1::nat))) [^] i"
    using assms inv_root_of_unity by algebra
  also have "... = μ [^] ((n - 1) * i)"
    apply (intro nat_pow_pow) using assms roots_Units Units_closed by blast
  also have "... = μ [^] n  μ [^] ((n - 1) * i)"
    using assms root_of_unity_def[of n μ] by fastforce
  also have "... = μ [^] (n + (n - 1) * i)"
    apply (intro nat_pow_mult) using assms roots_Units Units_closed by blast
  also have "... = μ [^] (n * i + (n - i))"
  proof (intro arg_cong[where f = "([^]) μ"])
    have "int (n + (n - 1) * i) = int (n * i + (n - i))"
    proof -
      have "int (n + (n - 1) * i) = int n + int (n - 1) * int i"
        by simp
      also have "... = int n + (int n - int 1) * int i"
        using n > 0 by fastforce
      also have "... = int n + int n * int i - int i"
        by (simp add: left_diff_distrib')
      also have "... = int n * int i + (int n - int i)"
        by simp
      also have "... = int (n * i) + int (n - i)"
        using assms(2) by fastforce
      finally show ?thesis by presburger
    qed
    then show "n + (n - 1) * i = n * i + (n - i)" by presburger
  qed
  also have "... = (μ [^] n) [^] i  μ [^] (n - i)"
    using nat_pow_mult nat_pow_pow
    using assms roots_Units Units_closed by algebra
  also have "... = μ [^] (n - i)"
    using assms unfolding root_of_unity_def by simp
  finally show "(inv μ) [^] i = μ [^] (n - i)" by blast
  show "n - i  {1..<n}" using assms by auto
qed

lemma root_of_unity_nat_pow_closed:
  assumes "root_of_unity n μ"
  shows "root_of_unity n (μ [^] (m :: nat))"
  using assms root_group.nat_pow_closed root_subgroup.nat_pow_eq by simp

lemma root_of_unity_powers:
  assumes "root_of_unity n μ"
  shows "μ [^] i = μ [^] (i mod n)"
proof -
  have[simp]: "μ  carrier R" using assms by simp
  define s t where "s = i div n" "t = i mod n"
  then have "i = s * n + t" "t < n" using n > 0 by simp_all
  then have "μ [^] i = μ [^] (s * n)  μ [^] t" by (simp add: nat_pow_mult)
  also have "μ [^] (s * n) = (μ [^] n) [^] s" by (simp add: nat_pow_pow mult.commute)
  also have "... = 𝟭" using assms by simp
  finally show ?thesis using t = i mod n by simp
qed

lemma root_of_unity_powers_modint:
  assumes "root_of_unity n μ"
  shows "μ [^] (i :: int) = μ [^] (i mod int n)"
proof -
  have "μ  Units R" "μ [^] n = 𝟭" using assms by simp_all
  define s t where "s = i div int n" "t = i mod int n"
  then have "i = s * int n + t" "t  0" "t < int n" using n > 0 by simp_all
  then have "μ [^] i = μ [^] (s * int n)  μ [^] t"
    using int_pow_mult[OF μ  Units R] by simp
  also have "... = (μ [^] int n) [^] s  μ [^] t"
    by (intro_cong "[cong_tag_2 (⊗)]" more: refl) (simp add: int_pow_pow μ  Units R mult.commute)
  also have "... = (μ [^] n) [^] s  μ [^] t"
    apply (intro_cong "[cong_tag_2 (⊗), cong_tag_1 (λi. i [^] s)]" more: refl)
    using n > 0 by (simp add: int_pow_int)
  also have "... = μ [^] t"
    using int_pow_closed[OF μ  Units R] Units_closed l_one
    by (simp add: μ [^] n = 𝟭 int_pow_one int_pow_closed)
  finally show ?thesis unfolding s_t_def .
qed

lemma root_of_unity_powers_nat:
  assumes "root_of_unity n μ"
  assumes "i mod n = j mod n"
  shows "μ [^] i = μ [^] j"
  using assms root_of_unity_powers by metis

lemma root_of_unity_powers_int:
  assumes "root_of_unity n μ"
  assumes "i mod int n = j mod int n"
  shows "μ [^] i = μ [^] j"
  using assms root_of_unity_powers_modint by metis

end

subsection "Primitive Roots"

definition primitive_root :: "nat  'a  bool" where
"primitive_root n μ  root_of_unity n μ  (i  {1..<n}. μ [^] i  𝟭)"

lemma primitive_rootI[intro]:
  assumes "μ  carrier R"
  assumes "μ [^] n = 𝟭"
  assumes "i. i > 0  i < n  μ [^] i  𝟭"
  shows "primitive_root n μ"
  unfolding primitive_root_def root_of_unity_def using assms by simp

lemma primitive_root_is_root_of_unity[simp]: "primitive_root n μ  root_of_unity n μ"
  unfolding primitive_root_def by simp

lemma primitive_root_recursion:
  assumes "even n"
  assumes "primitive_root n μ"
  shows "primitive_root (n div 2) (μ [^] (2 :: nat))"
  unfolding primitive_root_def root_of_unity_def
  apply (intro conjI)
  subgoal
    using assms(2) unfolding primitive_root_def root_of_unity_def by blast
  subgoal
    using nat_pow_pow[of μ "2::nat" "n div 2"] assms apply simp
    unfolding primitive_root_def root_of_unity_def apply simp
    done
  subgoal
  proof
    fix i
    assume "i  {1..<n div 2}"
    then have "2 * i  {1..<n}" using even n by auto
    have "(μ [^] (2::nat)) [^] i = μ [^] (2 * i)"
      using assms unfolding primitive_root_def root_of_unity_def by (simp add: nat_pow_pow)
    also have "...  𝟭"
      using assms unfolding primitive_root_def using 2 * i  {1..<n} by blast
    finally show "(μ [^] (2::nat)) [^] i  𝟭" .
  qed
  done

lemma primitive_root_inv:
  assumes "n > 0"
  assumes "primitive_root n μ"
  shows "primitive_root n (inv μ)"
  unfolding primitive_root_def
proof (intro conjI)
  show "root_of_unity n (inv μ)" using assms unfolding primitive_root_def
    by (simp add: root_of_unity_inv)
  show "i{1..<n}. inv μ [^] i  𝟭" using assms unfolding primitive_root_def
    by (metis Group.nat_pow_0 Units_inv_inv bot_nat_0.extremum_strict nat_neq_iff root_of_unity_def root_of_unity_inv roots_Units)
qed

subsection "Number Theoretic Transforms"

definition NTT :: "'a  'a list  'a list" where
"NTT μ a  let n = length a in [j  [0..<n]. (a ! j)  (μ [^] i) [^] j. i  [0..<n]]"

lemma NTT_length[simp]: "length (NTT μ a) = length a"
  unfolding NTT_def by (metis length_map map_nth)

lemma NTT_nth:
  assumes "length a = n"
  assumes "i < n"
  shows "NTT μ a ! i = (j  [0..<n]. (a ! j)  (μ [^] i) [^] j)"
  unfolding NTT_def using assms by auto

lemma NTT_nth_2:
  assumes "length a = n"
  assumes "i < n"
  assumes "μ  carrier R"
  shows "NTT μ a ! i = (j  [0..<n]. (a ! j)  (μ [^] (i * j)))"
  unfolding NTT_nth[OF assms(1) assms(2)]
  by (intro monoid_sum_list_cong arg_cong[where f = "(⊗) _"] nat_pow_pow assms(3))

lemma NTT_nth_closed:
  assumes "set a  carrier R"
  assumes "μ  carrier R"
  assumes "length a = n"
  assumes "i < n"
  shows "NTT μ a ! i  carrier R"
proof -
  have "NTT μ a ! i = (j  [0..<length a]. (a ! j)  (μ [^] i) [^] j)"
    using NTT_nth assms by blast
  also have "...  carrier R"
    by (intro monoid_sum_list_closed m_closed nat_pow_closed assms(2) set_subseteqD[OF assms(1)]) simp
  finally show ?thesis .
qed

lemma NTT_closed:
  assumes "set a  carrier R"
  assumes "μ  carrier R"
  shows "set (NTT μ a)  carrier R"
  using assms NTT_nth_closed[of a μ]
  by (intro subsetI)(metis NTT_length in_set_conv_nth)

lemma "primitive_root 1 𝟭"
  unfolding primitive_root_def root_of_unity_def
  by simp
  
lemma "( 𝟭) [^] (2::nat) = 𝟭"
  by (simp add: numeral_2_eq_2) algebra
lemma "𝟭  𝟭  𝟬  primitive_root 2 ( 𝟭)"
  unfolding primitive_root_def root_of_unity_def
  apply (intro conjI)
  subgoal by simp
  subgoal by (simp add: numeral_2_eq_2, algebra)
  subgoal
  proof (standard, rule ccontr)
    fix i
    assume "𝟭  𝟭  𝟬" "i  {1::nat..<2}"
    then have "i = 1" by simp
    assume "¬  𝟭 [^] i  𝟭"
    then have " 𝟭 = 𝟭" using i = 1 by simp
    then have "𝟭  𝟭 = 𝟬" using l_neg by fastforce
    thus False using 𝟭  𝟭  𝟬 by simp
  qed
  done

subsubsection "Inversion Rule"

theorem inversion_rule:
  fixes μ :: 'a
  fixes n :: nat
  assumes "n > 0"
  assumes "primitive_root n μ"
  assumes good: "i. i  {1..<n}  (j  [0..<n]. (μ [^] i) [^] j) = 𝟬"
  assumes[simp]: "length a = n"
  assumes[simp]: "set a  carrier R"
  shows "NTT (inv μ) (NTT μ a) = map (λx. nat_embedding n  x) a"
proof (intro nth_equalityI)
  have "μ  Units R" using assms unfolding primitive_root_def using roots_Units by blast
  then have[simp]: "μ  carrier R" by blast
  show "length (NTT (inv μ) (NTT μ a)) = length (map ((⊗) (nat_embedding n)) a)" using NTT_length
    by simp
  fix i
  assume "i < length (NTT (inv μ) (NTT μ a))"
  then have "i < n" by simp

  have[simp]: "inv μ  carrier R" using assms roots_Units unfolding primitive_root_def by blast
  then have[simp]: "i :: nat. (inv μ) [^] i  carrier R" by simp

  have 0: "NTT (inv μ) (NTT μ a) ! i = (j  [0..<n]. (NTT μ a ! j)  ((inv μ) [^] i) [^] j)"
    using NTT_nth
    using assms NTT_length i < n by blast
  also have "... = (j  [0..<n]. (k  [0..<n]. a ! k  μ [^] ((int k - int i) * int j)))"
  proof (intro monoid_sum_list_cong)
    fix j
    assume "j  set [0..<n]"
    then have[simp]: "j < n" by simp
    have nj: "(NTT μ a ! j) = (k  [0..<n]. a ! k  (μ [^] j) [^] k)"
      using NTT_nth by simp
    have "...  ((inv μ) [^] i) [^] j = (k  [0..<n]. a ! k  ((μ [^] j) [^] k)  ((inv μ) [^] i) [^] j)"
      apply (intro monoid_sum_list_in_right[symmetric] nat_pow_closed m_closed)
      using set_subseteqD[OF assms(5)] by simp_all
    also have "... = (k  [0..<n]. a ! k  μ [^] ((int k - int i) * int j))"
    proof (intro monoid_sum_list_cong)
      fix k
      assume "k  set [0..<n]"
      have "a ! k  (μ [^] j) [^] k  (inv μ [^] i) [^] j = a ! k  ((μ [^] j) [^] k  (inv μ [^] i) [^] j)"
        apply (intro m_assoc nat_pow_closed)
        using set_subseteqD[OF assms(5)] k  set [0..<n] by simp_all
      also have "inv μ [^] i = μ [^] (- int i)"
        by (metis μ  Units R cring.units_int_pow_neg int_pow_int is_cring)
      also have "((μ [^] j) [^] k  (μ [^] (- int i)) [^] j) = μ [^] (int j * int k - int i * int j)"
        using μ  Units R
        by (simp add: int_pow_int[symmetric] int_pow_pow int_pow_mult)
      also have "... = μ [^] ((int k - int i) * int j)"
        apply (intro arg_cong[where f = "([^]) _"])
        by (simp add: mult.commute right_diff_distrib')
      finally show "a ! k  (μ [^] j) [^] k  (inv μ [^] i) [^] j = a ! k  μ [^] ((int k - int i) * int j)"
        using inv μ [^] i = μ [^] (- int i) by argo
    qed
    finally show "NTT μ a ! j  (inv μ [^] i) [^] j = monoid_sum_list (λk. a ! k  μ [^] ((int k - int i) * int j)) [0..<n]"
      by (simp add: nj)
  qed
  also have "... = (k  [0..<n]. (j  [0..<n]. a ! k  μ [^] ((int k - int i) * int j)))"
    apply (intro monoid_sum_list_swap m_closed)
    subgoal for j k
      using assms by (metis atLeastLessThan_iff atLeastLessThan_upt nth_mem subset_eq)
    subgoal for j k
      using μ  Units R
      using units_of_int_pow[OF μ  Units R]
      using group.int_pow_closed[OF units_group, of μ]
      by (metis Units_closed units_of_carrier)
    done
  also have "... = (k  [0..<n]. a ! k  (j  [0..<n]. μ [^] ((int k - int i) * int j)))"
    apply (intro monoid_sum_list_cong monoid_sum_list_in_left)
    subgoal using set_subseteqD[OF assms(5)] by simp
    subgoal for j
      by (simp add: Units_closed int_pow_closed μ  Units R)
    done
  also have "... = (k  [0..<n]. a ! k  (if i = k then nat_embedding n else 𝟬))"
  proof (intro monoid_sum_list_cong arg_cong[where f = "(⊗) _"])
    fix k
    assume "k  set [0..<n]"
    then have[simp]: "k < n" by simp
    consider "i < k" | "i = k" | "i > k" by fastforce
    then show "(j  [0..<n]. μ [^] ((int k - int i) * int j)) = (if i = k then nat_embedding n else 𝟬)"
    proof (cases)
      case 1
      then have "j. j < n  μ [^] ((int k - int i) * int j) = (μ [^] (k - i)) [^] j"
      proof -
        fix j
        assume "j < n"
        have "(int k - int i) * int j = int ((k - i) * j)" using 1 by auto
        then have "μ [^] ((int k - int i) * int j) = μ [^] int ((k - i) * j)"
          by argo
        also have "... = μ [^] ((k - i) * j)"
          by (intro int_pow_int)
        also have "... = (μ [^] (k - i)) [^] j"
          by (intro nat_pow_pow[symmetric] μ  carrier R)
        finally show "μ [^] ((int k - int i) * int j) = (μ [^] (k - i)) [^] j" .
      qed
      then have "(j  [0..<n]. μ [^] ((int k - int i) * int j)) = (j  [0..<n]. (μ [^] (k - i)) [^] j)"
        by (intro monoid_sum_list_cong, simp)
      also have "... = 𝟬"
        using good[of "k - i"]
      proof
        show "k - i  {1..<n}" using 1 k < n by (simp add: less_imp_diff_less)
      qed simp
      finally show ?thesis using 1 by simp
    next
      case 2
      then have "j. j < n  μ [^] ((int k - int i) * int j) = 𝟭" by simp
      then have "(j  [0..<n]. μ [^] ((int k - int i) * int j)) = nat_embedding n"
        using monoid_sum_list_const[of 𝟭 "[0..<n]"]
        using monoid_sum_list_cong[of "[0..<n]" "λj. μ [^] ((int k - int i) * int j)" "λj. 𝟭"]
        by simp
      then show ?thesis using 2 by simp
    next
      case 3
      then have "j. j < n  μ [^] ((int k - int i) * int j) = (μ [^] (n + k - i)) [^] j"
      proof -
        fix j
        assume "j < n"
        have "μ [^] ((int k - int i) * int j) = (μ [^] (int k - int i)) [^] j"
          using int_pow_pow by (metis μ  Units R int_pow_int)
        also have "... = (μ [^] n  μ [^] (int k - int i)) [^] j"
        proof -
          have "μ [^] (int k - int i)  carrier R"
            using μ  Units R int_pow_closed Units_closed by simp
          then have "μ [^] (int k - int i) = μ [^] n  μ [^] (int k - int i)"
            using l_one assms(2) unfolding primitive_root_def root_of_unity_def
            by presburger
          then show ?thesis by simp
        qed
        also have "... = (μ [^] (int n)  μ [^] (int k - int i)) [^] j"
          by (simp add: int_pow_int)
        also have "... = (μ [^] (int n + int k - int i)) [^] j"
          using μ  Units R by (simp add: int_pow_mult add_diff_eq)
        finally show "μ [^] ((int k - int i) * int j) = (μ [^] (n + k - i)) [^] j" using 3
          by (metis (no_types, opaque_lifting) i < n diff_cancel2 diff_diff_cancel diff_le_self int_plus int_pow_int less_or_eq_imp_le of_nat_diff)
      qed
      then have "(j  [0..<n]. μ [^] ((int k - int i) * int j)) = (j  [0..<n]. (μ [^] (n + k - i)) [^] j)"
        by (intro monoid_sum_list_cong, simp)
      also have "... = 𝟬"
        using good[of "n + k - i"]
      proof
        show "n + k - i  {1..<n}" using 3 k < n i < n by fastforce
      qed simp
      finally show ?thesis using 3 by simp
    qed
  qed
  also have "... = (k  [0..<n]. a ! k  (nat_embedding n  delta k i))"
    apply (intro monoid_sum_list_cong)
    unfolding delta_def
    by simp
  also have "... = (k  [0..<n]. nat_embedding n  (delta k i  a ! k))"
    apply (intro monoid_sum_list_cong)
    using m_assoc m_comm delta_closed set_subseteqD[OF assms(5)] nat_embedding_closed by simp
  also have "... = nat_embedding n  (k  [0..<n]. delta k i  a ! k)"
    using set_subseteqD[OF assms(5)]
    by (intro monoid_sum_list_in_left) auto
  also have "... = nat_embedding n  a ! i"
    using monoid_sum_list_delta[of n "λk. a ! k" i] i < n assms
    by (metis (no_types, lifting) nth_mem subsetD)
  finally show "NTT (inv μ) (NTT μ a) ! i = map ((⊗) (nat_embedding n)) a ! i"
    using nth_map i < n length a = n NTT_length 0
    by simp
qed

lemma inv_good:
  assumes "n > 0"
  assumes "primitive_root n μ"
  assumes good: "i. i  {1..<n}  (j  [0..<n]. (μ [^] i) [^] j) = 𝟬"
  shows "primitive_root n (inv μ)"
      "i. i  {1..<n}  (j  [0..<n]. ((inv μ) [^] i) [^] j) = 𝟬"
  subgoal using assms by (simp add: primitive_root_inv)
  subgoal for i
  proof -
    assume "i  {1..<n}"
    then have "n - i  {1..<n}" by auto
    then have "(j  [0..<n]. (μ [^] (n - i)) [^] j) = 𝟬"
      using assms by blast
    moreover have "μ [^] (n - i) = inv μ [^] i"
      using assms inv_pow_root_of_unity[of n μ i] i  {1..<n}
      by auto
    ultimately show "(j  [0..<n]. ((inv μ) [^] i) [^] j) = 𝟬" by simp
  qed
  done

lemma inv_halfway_property:
  assumes "μ  Units R"
  assumes "μ [^] (i::nat) =  𝟭"
  shows "(inv μ) [^] i =  𝟭"
proof -
  have "(inv μ) [^] i = (invunits_of Rμ) [^] i"
    by (intro arg_cong[where f = "λj. j [^] i"] units_of_inv[symmetric] assms(1))
  also have "... = (invunits_of Rμ) [^]units_of Ri"
    apply (intro units_of_pow[symmetric])
    using units_group.Units_inv_Units assms(1) by simp
  also have "... = invunits_of R(μ [^]units_of Ri)"
    apply (intro units_group.nat_pow_inv)
    using assms(1) by (simp add: units_of_def)
  also have "... = inv (μ [^]units_of Ri)"
    apply (intro units_of_inv)
    using assms(1) units_group.nat_pow_closed by (simp add: units_of_def)
  also have "... = inv (μ [^] i)"
    using units_of_pow assms(1) by simp
  finally have "(inv μ) [^] i = inv (μ [^] i)" .
  also have "... = inv ( 𝟭)" using assms(2) by simp
  also have "... =  𝟭" by simp
  finally show ?thesis .
qed

lemma sufficiently_good_aux:
  assumes "primitive_root m η"
  assumes "m = 2 ^ j"
  assumes "η [^] (m div 2) =  𝟭"
  assumes "odd r"
  assumes "r * 2 ^ k < m"
  shows "(l  [0..<m]. (η [^] (r * 2 ^ k)) [^] l) = 𝟬"
  using assms
proof (induction k arbitrary: η m j)
  case 0
  then have "root_of_unity m η" by simp
  then have "η  carrier R" by simp
  have "j > 0"
  proof (rule ccontr)
    assume "¬ j > 0"
    then have "j = 0" by simp
    then have "m = 1" using 0 by simp
    then have "r * 2 ^ k = 0" using 0 by simp
    then have "r = 0" by simp
    then show "False" using odd r by simp
  qed
  then have "even m" using 0 by simp
  then have "m = m div 2 + m div 2" by auto
  then have "(l  [0..<m]. (η [^] (r * 2 ^ 0)) [^] l) = (l  [0..<m div 2 + m div 2]. (η [^] r) [^] l)"
    by simp
  also have "... = (l  [0..<m div 2]. (η [^] r) [^] l)  (l  [m div 2..<m div 2 + m div 2]. (η [^] r) [^] l)"
    by (intro monoid_sum_list_split[symmetric] nat_pow_closed, rule η  carrier R)
  also have "... = (l  [0..<m div 2]. (η [^] r) [^] l)  (l  [0..<m div 2]. (η [^] r) [^] (m div 2 + l))"
    by (intro arg_cong[where f = "(⊕) _"] monoid_sum_list_index_shift_0)
  also have "... = (l  [0..<m div 2]. (η [^] r) [^] l  (η [^] r) [^] (m div 2 + l))"
    by (intro monoid_sum_list_add_in nat_pow_closed; rule η  carrier R)
  also have "... = (l  [0..<m div 2]. (η [^] r) [^] l  (η [^] r) [^] l)"
  proof (intro monoid_sum_list_cong)
    fix l
    have "(η [^] r) [^] (m div 2 + l) = (η [^] r) [^] (m div 2)  (η [^] r) [^] l"
      by (intro nat_pow_mult[symmetric] nat_pow_closed, rule η  carrier R)
    also have "(η [^] r) [^] (m div 2) = ( 𝟭) [^] r"
      unfolding nat_pow_pow[OF η  carrier R] mult.commute[of r _]
      by (simp only: nat_pow_pow[symmetric] η  carrier R η [^] (m div 2) =  𝟭)
    also have "... =  𝟭" using odd r
      by (simp add: powers_of_negative)
    finally have "(η [^] r) [^] (m div 2 + l) =  ((η [^] r) [^] l)"
      using η  carrier R nat_pow_closed by algebra
    then show "(η [^] r) [^] l  (η [^] r) [^] (m div 2 + l) = (η [^] r) [^] l  (η [^] r) [^] l"
      unfolding minus_eq
      by (intro arg_cong[where f = "(⊕) _"])
  qed
  also have "... = (l  [0..<m div 2]. 𝟬)"
    by (intro monoid_sum_list_cong) (simp add: η  carrier R)
  also have "... = 𝟬" by simp
  finally show ?case .
next
  case (Suc k)
  have "j > 0"
  proof (rule ccontr)
    assume "¬ j > 0"
    then have "j = 0" by simp
    then have "m = 1" using Suc by simp
    then have "r * 2 ^ k = 0" using Suc by simp
    then have "r = 0" by simp
    then show "False" using odd r by simp
  qed
  then have "even m" using Suc by simp
  then have "m = m div 2 + m div 2" by auto
  have "root_of_unity m η" using primitive_root m η by simp
  then have "η  carrier R" by simp
  from j > 0 obtain j' where "j = Suc j'"
    using gr0_implies_Suc by blast
  then have "m div 2 = 2 ^ j'" using m = 2 ^ j by simp
  have "j' > 0"
  proof (rule ccontr)
    assume "¬ j' > 0"
    then have "j' = 0" by simp
    then have "m = 2" using m = 2 ^ j j = Suc j' by simp
    then have "r * 2 ^ Suc k < 2" using Suc by simp
    then show "False" using odd r by simp
  qed
  then have "even (m div 2)" using m div 2 = 2 ^ j' by simp
  have IH': "(l  [0..<m div 2]. ((η [^] (2::nat)) [^] (r * 2 ^ k)) [^] l) = 𝟬"
    apply (intro Suc.IH[of "m div 2" "η [^] (2::nat)" j'])
    subgoal using primitive_root_recursion[OF even m, OF primitive_root m η] .
    subgoal using m = 2 ^ j j = Suc j' by simp
    subgoal
      by (simp add: η  carrier R nat_pow_pow even (m div 2) η [^] (m div 2) =  𝟭)
    subgoal using odd r .
    subgoal using r * 2 ^ (Suc k) < m even m by auto
    done
  have "(l  [0..<m]. (η [^] (r * 2 ^ (Suc k))) [^] l) = (l  [0..<m]. ((η [^] (2::nat)) [^] (r * 2 ^ k)) [^] l)"
    unfolding nat_pow_pow[OF η  carrier R]
    apply (intro monoid_sum_list_cong arg_cong[where f = "λi. i [^] _"])
    apply (intro arg_cong[where f = "([^]) _"])
    by simp
  also have "... = (l  [0..<m div 2 + m div 2]. ((η [^] (2::nat)) [^] (r * 2 ^ k)) [^] l)"
    using m = m div 2 + m div 2 by argo
  also have "... = (l  [0..<m div 2]. ((η [^] (2::nat)) [^] (r * 2 ^ k)) [^] l)  (l  [m div 2..<m div 2 + m div 2]. ((η [^] (2::nat)) [^] (r * 2 ^ k)) [^] l)"
    by (intro monoid_sum_list_split[symmetric] nat_pow_closed, rule η  carrier R)
  also have "... = 𝟬  (l  [m div 2..<m div 2 + m div 2]. ((η [^] (2::nat)) [^] (r * 2 ^ k)) [^] l)"
    using IH' by argo
  also have "... = (l  [m div 2..<m div 2 + m div 2]. ((η [^] (2::nat)) [^] (r * 2 ^ k)) [^] l)"
    by (intro l_zero monoid_sum_list_closed nat_pow_closed, rule η  carrier R)
  also have "... = (l  [0..<m div 2]. ((η [^] (2::nat)) [^] (r * 2 ^ k)) [^] (m div 2 + l))"
    by (intro monoid_sum_list_index_shift_0)
  also have "... = (l  [0..<m div 2]. ((η [^] (2::nat)) [^] (r * 2 ^ k)) [^] (m div 2)  ((η [^] (2::nat)) [^] (r * 2 ^ k)) [^] l)"
    by (intro monoid_sum_list_cong nat_pow_mult[symmetric] nat_pow_closed, rule η  carrier R)
  also have "... = ((η [^] (2::nat)) [^] (r * 2 ^ k)) [^] (m div 2)  (l  [0..<m div 2]. ((η [^] (2::nat)) [^] (r * 2 ^ k)) [^] l)"
    by (intro monoid_sum_list_in_left nat_pow_closed; rule η  carrier R)
  also have "... = ((η [^] (2::nat)) [^] (r * 2 ^ k)) [^] (m div 2)  𝟬"
    using IH' by argo
  also have "... = 𝟬"
    by (intro r_null nat_pow_closed, rule η  carrier R)
  finally show ?case .
qed


lemma sufficiently_good:
  assumes "primitive_root n μ"
  assumes "domain R  (n = 2 ^ k  μ [^] (n div 2) =  𝟭)"
  shows good: "i. i  {1..<n}  (j  [0..<n]. (μ [^] i) [^] j) = 𝟬"
proof (cases "domain R")
  case True
  fix i
  assume "i  {1..<n}"

  have "root_of_unity n μ" using assms(1) by simp
  then have "μ  carrier R" "μ [^] n = 𝟭" by simp_all

  have "μ [^] i  𝟭" using assms(1) i  {1..<n} unfolding primitive_root_def
    by simp
  then have "𝟭  μ [^] i  𝟬" using μ  carrier R by simp

  have "(μ [^] i) [^] n = 𝟭"
    unfolding nat_pow_pow[OF μ  carrier R]
    using root_of_unity_powers[OF _ root_of_unity n μ, of "i * n"]
    by (cases "n > 0"; simp)
  then have "𝟬 = 𝟭  (μ [^] i) [^] n" by algebra
  also have "... = (𝟭  μ [^] i)  (j  [0..<n]. (μ [^] i) [^] j)"
    by (intro geo_monoid_list_sum[symmetric] nat_pow_closed μ  carrier R)
  finally show "(j  [0..<n]. (μ [^] i) [^] j) = 𝟬"
    using 𝟭  μ [^] i  𝟬 True μ  carrier R
    by (metis domain.integral minus_closed monoid_sum_list_closed nat_pow_closed one_closed)
next
  case False
  then have "n = 2 ^ k" "μ [^] (n div 2) =  𝟭" using assms(2) by auto

  have "root_of_unity n μ" using primitive_root n μ by simp
  then have "μ  carrier R" "μ [^] n = 𝟭" by simp_all
  
  fix i
  assume "i  {1..<n}"
  define l where "l = Max {s. 2 ^ s dvd i}"
  define r where "r = i div 2 ^ l"
  from i  {1..<n} have "i  0" by simp
  then have "i = r * 2 ^ l" "odd r" using max_dividing_power_factorization[of i l 2 r]
    using l_def r_def coprime_left_2_iff_odd[of r] by simp_all

  show "(j  [0..<n]. (μ [^] i) [^] j) = 𝟬"
    apply (simp only: i = r * 2 ^ l)
    apply (intro sufficiently_good_aux[of n μ k r l, OF primitive_root n μ n = 2 ^ k μ [^] (n div 2) =  𝟭 odd r])
    using i = r * 2 ^ l i  {1..<n} by simp
qed

corollary inversion_rule_inv:
  fixes μ :: 'a
  fixes n :: nat
  assumes "n > 0"
  assumes "primitive_root n μ"
  assumes good: "i. i  {1..<n}  (j  [0..<n]. (μ [^] i) [^] j) = 𝟬"
  assumes[simp]: "length a = n"
  assumes[simp]: "set a  carrier R"
  shows "NTT μ (NTT (inv μ) a) = map (λx. nat_embedding n  x) a"
  using assms inv_good[of n μ] inversion_rule[of n "inv μ" a]
  using Units_inv_inv[of μ]
  using roots_Units[of n μ]
  unfolding primitive_root_def
  by algebra

subsubsection "Convolution Theorem"

lemma root_of_unity_power_sum_product:
  assumes "root_of_unity n x"
  assumes[simp]: "i. i < n  f i  carrier R"
  assumes[simp]: "i. i < n  g i  carrier R"
  shows "(i  [0..<n]. f i  x [^] i)  (i  [0..<n]. g i  x [^] i) =
    (k  [0..<n]. (i  [0..<n]. f i  g ((n + k - i) mod n))  x [^] k)"
proof (cases "n > 0")
  case False
  then have "n = 0" by simp
  then show ?thesis by simp
next
  case True
  have[simp]: "x  carrier R" using root_of_unity n x by simp

  have "(k  [0..<n]. (i  [0..<n]. f i  g ((n + k - i) mod n))  x [^] k) =
      (k  [0..<n]. (i  [0..<n]. f i  g ((n + k - i) mod n)  x [^] k))"
    by (intro monoid_sum_list_cong monoid_sum_list_in_right[symmetric] nat_pow_closed m_closed)
        simp_all
  also have "... = (k  [0..<n]. (i  [0..<n]. f i  g ((n + k - i) mod n)  x [^] ((n + k - i) mod n + i)))"
    apply (intro monoid_sum_list_cong arg_cong[where f = "(⊗) _"])
    apply (intro root_of_unity_powers_nat[OF n > 0 root_of_unity n x])
    by (simp add: add.commute mod_add_right_eq)
  also have "... = (k  [0..<n]. (i  [0..<n]. f i  g ((n + k - i) mod n)  (x [^] ((n + k - i) mod n)  x [^] i)))"
    by (intro monoid_sum_list_cong arg_cong[where f = "(⊗) _"] nat_pow_mult[symmetric]) simp
  also have "... = (k  [0..<n]. (i  [0..<n]. f i  x [^] i  (g ((n + k - i) mod n)  x [^] ((n + k - i) mod n))))"
  proof -
    have reorder: "a b c d.  a  carrier R; b  carrier R; c  carrier R; d  carrier R  
      a  b  (c  d) = a  d  (b  c)"
      using m_comm m_assoc by algebra
    show ?thesis
      by (intro monoid_sum_list_cong reorder nat_pow_closed) simp_all
  qed
  also have "... = (i  [0..<n]. (k  [0..<n]. f i  x [^] i  (g ((n + k - i) mod n)  x [^] ((n + k - i) mod n))))"
    by (intro monoid_sum_list_swap m_closed nat_pow_closed) simp_all
  also have "... = (i  [0..<n]. f i  x [^] i  (k  [0..<n]. (g ((n + k - i) mod n)  x [^] ((n + k - i) mod n))))"
    by (intro monoid_sum_list_cong monoid_sum_list_in_left m_closed nat_pow_closed) simp_all
  also have "... = (i  [0..<n]. f i  x [^] i  (j  [0..<n]. (g j  x [^] j)))"
    (is "(i  _. _  ?lhs i) = (i  _. _  ?rhs i)")
  proof -
    have "i. i  set [0..<n]  ?lhs i = ?rhs i"
    proof (intro monoid_sum_list_index_permutation[symmetric] m_closed nat_pow_closed)
      fix i
      assume "i  set [0..<n]"
      have "bij_betw (λia. (n - i + ia) mod n) {0..<n} {0..<n}"
        by (intro const_add_mod_bij)
      also have "bij_betw (λia. (n - i + ia) mod n) {0..<n} {0..<n} =
        bij_betw (λia. (n + ia - i) mod n) {0..<n} {0..<n}"
        apply (intro bij_betw_cong)
        using i  set [0..<n] by simp
      finally show "bij_betw (λia. (n + ia - i) mod n) (set [0..<n]) (set [0..<n])" by simp
    qed simp_all
    then show ?thesis
      by (intro monoid_sum_list_cong) (intro arg_cong[where f = "(⊗) _"])
  qed
  also have "... = (i  [0..<n]. f i  x [^] i)  (j  [0..<n]. (g j  x [^] j))"
    by (intro monoid_sum_list_in_right monoid_sum_list_closed) simp_all
  finally show ?thesis by argo
qed

context
  fixes n :: nat
begin

definition cyclic_convolution :: "'a list  'a list  'a list" (infixl "" 70) where
  "cyclic_convolution a b  [(σ  [0..<n]. (a ! σ  b ! ((n + i - σ) mod n))). i  [0..<n]]"

lemma cyclic_convolution_length[simp]:
  "length (a  b) = n" unfolding cyclic_convolution_def by simp

lemma cyclic_convolution_nth:
"i < n  (a  b) ! i = (σ  [0..<n]. (a ! σ  b ! ((n + i - σ) mod n)))"
  unfolding cyclic_convolution_def by simp

lemma cyclic_convolution_closed:
  assumes "length a = n" "length b = n"
  assumes "set a  carrier R" "set b  carrier R"
  shows "set (a  b)  carrier R"
proof (intro set_subseteqI)
  fix i
  assume "i < length (a  b)"
  then have "i < n" using assms(1) assms(2) by simp
  then have "(a  b) ! i = (σ  [0..<n]. (a ! σ  b ! ((n + i - σ) mod n)))"
    using cyclic_convolution_nth by presburger
  also have "...  carrier R"
    apply (intro monoid_sum_list_closed m_closed)
    subgoal for σ using set_subseteqD[OF assms(3)] length a = n by simp
    subgoal for σ using set_subseteqD[OF assms(4)] length b = n by simp
    done
  finally show "(a  b) ! i  carrier R" .
qed

theorem convolution_rule:
  assumes "length a = n"
  assumes "length b = n"
  assumes "set a  carrier R"
  assumes "set b  carrier R"
  assumes "root_of_unity n μ"
  assumes "i < n"
  shows "NTT μ a ! i  NTT μ b ! i = NTT μ (a  b) ! i"
proof (cases "n > 0")
  case False
  then show ?thesis using i < n by simp
next
  case True

  then interpret root_group : group "roots_of_unity_group n"
    by (rule roots_of_unity_group_is_group)

  interpret root_subgroup : multiplicative_subgroup R "{μ. root_of_unity n μ}" "roots_of_unity_group n"
    apply unfold_locales
    subgoal using roots_Units n > 0 by blast
    subgoal unfolding roots_of_unity_group_def[OF n > 0] by simp
    done

  have "μ  carrier R" using assms(5) by simp
  have "NTT μ a ! i  NTT μ b ! i =
    (j  [0..<n]. a ! j  (μ [^] i) [^] j)  (j  [0..<n]. b ! j  (μ [^] i) [^] j)"
    unfolding NTT_nth[OF assms(1) i < n] NTT_nth[OF assms(2) i < n] by argo
  also have "... = (j  [0..<n]. (k  [0..<n]. (a ! k)  (b ! ((n + j - k) mod n)))  (μ [^] i) [^] j)"
    apply (intro root_of_unity_power_sum_product root_of_unity_nat_pow_closed)
    using True root_of_unity n μ set_subseteqD[OF assms(3)] set_subseteqD[OF assms(4)] assms(1) assms(2)
    by simp_all
  also have "... = (j  [0..<n]. (a  b) ! j  (μ [^] i) [^] j)"
    apply (intro monoid_sum_list_cong arg_cong[where f = "λj. j  _"] cyclic_convolution_nth[symmetric])
    by simp
  also have "... = NTT μ (a  b) ! i"
    apply (intro NTT_nth[symmetric]) using i < n by simp_all
  finally show ?thesis .
qed

end

end

end