Theory Z_mod_power_of_2

section "The Schoenhage-Strassen Algorithm"

subsection "Representing $\\mathbb{Z}_{2 ^ n}$"

theory Z_mod_power_of_2
  imports 
    "Karatsuba.Nat_LSBF_TM" 
    "Finite_Fields.Ring_Characteristic" 
    "Karatsuba.Abstract_Representations_2" 
    "HOL-Number_Theory.Number_Theory"
begin

context cring begin
lemma pow_one_imp_unit:
"(n::nat) > 0  a  carrier R  a [^] n = 𝟭  a  Units R"
  using gr0_implies_Suc[of n] nat_pow_Suc2[of a]
  by (metis Units_one_closed nat_pow_closed unit_factor)
end

definition ensure_length where "ensure_length k xs = take k (fill k xs)"
lemma ensure_length_correct[simp]: "length (ensure_length k xs) = k" using fill_def ensure_length_def by simp
lemma to_nat_ensure_length: "Nat_LSBF.to_nat xs < 2 ^ n  Nat_LSBF.to_nat (ensure_length n xs) = Nat_LSBF.to_nat xs"
  by (simp add: to_nat_take ensure_length_def)

locale int_lsbf_mod =
  fixes k :: nat
  assumes k_positive: "k > 0"
begin

abbreviation n where "n  (2::nat) ^ k"

definition Zn where "Zn  residue_ring (int n)"

lemma n_positive[simp]: "n > 0"
  by simp

sublocale residues n Zn
  apply unfold_locales
  subgoal using k_positive by simp
  subgoal by (rule Zn_def)
  done

definition to_residue_ring :: "nat_lsbf  int" where
"to_residue_ring xs = int (Nat_LSBF.to_nat xs) mod int n"

lemma to_residue_ring_in_carrier: "to_residue_ring xs  carrier Zn"
  unfolding to_residue_ring_def res_carrier_eq by simp

definition from_residue_ring :: "int  nat_lsbf" where
"from_residue_ring x = fill k (Nat_LSBF.from_nat (nat x))"

definition reduce where
"reduce xs = ensure_length k xs"

lemma length_reduce: "length (reduce xs) = k"
  unfolding reduce_def using fill_def ensure_length_def by simp

lemma to_nat_reduce: "Nat_LSBF.to_nat (reduce xs) = Nat_LSBF.to_nat xs mod n"
proof (cases "length xs  k")
  case True
  then have "reduce xs = fill k xs" unfolding reduce_def using fill_def ensure_length_def by simp
  also have "... = xs @ (replicate (k - length xs) False)" using fill_def by simp
  finally have "Nat_LSBF.to_nat (reduce xs) = Nat_LSBF.to_nat xs" by simp
  moreover have "Nat_LSBF.to_nat xs  2 ^ k - 1" using to_nat_length_upper_bound[of xs] True
    by (meson diff_le_mono le_trans one_le_numeral power_increasing)
  hence "Nat_LSBF.to_nat xs < 2 ^ k"
    using Nat.le_diff_conv2 by auto
  ultimately show ?thesis by simp
next
  case False
  then have "length (take k xs) = k" "fill k xs = xs" "xs = (take k xs) @ (drop k xs)" using fill_def by simp_all
  then have "Nat_LSBF.to_nat xs = Nat_LSBF.to_nat (take k xs) + n * Nat_LSBF.to_nat (drop k xs)"
    using to_nat_app[of "take k xs" "drop k xs"] by simp
  moreover have "Nat_LSBF.to_nat (take k xs)  2 ^ k - 1"
    using to_nat_length_upper_bound[of "take k xs"] length (take k xs) = k by simp
  hence "Nat_LSBF.to_nat (take k xs) < 2 ^ k"
    using Nat.le_diff_conv2 by auto
  ultimately show ?thesis unfolding reduce_def using fill_def ensure_length_def by simp
qed

(*fun to_ZMod where
"to_ZMod xs = ZMod n (int (to_nat xs))"
*)

definition add_mod where
"add_mod x y = reduce (add_nat x y)"

definition subtract_mod where
"subtract_mod xs ys =
  (if compare_nat xs ys then
    reduce (subtract_nat ((fill k xs) @ [True]) ys)
  else
    subtract_nat xs ys)"

lemma to_nat_add_mod: "Nat_LSBF.to_nat (add_mod x y) = (Nat_LSBF.to_nat x + Nat_LSBF.to_nat y) mod n"
  by (simp only: to_nat_reduce add_nat_correct add_mod_def)

