Theory FNTT_Rings

subsection "Fast Number Theoretic Transforms in Rings"

theory FNTT_Rings
  imports NTT_Rings "Number_Theoretic_Transform.Butterfly"
begin

context cring begin

text "The following lemma is the essence of Fast Number Theoretic Transforms (FNTTs)."

lemma NTT_recursion:
  assumes "even n"
  assumes "primitive_root n μ"
  assumes[simp]: "length a = n"
  assumes[simp]: "j < n"
  assumes[simp]: "set a  carrier R"
  defines "j'  (if j < n div 2 then j else j - n div 2)"
  shows "j' < n div 2" "j = (if j < n div 2 then j' else j' + n div 2)"
  and "(NTT μ a) ! j = (NTT (μ [^] (2::nat)) [a ! i. i  filter even [0..<n]]) ! j'
     μ [^] j  (NTT (μ [^] (2::nat)) [a ! i. i  filter odd [0..<n]]) ! j'"
proof -
  from assms have "n > 0" by linarith
  have[simp]: "μ  carrier R" using primitive_root n μ unfolding primitive_root_def root_of_unity_def by blast
  then have μ_pow_carrier[simp]: "μ [^] i  carrier R" for i :: nat by simp
  show "j' < n div 2" unfolding j'_def using j < n even n by fastforce
  show j'_alt: "j = (if j < n div 2 then j' else j' + n div 2)"
    unfolding j'_def by simp

  have a_even_carrier[simp]: "a ! (2 * i)  carrier R" if "i < n div 2" for i
    using set_subseteqD[OF set a  carrier R] assms that by simp
  have a_odd_carrier[simp]: "a ! (2 * i + 1)  carrier R" if "i < n div 2" for i
    using set_subseteqD[OF set a  carrier R] assms that by simp

  have μ_pow: "μ [^] (j * (2 * i)) = (μ [^] (2::nat)) [^] (j' * i)" for i
  proof -
    have "μ [^] (j * (2 * i)) = (μ [^] (j * 2)) [^] i"
      using mult.assoc nat_pow_pow[symmetric] by simp
    also have "μ [^] (j * 2) = μ [^] (j' * 2)"
    proof (cases "j < n div 2")
      case True
      then show ?thesis unfolding j'_def by simp
    next
      case False
      then have "μ [^] (j * 2) = μ [^] (j' * 2 + n)"
        using j'_alt by (simp add: even n)
      also have "... = μ [^] (j' * 2)"
        using n > 0 primitive_root n μ
        by (intro root_of_unity_powers_nat[of n]) auto
      finally show ?thesis .
    qed
    finally show ?thesis unfolding nat_pow_pow[OF μ  carrier R]
      by (simp add: mult.assoc mult.commute)
  qed

  have "(NTT μ a) ! j = (i  [0..<n]. a ! i  (μ [^] (j * i)))"
    using NTT_nth_2[of a n j μ] by simp
  also have "... = (i  [0..<n div 2]. a ! (2 * i)  (μ [^] (j * (2 * i))))
       (i  [0..<n div 2]. a ! (2 * i + 1)  (μ [^] (j * (2 * i + 1))))"
    using even n
    by (intro monoid_sum_list_even_odd_split m_closed nat_pow_closed set_subseteqD) simp_all
  also have "(i  [0..<n div 2]. a ! (2 * i + 1)  (μ [^] (j * (2 * i + 1))))
           = (i  [0..<n div 2]. μ [^] j  (a ! (2 * i + 1)  (μ [^] (j * (2 * i)))))"
  proof (intro monoid_sum_list_cong)
    fix i
    assume "i  set [0..<n div 2]"
    then have[simp]: "i < n div 2" by simp
    have "a ! (2 * i + 1)  (μ [^] (j * (2 * i + 1))) =
          a ! (2 * i + 1)  (μ [^] (j * (2 * i))  μ [^] j)"
      unfolding distrib_left mult_1_right
      unfolding nat_pow_mult[symmetric, OF μ  carrier R]
      by (rule refl)
    also have "... = (a ! (2 * i + 1)  μ [^] (j * (2 * i)))  μ [^] j"
      using a_odd_carrier[OF i < n div 2]
      by (intro m_assoc[symmetric]; simp)
    also have "... = μ [^] j  (a ! (2 * i + 1)  μ [^] (j * (2 * i)))"
      using a_odd_carrier[OF i < n div 2]
      by (intro m_comm; simp)
    finally show "a ! (2 * i + 1)  μ [^] (j * (2 * i + 1)) = ..." .
  qed
  also have "... = μ [^] j  (i  [0..<n div 2]. a ! (2 * i + 1)  (μ [^] (j * (2 * i))))"
    using a_odd_carrier by (intro monoid_sum_list_in_left; simp)
  finally have "(NTT μ a) ! j = (i  [0..<n div 2]. a ! (2 * i)  (μ [^] (2::nat)) [^] (j' * i))
       μ [^] j  (i  [0..<n div 2]. a ! (2 * i + 1)  (μ [^] (2::nat)) [^] (j' * i))"
    unfolding μ_pow .
  also have "... = (i  [0..<n div 2]. [a ! i. i  filter even [0..<n]] ! i  (μ [^] (2::nat)) [^] (j' * i))
       μ [^] j  (i  [0..<n div 2]. [a ! i. i  filter odd [0..<n]] ! i  (μ [^] (2::nat)) [^] (j' * i))"
    by (intro_cong "[cong_tag_2 (⊕), cong_tag_2 (⊗)]" more: monoid_sum_list_cong)
       (simp_all add: filter_even_nth length_filter_even length_filter_odd filter_odd_nth)
  also have "... = (NTT (μ [^] (2::nat)) [a ! i. i  filter even [0..<n]]) ! j'
       μ [^] j  (NTT (μ [^] (2::nat)) [a ! i. i  filter odd [0..<n]]) ! j'"
    by (intro_cong "[cong_tag_2 (⊕), cong_tag_2 (⊗)]" more: NTT_nth_2[symmetric])
       (simp_all add: length_filter_even length_filter_odd even n j' < n div 2)
  finally show "(NTT μ a) ! j = ..." .
qed

lemma NTT_recursion_1:
  assumes "even n"
  assumes "primitive_root n μ"
  assumes[simp]: "length a = n"
  assumes[simp]: "j < n div 2"
  assumes[simp]: "set a  carrier R"
  shows "(NTT μ a) ! j =
        (NTT (μ [^] (2::nat)) [a ! i. i  filter even [0..<n]]) ! j
       μ [^] j  (NTT (μ [^] (2::nat)) [a ! i. i  filter odd [0..<n]]) ! j"
proof -
  have "j < n" using j < n div 2 by linarith
  show ?thesis
    using NTT_recursion[OF even n primitive_root n μ length a = n j < n set a  carrier R]
    using j < n div 2 by presburger
qed

lemma NTT_recursion_2:
  assumes "even n"
  assumes "primitive_root n μ"
  assumes[simp]: "length a = n"
  assumes[simp]: "j < n div 2"
  assumes[simp]: "set a  carrier R"
  assumes halfway_property: "μ [^] (n div 2) =  𝟭"
  shows "(NTT μ a) ! (n div 2 + j) =
        (NTT (μ [^] (2::nat)) [a ! i. i  filter even [0..<n]]) ! j
       μ [^] j  (NTT (μ [^] (2::nat)) [a ! i. i  filter odd [0..<n]]) ! j"
proof -
  from assms have "μ  carrier R" unfolding primitive_root_def root_of_unity_def by simp
  then have carrier_1: "μ [^] j  carrier R"
    by simp
  have carrier_2: "NTT (μ [^] (2::nat)) (map ((!) a) (filter odd [0..<n])) ! j  carrier R"
    apply (intro NTT_nth_closed[where n = "n div 2"])
    subgoal using set a  carrier R length a = n by fastforce
    subgoal using μ  carrier R by simp
    subgoal by (simp add: length_filter_odd)
    subgoal using j < n div 2 .
    done
  have "n div 2 + j < n" using j < n div 2 even n by linarith
  then have "(NTT μ a) ! (n div 2 + j) =
        (NTT (μ [^] (2::nat)) [a ! i. i  filter even [0..<n]]) ! j
       μ [^] (n div 2 + j)  (NTT (μ [^] (2::nat)) [a ! i. i  filter odd [0..<n]]) ! j"
    using NTT_recursion[OF even n primitive_root n μ length a = n n div 2 + j < n set a  carrier R]
    by simp
  also have "μ [^] (n div 2 + j) =  (μ [^] j)"
    unfolding nat_pow_mult[symmetric, OF μ  carrier R] halfway_property
    by (intro minus_eq_mult_one[symmetric]; simp add: μ  carrier R)
  finally show ?thesis unfolding minus_eq l_minus[OF carrier_1 carrier_2] .
qed

lemma NTT_diffs:
  assumes "even n"
  assumes "primitive_root n μ"
  assumes "length a = n"
  assumes "j < n div 2"
  assumes "set a  carrier R"
  assumes "μ [^] (n div 2) =  𝟭"
  shows "NTT μ a ! j  NTT μ a ! (n div 2 + j) = nat_embedding 2  (μ [^] j  NTT (μ [^] (2::nat)) (map ((!) a) (filter odd [0..<n])) ! j)"
proof -
  have[simp]: "μ  carrier R" using primitive_root n μ unfolding primitive_root_def root_of_unity_def by blast
  define ntt1 where "ntt1 = NTT (μ [^] (2::nat)) (map ((!) a) (filter even [0..<n])) ! j"
  have "ntt1  carrier R" unfolding ntt1_def
    apply (intro set_subseteqD[OF NTT_closed] set_subseteqI)
    subgoal for i
      using set_subseteqD[OF set a  carrier R]
      by (simp add: filter_even_nth length a = n even n length_filter_even)
    subgoal by simp
    subgoal using assms by (simp add: length_filter_even even n)
    done
  define ntt2 where "ntt2 = NTT (μ [^] (2::nat)) (map ((!) a) (filter odd [0..<n])) ! j"
  have "ntt2  carrier R" unfolding ntt2_def
    apply (intro set_subseteqD[OF NTT_closed] set_subseteqI)
    subgoal for i
      using set_subseteqD[OF set a  carrier R]
      by (simp add: filter_odd_nth length a = n even n length_filter_odd)
    subgoal by simp
    subgoal using assms by (simp add: length_filter_odd even n)
    done
  have "NTT μ a ! j  NTT μ a ! (n div 2 + j) =
    (ntt1  μ [^] j  ntt2)  (ntt1  μ [^] j  ntt2)"
    apply (intro arg_cong2[where f = "λi j. i  j"])
    unfolding ntt1_def ntt2_def
    subgoal by (intro NTT_recursion_1 assms)
    subgoal by (intro NTT_recursion_2 assms)
    done
  also have "... = μ [^] j  (ntt2  ntt2)"
    using ntt1  carrier R ntt2  carrier R nat_pow_closed[OF μ  carrier R]
    by algebra
  also have "... = μ [^] j  ((𝟭  𝟭)  ntt2)"
    using ntt2  carrier R one_closed by algebra
  also have "... = μ [^] j  (nat_embedding 2  ntt2)"
    by (simp add: numeral_2_eq_2)
  also have "... = nat_embedding 2  (μ [^] j  ntt2)"
    using nat_pow_closed[OF μ  carrier R] ntt2  carrier R nat_embedding_closed
    by algebra
  finally show ?thesis unfolding ntt2_def .
qed

text "The following algorithm is adapted from @{theory Number_Theoretic_Transform.Butterfly}"
lemma FNTT_term_aux[simp]: "length (filter P [0..<l]) < Suc l"
  by (metis diff_zero le_imp_less_Suc length_filter_le length_upt)
fun FNTT :: "'a  'a list  'a list" where
"FNTT μ [] = []"
| "FNTT μ [x] = [x]"
| "FNTT μ [x, y] = [x  y, x  y]"
| "FNTT μ a = (let n = length a;
                  nums1 = [a!i.  i  filter even [0..<n]];
                  nums2 = [a!i.  i  filter odd [0..<n]];
                  b = FNTT (μ [^] (2::nat)) nums1;
                  c = FNTT (μ [^] (2::nat)) nums2;
                  g = [b!i  (μ [^] i)  c!i. i  [0..<(n div 2)]];
                  h = [b!i  (μ [^] i)  c!i. i  [0..<(n div 2)]]
               in g@h)"
lemmas [simp del] = FNTT_term_aux

declare FNTT.simps[simp del]

lemma length_FNTT[simp]:
  assumes "length a = 2 ^ k"
  shows "length (FNTT μ a) = length a"
  using assms
proof (induction rule: FNTT.induct)
  case (1 μ)
  then show ?case by simp
next
  case (2 μ x)
  then show ?case by (simp add: FNTT.simps)
next
  case (3 μ x y)
  then show ?case by (simp add: FNTT.simps)
next
  case (4 μ a1 a2 a3 as)
  define a where "a = a1 # a2 # a3 # as"
  define n where "n = length a"
  with a_def have "even n" using 4(3)
    by (cases "k = 0") simp_all
  define nums1 where "nums1 = [a!i.  i  filter even [0..<n]]"
  define nums2 where "nums2 = [a!i.  i  filter odd [0..<n]]"
  define b where "b = FNTT (μ [^] (2::nat)) nums1"
  define c where "c = FNTT (μ [^] (2::nat)) nums2"
  define g where "g = [b!i  (μ [^] i)  c!i. i  [0..<(n div 2)]]"
  define h where "h = [b!i  (μ [^] i)  c!i. i  [0..<(n div 2)]]"

  note defs = a_def n_def nums1_def nums2_def b_def c_def g_def h_def

  have "length (FNTT μ a) = length g + length h"
    using defs by (simp add: Let_def FNTT.simps)
  also have "... = (n div 2) + (n div 2)" unfolding g_def h_def by simp
  also have "... = n" using even n by fastforce
  finally show ?case by (simp only: a_def n_def)
qed

theorem FNTT_NTT:
  assumes[simp]: "μ  carrier R"
  assumes "n = 2 ^ k"
  assumes "primitive_root n μ"
  assumes halfway_property: "μ [^] (n div 2) =  𝟭"
  assumes[simp]: "length a = n"
  assumes "set a  carrier R"
  shows "FNTT μ a = NTT μ a"
  using assms
proof (induction μ a arbitrary: n k rule: FNTT.induct)
  case (1 μ)
  then show ?case unfolding NTT_def by simp
next
  case (2 μ x)
  then have "n = 1" by simp
  then have "k = 0" using n = 2 ^ k by simp
  moreover have "x  carrier R" using 2 by simp
  ultimately show ?case unfolding NTT_def by (simp add: Let_def FNTT.simps)
next
  case (3 μ x y)
  then have[simp]: "x  carrier R" "y  carrier R" by simp_all
  from 3 have "n = 2" by simp
  with μ [^] (n div 2) =  𝟭 have "μ [^] (1 :: nat) =  𝟭" by simp
  then have "μ =  𝟭" by (simp add: μ  carrier R)
  have "NTT μ [x, y] = [x  y, x  y]"
    unfolding NTT_def
    apply (simp add: Let_def 3 μ =  𝟭)
    using x  carrier R y  carrier R by algebra
  then show ?case by (simp add: FNTT.simps)
next
  case (4 μ a1 a2 a3 as)
  define a where "a = a1 # a2 # a3 # as"
  then have[simp]: "length a = n" using 4(7) by simp
  define nums1 where "nums1 = [a!i.  i  filter even [0..<n]]"
  define nums2 where "nums2 = [a!i.  i  filter odd [0..<n]]"
  define b where "b = FNTT (μ [^] (2::nat)) nums1"
  define c where "c = FNTT (μ [^] (2::nat)) nums2"
  define g where "g = [b!i  (μ [^] i)  c!i. i  [0..<(n div 2)]]"
  then have "length g = n div 2" by simp
  define h where "h = [b!i  (μ [^] i)  c!i. i  [0..<(n div 2)]]"
  then have "length h = n div 2" by simp

  note defs = a_def nums1_def nums2_def b_def c_def g_def h_def

  have "k > 0"
    using length (a1 # a2 # a3 # as) = n n = 2 ^ k
    by (cases "k = 0") simp_all
  then have "even n" "n div 2 = 2 ^ (k - 1)"
    using n = 2 ^ k by (simp_all add: power_diff)

  have "FNTT μ (a1 # a2 # a3 # as) = g @ h"
    unfolding FNTT.simps
    using length (a1 # a2 # a3 # as) = n by (simp only: Let_def defs)
  then have "FNTT μ a = g @ h" using a_def by simp

  have recursive_halfway: "(μ [^] (2 :: nat)) [^] (n div 2 div 2) =  𝟭"
  proof -
    have "n  3"
      using length (a1 # a2 # a3 # as) = n by simp
    then have "k  2" using n = 2 ^ k by (cases "k  {0, 1}") auto
    then have "even (n div 2)" using n div 2 = 2 ^ (k - 1) by fastforce
    then show ?thesis
      by (simp add: nat_pow_pow μ  carrier R μ [^] (n div 2) =  𝟭)
  qed

  have "b = NTT (μ [^] (2::nat)) nums1"
    unfolding b_def
    apply (intro 4(1)[of n nums1 nums2 "n div 2" "k - 1"])
    subgoal using length (a1 # a2 # a3 # as) = n by simp
    subgoal using nums1_def a_def by simp
    subgoal using nums2_def a_def by simp
    subgoal using μ  carrier R by simp
    subgoal using n div 2 = 2 ^ (k - 1) .
    subgoal using primitive_root_recursion even n primitive_root n μ by blast
    subgoal using recursive_halfway .
    subgoal using nums1_def length_filter_even even n by simp
    subgoal
      unfolding nums1_def
      apply (intro set_subseteqI)
      using set_subseteqD[OF set (a1 # a2 # a3 # as)  carrier R]
      by (simp add: a_def[symmetric] filter_even_nth length_filter_even even n)
    done

  have "c = NTT (μ [^] (2::nat)) nums2"
    unfolding c_def
    apply (intro 4(2)[of n nums1 nums2 b "n div 2" "k - 1"])
    subgoal using length (a1 # a2 # a3 # as) = n by simp
    subgoal unfolding nums1_def a_def by simp
    subgoal unfolding nums2_def a_def by simp
    subgoal using b_def .
    subgoal using μ  carrier R by simp
    subgoal using n div 2 = 2 ^ (k - 1) .
    subgoal using primitive_root_recursion even n primitive_root n μ by blast
    subgoal using recursive_halfway .
    subgoal unfolding nums2_def using length_filter_odd by simp
    subgoal
      unfolding nums2_def
      apply (intro set_subseteqI)
      using set_subseteqD[OF set (a1 # a2 # a3 # as)  carrier R]
      by (simp add: a_def[symmetric] filter_odd_nth length_filter_odd)
    done

  show ?case
  proof (intro nth_equalityI)
    have[simp]: "length (FNTT μ (a1 # a2 # a3 # as)) = n"
      using length (a1 # a2 # a3 # as) = n n = 2 ^ k length_FNTT[of "a1 # a2 # a3 # as"]
      by blast
    then show "length (FNTT μ (a1 # a2 # a3 # as)) = length (NTT μ (a1 # a2 # a3 # as))"
      using NTT_length[of μ "a1 # a2 # a3 # as"] length (a1 # a2 # a3 # as) = n by argo
    fix i
    assume "i < length (FNTT μ (a1 # a2 # a3 # as))"
    then have "i < n" by simp

    have "FNTT μ a ! i = NTT μ a ! i"
    proof (cases "i < n div 2")
      case True
      then have "NTT μ a ! i =
        (NTT (μ [^] (2::nat)) [a ! i. i  filter even [0..<n]]) ! i
       μ [^] i  (NTT (μ [^] (2::nat)) [a ! i. i  filter odd [0..<n]]) ! i"
        apply (intro NTT_recursion_1)
        using True even n primitive_root n μ set (a1 # a2 # a3 # as)  carrier R a_def
        using μ  carrier R length (a1 # a2 # a3 # as) = n
        by simp_all

      also have "... = (NTT (μ [^] (2::nat)) nums1) ! i
       μ [^] i  (NTT (μ [^] (2::nat)) nums2) ! i"
        unfolding nums1_def nums2_def by blast
      also have "... = b ! i  μ [^] i  c ! i"
        using b = NTT (μ [^] 2) nums1 c = NTT (μ [^] 2) nums2 by blast
      also have "... = g ! i"
        unfolding g_def using True by simp
      also have "... = FNTT μ a ! i"
        using FNTT μ a = g @ h length g = n div 2 True
        by (simp add: nth_append)

      finally show ?thesis by simp
    next
      case False
      then obtain j where j_def: "i = n div 2 + j" "j < n div 2"
        using i < n even n
        by (metis add_diff_inverse_nat add_self_div_2 div_plus_div_distrib_dvd_right nat_add_left_cancel_less)
      have "NTT μ a ! (n div 2 + j) =
        (NTT (μ [^] (2::nat)) [a ! i. i  filter even [0..<n]]) ! j
       μ [^] j  (NTT (μ [^] (2::nat)) [a ! i. i  filter odd [0..<n]]) ! j"
        apply (intro NTT_recursion_2)
        subgoal using even n .
        subgoal using primitive_root n μ .
        subgoal using length (a1 # a2 # a3 # as) = n a_def by simp
        subgoal using j_def by simp
        subgoal using set (a1 # a2 # a3 # as)  carrier R a_def by simp
        subgoal using μ [^] (n div 2) =  𝟭 .
        done

      also have "... = (NTT (μ [^] (2::nat)) nums1) ! j
       μ [^] j  (NTT (μ [^] (2::nat)) nums2) ! j"
        unfolding nums1_def nums2_def by blast
      also have "... = b ! j  μ [^] j  c ! j"
        using b = NTT (μ [^] 2) nums1 c = NTT (μ [^] 2) nums2 by blast
      also have "... = h ! j"
        unfolding g_def h_def using j_def by simp
      also have "... = FNTT μ a ! i"
        using FNTT μ a = g @ h length g = n div 2 j_def
        by (simp add: nth_append)

      finally show ?thesis using j_def by simp
    qed
    then show "FNTT μ (a1 # a2 # a3 # as) ! i = NTT μ (a1 # a2 # a3 # as) ! i"
      using a_def by simp
  qed
qed

end

text "The following is copied from @{theory Number_Theoretic_Transform.Butterfly} and moved outside
of the @{locale butterfly} locale."

fun evens_odds where
"evens_odds _ [] = []"
| "evens_odds True (x#xs)= (x # evens_odds False xs)"
| "evens_odds False (x#xs) = evens_odds True xs"

lemma map_filter_shift: " map f (filter even [0..<Suc g]) = 
        f 0 #  map (λ x. f (x+1)) (filter odd [0..<g])"
  by (induction g) auto

lemma map_filter_shift': " map f (filter odd [0..<Suc g]) = 
          map (λ x. f (x+1)) (filter even [0..<g])"
  by (induction g) auto

lemma filter_comprehension_evens_odds:
      "[xs ! i. i  filter even [0..<length xs]] = evens_odds True xs 
       [xs ! i. i  filter odd [0..<length xs]] = evens_odds False xs "
  apply(induction xs)
   apply simp
  subgoal for x xs
    apply rule
    subgoal 
      apply(subst evens_odds.simps)
      apply(rule trans[of _ "map ((!) (x # xs)) (filter even [0..<Suc (length xs)])"])
      subgoal by simp
      apply(rule trans[OF  map_filter_shift[of "(!) (x # xs)" "length xs"]])
      apply simp
      done

      apply(subst evens_odds.simps)
      apply(rule trans[of _ "map ((!) (x # xs)) (filter odd [0..<Suc (length xs)])"])
      subgoal by simp
      apply(rule trans[OF  map_filter_shift'[of "(!) (x # xs)" "length xs"]])
      apply simp
    done
  done

lemma FNTT'_termination_aux[simp]: "length (evens_odds True xs) < Suc (length xs)" 
              "length (evens_odds False xs) < Suc (length xs)"
  by (metis filter_comprehension_evens_odds le_imp_less_Suc length_filter_le length_map map_nth)+

text "(End of copy)"

lemma map_evens_odds: "map f (evens_odds x a) = evens_odds x (map f a)"
  by (induction x a rule: evens_odds.induct) simp_all

lemma length_evens_odds:
  "length (evens_odds True a) = (if even (length a) then length a div 2 else length a div 2 + 1)"
  "length (evens_odds False a) = length a div 2"
  using filter_comprehension_evens_odds[of a] length_filter_even[of "length a"] length_filter_odd[of "length a"]
  using length_map by (metis, metis)

lemma set_evens_odds:
  "set (evens_odds x a)  set a"
  by (induction x a rule: evens_odds.induct) fastforce+

context cring begin

text "Similar to @{theory Number_Theoretic_Transform.Butterfly}, we give an abstract algorithm that can be
refined more easily to a verifiably efficient FNTT algorithm."

fun FNTT' :: "'a  'a list  'a list" where
"FNTT' μ [] = []"
| "FNTT' μ [x] = [x]"
| "FNTT' μ [x, y] = [x  y, x  y]"
| "FNTT' μ a = (let n = length a;
                  nums1 = evens_odds True a;
                  nums2 = evens_odds False a;
                  b = FNTT' (μ [^] (2::nat)) nums1;
                  c = FNTT' (μ [^] (2::nat)) nums2;
                  g = [b!i  (μ [^] i)  c!i. i  [0..<(n div 2)]];
                  h = [b!i  (μ [^] i)  c!i. i  [0..<(n div 2)]]
               in g@h)"

lemma FNTT'_FNTT: "FNTT' μ xs = FNTT μ xs"
  apply (induction μ xs rule: FNTT'.induct)
  subgoal by (simp add: FNTT.simps)
  subgoal by (simp add: FNTT.simps)
  subgoal by (simp add: FNTT.simps)
  subgoal for μ a1 a2 a3 as
    unfolding FNTT.simps FNTT'.simps Let_def
    using filter_comprehension_evens_odds[of "a1 # a2 # a3 # as"] by presburger
  done

fun FNTT'' :: "'a  'a list  'a list" where
"FNTT'' μ [] = []"
| "FNTT'' μ [x] = [x]"
| "FNTT'' μ [x, y] = [x  y, x  y]"
| "FNTT'' μ a = (let n = length a;
                  nums1 = evens_odds True a;
                  nums2 = evens_odds False a;
                  b = FNTT'' (μ [^] (2::nat)) nums1;
                  c = FNTT'' (μ [^] (2::nat)) nums2;
                  g = map2 (⊕) b (map2 (⊗) [μ [^] i. i  [0..<(n div 2)]] c);
                  h = map2 (λx y. x  y) b (map2 (⊗) [μ [^] i. i  [0..<(n div 2)]] c)
               in g@h)"

lemma FNTT''_FNTT':
  assumes "length a = 2 ^ k"
  shows "FNTT'' μ a = FNTT' μ a"
  using assms
proof (induction μ a arbitrary: k rule: FNTT''.induct)
  case (4 μ a1 a2 a3 as)
  define a where "a = a1 # a2 # a3 # as"
  define n where "n = length a"
  then have "n = 2 ^ k" using 4 a_def by simp
  then have "k  2" using n_def a_def by (cases "k = 0"; cases "k = 1") simp_all
  then have "even n" using n = 2 ^ k by simp
  have "n div 2 = 2 ^ (k - 1)" using n = 2 ^ k k  2 by (simp add: power_diff)
  then have "even (n div 2)" using k  2 by simp
  define nums1 where "nums1 = evens_odds True a"
  then have "length nums1 = n div 2"
    using length_filter_even[of n] filter_comprehension_evens_odds[of a] n_def even n
    by (metis length_map)
  define nums2 where "nums2 = evens_odds False a"
  then have "length nums2 = n div 2"
    using length_filter_odd[of n] filter_comprehension_evens_odds[of a] n_def
    by (metis length_map)
  define b where "b = FNTT' (μ [^] (2::nat)) nums1"
  then have "length b = n div 2"
    by (simp add: FNTT'_FNTT length nums1 = n div 2 n div 2 = 2 ^ (k - 1))
  define c where "c = FNTT' (μ [^] (2::nat)) nums2"
  then have "length c = n div 2"
    by (simp add: FNTT'_FNTT length nums2 = n div 2 n div 2 = 2 ^ (k - 1))
  define g1 where "g1 = [b!i  (μ [^] i)  c!i. i  [0..<(n div 2)]]"
  then have "length g1 = n div 2" by simp
  define h1 where "h1 = [b!i  (μ [^] i)  c!i. i  [0..<(n div 2)]]"
  then have "length h1 = n div 2" by simp
  define g2 where "g2 = map2 (⊕) b (map2 (⊗) [μ [^] i. i  [0..<(n div 2)]] c)"
  then have "length g2 = n div 2"
    by (simp add: length b = n div 2 length c = n div 2)

  have "g1 = g2"
    apply (intro nth_equalityI)
    subgoal by (simp only: length g1 = n div 2 length g2 = n div 2)
    subgoal for i
      by (simp add: g1_def g2_def length b = n div 2 length c = n div 2)
    done
    
  define h2 where "h2 = map2 (λx y. x  y) b (map2 (⊗) [μ [^] i. i  [0..<(n div 2)]] c)"
  then have "length h2 = n div 2"
    by (simp add: length b = n div 2 length c = n div 2)

  have "h1 = h2"
    apply (intro nth_equalityI)
    subgoal by (simp only: length h1 = n div 2 length h2 = n div 2)
    subgoal for i
      by (simp add: h1_def h2_def length b = n div 2 length c = n div 2)
    done

  have 1: "FNTT'' (μ [^] (2::nat)) nums1 = FNTT' (μ [^] (2::nat)) nums1"
    apply (intro 4(1))
    using a_def n_def length (a1 # a2 # a3 # as) = 2 ^ k length nums1 = n div 2 n div 2 = 2 ^ (k - 1)
    by (simp_all add: nums1_def)
  have 2: "FNTT'' (μ [^] (2::nat)) nums2 = FNTT' (μ [^] (2::nat)) nums2"
    apply (intro 4(2))
    using a_def n_def length (a1 # a2 # a3 # as) = 2 ^ k length nums2 = n div 2 n div 2 = 2 ^ (k - 1)
    by (simp_all add: nums2_def)

  show ?case
    apply (simp only: FNTT'.simps FNTT''.simps)
    apply (simp only: Let_def 1 2 a_def[symmetric] nums1_def[symmetric] nums2_def[symmetric]
        b_def[symmetric] c_def[symmetric])
    using h1 = h2 g1 = g2  n_def g1_def h1_def g2_def h2_def
    by argo
qed simp_all

end

end