lemma to_nat_subtract_mod: "length xs  k  length ys  k  int (Nat_LSBF.to_nat (subtract_mod xs ys)) = (int (Nat_LSBF.to_nat xs) - int (Nat_LSBF.to_nat ys)) mod n"
proof (cases "Nat_LSBF.to_nat xs  Nat_LSBF.to_nat ys")
  case True
  assume "length xs  k"
  assume "length ys  k"
  then have "Nat_LSBF.to_nat ys  n - 1"
    using to_nat_length_upper_bound[of ys]
    by (meson diff_le_mono le_trans one_le_numeral power_increasing)
  then have "Nat_LSBF.to_nat ys  Nat_LSBF.to_nat xs + n" by simp

  have "int (Nat_LSBF.to_nat (subtract_nat (fill k xs @ [True]) ys) mod n)
    = int ((Nat_LSBF.to_nat xs + n - Nat_LSBF.to_nat ys) mod n)"
    by (simp add: subtract_nat_correct to_nat_app length_fill length xs  k)
  also have "... = (int (Nat_LSBF.to_nat xs + n - Nat_LSBF.to_nat ys)) mod n"
    using zmod_int by simp
  also have "... = (int (Nat_LSBF.to_nat xs) + int n - int (Nat_LSBF.to_nat ys)) mod n"
    using Nat_LSBF.to_nat ys  Nat_LSBF.to_nat xs + n by (simp add: of_nat_diff)
  also have "... = (int (Nat_LSBF.to_nat xs) - int (Nat_LSBF.to_nat ys)) mod n"
    by (metis diff_add_eq int_ops(3) mod_add_self2 of_nat_power)
  finally have "int (Nat_LSBF.to_nat (subtract_nat (fill k xs @ [True]) ys) mod n) = (int (Nat_LSBF.to_nat xs) - int (Nat_LSBF.to_nat ys)) mod n" .
  then show ?thesis
    by (simp add: subtract_mod_def compare_nat_correct to_nat_reduce True split: if_splits)
next
  case False
  assume "length xs  k"
  then have "Nat_LSBF.to_nat xs  n - 1" using to_nat_length_upper_bound[of xs]
    by (meson diff_le_mono le_trans one_le_numeral power_increasing)
  assume "length ys  k"
  from False have "int (Nat_LSBF.to_nat (subtract_nat xs ys)) = int (Nat_LSBF.to_nat xs) - int (Nat_LSBF.to_nat ys)"
    by (simp add: subtract_nat_correct)
  moreover have "...  {0..<n}"
  proof -
    have "int (Nat_LSBF.to_nat xs) - int (Nat_LSBF.to_nat ys)  int (Nat_LSBF.to_nat xs)" by simp
    also have "...  n - 1" using Nat_LSBF.to_nat xs  n - 1 n_positive by simp
    also have "... < n" by simp
    finally have "int (Nat_LSBF.to_nat xs) - int (Nat_LSBF.to_nat ys) < n" by simp
    moreover have "int (Nat_LSBF.to_nat xs) - int (Nat_LSBF.to_nat ys)  0" using ¬ Nat_LSBF.to_nat xs  Nat_LSBF.to_nat ys by simp
    ultimately show ?thesis by simp
  qed
  ultimately have "int (Nat_LSBF.to_nat (subtract_nat xs ys)) = (int (Nat_LSBF.to_nat xs) - int (Nat_LSBF.to_nat ys)) mod n"
    by simp
  then show ?thesis by (simp add: subtract_mod_def compare_nat_correct to_nat_reduce False split: if_splits)
qed

lemma length_subtract_mod: "length xs  k  length ys  k  length (subtract_mod xs ys)  k"
  unfolding subtract_mod_def
  apply (cases "compare_nat xs ys")
  using subtract_nat_aux[of xs ys]
  by (auto simp: Let_def reduce_def ensure_length_def)
  
lemma add_mod_correct: "to_residue_ring (add_mod x y) = to_residue_ring x Znto_residue_ring y"
proof -
  have "to_residue_ring (add_mod x y) = to_residue_ring (reduce (add_nat x y))"
    unfolding add_mod_def by simp
  also have "... = (Nat_LSBF.to_nat x + Nat_LSBF.to_nat y) mod n"
    using to_nat_reduce add_nat_correct to_residue_ring_def by simp
  also have "... = (int (Nat_LSBF.to_nat x) mod n + (int (Nat_LSBF.to_nat y) mod n)) mod n"
    by (simp add: zmod_int mod_add_eq)
  finally show ?thesis
    by (simp only: res_add_eq to_residue_ring_def)
qed

lemma subtract_mod_correct:
  assumes "length x  k"
  assumes "length y  k"
  assumes "n > 1"
  shows "to_residue_ring (subtract_mod x y) = to_residue_ring x Znto_residue_ring y"
proof -
  have "to_residue_ring (subtract_mod x y) = int (Nat_LSBF.to_nat (subtract_mod x y)) mod int n"
    unfolding to_residue_ring_def by argo
  also have "... = (int (Nat_LSBF.to_nat x) - (int (Nat_LSBF.to_nat y))) mod int n"
    by (simp add: to_nat_subtract_mod assms)
  also have "... = (to_residue_ring x + (- to_residue_ring y mod n)) mod n"
    unfolding diff_conv_add_uminus to_residue_ring_def
    by (simp add: mod_add_eq mod_diff_right_eq)
  also have "... = (to_residue_ring x + (residue_ring n(to_residue_ring y mod n))) mod n"
    apply (intro_cong "[cong_tag_2 (mod), cong_tag_2 (+)]" more: refl)
    using residues.neg_cong[symmetric, of n] unfolding residues_def using n > 1
    by (metis int_ops(2) nat_int_comparison(2))
  also have "... = to_residue_ring x residue_ring n(to_residue_ring y mod n)"
    unfolding a_minus_def
    by (simp add: residue_ring_def)
  also have "to_residue_ring y mod n = to_residue_ring y"
    using to_residue_ring_def by simp
  finally show ?thesis unfolding Zn_def .
qed

lemma length_from_residue_ring: "x < 2 ^ k  length (from_residue_ring x) = k"
proof -
  assume "x < 2 ^ k"
  have "truncated (Nat_LSBF.from_nat (nat x))"
    using truncate_from_nat by simp
  moreover have "Nat_LSBF.to_nat (Nat_LSBF.from_nat (nat x)) = nat x"
    using nat_lsbf.to_from by simp
  ultimately have "length (Nat_LSBF.from_nat (nat x))  k" using x < 2 ^ k to_nat_length_bound_truncated
    by simp
  then show "length (from_residue_ring x) = k"
    unfolding from_residue_ring_def using length_fill by simp
qed

interpretation int_lsbf_mod: abstract_representation_2 from_residue_ring to_residue_ring "{0..<int n}"
  rewrites "int_lsbf_mod.reduce = reduce"
  and "int_lsbf_mod.representations = {x :: bool list. length x = k}"
proof -
  show "abstract_representation_2 from_residue_ring to_residue_ring {0..<int n}"
    apply unfold_locales
    unfolding to_residue_ring_def from_residue_ring_def by simp_all
  then interpret int_lsbf_mod: abstract_representation_2 from_residue_ring to_residue_ring "{0..<int n}" .
  show "int_lsbf_mod.reduce = reduce"
    unfolding int_lsbf_mod.reduce_def reduce_def
    apply (intro ext)
    apply (intro nat_lsbf_eqI)
    subgoal for x
      unfolding from_residue_ring_def to_nat_fill to_nat_from_nat
    proof -
      have "nat (to_residue_ring x) = nat (int (Nat_LSBF.to_nat x) mod int n)"
        by (simp add: from_residue_ring_def to_residue_ring_def ensure_length_def to_nat_take)
      also have "... = Nat_LSBF.to_nat x mod n"
        unfolding zmod_int[symmetric] nat_int by (rule refl)
      also have "... = Nat_LSBF.to_nat (ensure_length k x)"
        unfolding ensure_length_def by (simp add: to_nat_take)
      finally show "nat (to_residue_ring x) = ..." .
    qed
    subgoal for x
    proof -
      have "length (from_residue_ring (to_residue_ring x)) = k"
        apply (intro length_from_residue_ring)
        unfolding to_residue_ring_def
        using mod_less_divisor[OF n_positive] by simp
      then show ?thesis by simp
    qed
    done
  show "int_lsbf_mod.representations = {x :: bool list. length x = k}"
  proof (intro equalityI subsetI)
    fix x
    assume "x  int_lsbf_mod.representations"
    then obtain y where "y < 2 ^ k" "x = from_residue_ring y"
      unfolding int_lsbf_mod.representations_def by auto
    then have "length x = k" by (simp add: length_from_residue_ring)
    then show "x  {x. length x = k}" by simp
  next
    fix x :: "bool list"
    assume "x  {x. length x = k}"
    then have "length x = k" by simp
    have "from_residue_ring (to_residue_ring x) = int_lsbf_mod.reduce x"
      using int_lsbf_mod.reduce_def by simp
    also have "... = reduce x" using int_lsbf_mod.reduce = reduce by simp
    also have "... = x" using length x = k unfolding reduce_def ensure_length_def fill_def by simp
    finally show "x  int_lsbf_mod.representations"
      unfolding int_lsbf_mod.representations_def
      using int_lsbf_mod.to_type_in_represented_set
      by (metis imageI)
  qed
qed

lemma add_mod_closed: "length (add_mod x y) = k"
  using int_lsbf_mod.range_reduce add_mod_def by blast

end

end