Theory CRYSTALS-Kyber.Mod_Plus_Minus
theory Mod_Plus_Minus
imports Kyber_spec
begin
lemma odd_half_floor:
  ‹⌊real_of_int x / 2⌋ = (x - 1) div 2› if ‹odd x›
  using that by (rule oddE) simp
section ‹Re-centered Modulo Operation›
text ‹To define the compress and decompress functions, 
  we need some special form of modulo. It returns the 
  representation of the equivalence class in ‹(-q div 2, q div 2]›.
  Using these representatives, we ensure that the norm of the 
  representative is as small as possible.›
definition mod_plus_minus :: "int ⇒ int ⇒ int" 
  (infixl ‹mod+-› 70) where
"m mod+- b = 
  (if m mod b > ⌊b/2⌋ then m mod b - b else m mod b)"
text ‹Range of the (re-centered) modulo operation›
 
lemma mod_range: "b>0 ⟹ (a::int) mod (b::int) ∈ {0..b-1}"
using range_mod by auto
lemma mod_rangeE: 
  assumes "(a::int)∈{0..<b}"
  shows "a = a mod b"
using assms by auto
lemma half_mod_odd:
  assumes "b > 0" "odd b" "⌊real_of_int b / 2⌋ < y mod b" 
  shows "- ⌊real_of_int b / 2⌋ ≤ y mod b - b"
    "y mod b - b ≤ ⌊real_of_int b / 2⌋"
proof -
  from odd_half_floor [of b]
  show "- ⌊real_of_int b / 2⌋ ≤ y mod b - b"
    using assms by linarith
  then have "y mod b ≤ b + ⌊real_of_int b / 2⌋"
    by (smt (verit) ‹b > 0› pos_mod_bound)
  then show "y mod b - b ≤ ⌊real_of_int b / 2⌋"
    by auto
qed
lemma half_mod:
assumes "b>0"
shows "- ⌊real_of_int b / 2⌋ ≤ y mod b"
using assms
by (smt (verit, best) floor_less_zero half_gt_zero mod_int_pos_iff of_int_pos)
lemma mod_plus_minus_range_odd: 
  assumes "b>0" "odd b"
  shows "y mod+- b ∈ {-⌊b/2⌋..⌊b/2⌋}"
unfolding mod_plus_minus_def by (auto simp add: half_mod_odd[OF assms] half_mod[OF assms(1)])
lemma odd_smaller_b:
  assumes "odd b" 
  shows "⌊ real_of_int b / 2 ⌋ + ⌊ real_of_int b / 2 ⌋ < b"
using assms 
by (smt (z3) floor_divide_of_int_eq odd_two_times_div_two_succ 
  of_int_hom.hom_add of_int_hom.hom_one)
 
lemma mod_plus_minus_rangeE_neg:
  assumes "y ∈ {-⌊real_of_int b/2⌋..⌊real_of_int b/2⌋}"
          "odd b" "b > 0"
           "⌊real_of_int b / 2⌋ < y mod b"
  shows "y = y mod b - b"
proof -
  have "y ∈ {-⌊real_of_int b/2⌋..<0}" using assms
  by (meson atLeastAtMost_iff atLeastLessThan_iff linorder_not_le order_trans zmod_le_nonneg_dividend)
  then have "y ∈ {-b..<0}" using assms(2-3)
  by (metis atLeastLessThan_iff floor_divide_of_int_eq int_div_less_self linorder_linear 
    linorder_not_le neg_le_iff_le numeral_code(1) numeral_le_iff of_int_numeral order_trans 
    semiring_norm(69))
  then have "y mod b = y + b" 
  by (smt (verit) atLeastLessThan_iff mod_add_self2 mod_pos_pos_trivial)
  then show ?thesis by auto
qed
lemma mod_plus_minus_rangeE_pos:
  assumes "y ∈ {-⌊real_of_int b/2⌋..⌊real_of_int b/2⌋}"
          "odd b" "b > 0"
          "⌊real_of_int b / 2⌋ ≥ y mod b"
  shows "y = y mod b"
proof -
  have "y ∈ {0..⌊real_of_int b/2⌋}" 
  proof (rule ccontr)
    assume "y ∉ {0..⌊real_of_int b / 2⌋} "
    then have "y ∈ {-⌊real_of_int b/2⌋..<0}" using assms(1) by auto
    then have "y ∈ {-b..<0}" using assms(2-3)
    by (metis atLeastLessThan_iff floor_divide_of_int_eq int_div_less_self linorder_linear 
      linorder_not_le neg_le_iff_le numeral_code(1) numeral_le_iff of_int_numeral order_trans 
      semiring_norm(69))
    then have "y mod b = y + b" 
      by (smt (verit) atLeastLessThan_iff mod_add_self2 mod_pos_pos_trivial)
    then have "y mod b ≥ b - ⌊real_of_int b/2⌋" using assms(1) by auto
    then have "y mod b > ⌊real_of_int b/2⌋"
      using assms(2) odd_smaller_b by fastforce
    then show False using assms(4) by auto
  qed
  then have "y ∈ {0..<b}" using assms(2-3)
  by (metis atLeastAtMost_iff atLeastLessThan_empty atLeastLessThan_iff floor_divide_of_int_eq 
    int_div_less_self linorder_not_le numeral_One numeral_less_iff of_int_numeral semiring_norm(76))
  then show ?thesis by auto
qed
lemma mod_plus_minus_rangeE:
  assumes "y ∈ {-⌊real_of_int b/2⌋..⌊real_of_int b/2⌋}"
          "odd b" "b > 0"
  shows "y = y mod+- b"
unfolding mod_plus_minus_def 
using mod_plus_minus_rangeE_pos[OF assms] mod_plus_minus_rangeE_neg[OF assms]
by auto
text ‹Image of $0$.›
lemma mod_plus_minus_zero:
  assumes "x mod+- b = 0"
  shows "x mod b = 0"
using assms unfolding mod_plus_minus_def 
by (metis eq_iff_diff_eq_0 mod_mod_trivial mod_self)
lemma mod_plus_minus_zero':
  assumes "b>0" "odd b"
  shows "0 mod+- b = (0::int)" 
using assms(1) mod_plus_minus_def by force
text ‹‹mod+-› with negative values.›
lemma neg_mod_plus_minus:
  assumes "odd b"
          "b>0"
  shows "(- x) mod+- b = - (x mod+- b)"
proof -
  obtain k :: int where k_def: "(-x) mod+- b = (-x)+ k* b" 
  using mod_plus_minus_def
  proof -
    assume a1: "⋀k. - x mod+- b = - x + k * b ⟹ thesis"
    have "∃i. i mod b + - (x + i) = - x mod+- b" 
    by (smt (verit, del_insts) mod_add_self1 mod_plus_minus_def)
    then show ?thesis
      using a1 by (metis (no_types) diff_add_cancel diff_diff_add 
      diff_minus_eq_add minus_diff_eq minus_mult_div_eq_mod 
      mult.commute mult_minus_left)
  qed
  then have "(-x) mod+- b = -(x - k*b)" using k_def by auto
  also have "… = - ((x-k*b) mod+- b)"
  proof -
    have range_xkb:"x - k * b ∈ 
      {- ⌊real_of_int b / 2⌋..⌊real_of_int b / 2⌋}"
      using k_def mod_plus_minus_range_odd[OF assms(2) assms(1)]
      by (smt (verit, ccfv_SIG) atLeastAtMost_iff)
    have "x - k*b = (x - k*b) mod+- b" 
      using mod_plus_minus_rangeE[OF range_xkb assms] by auto
    then show ?thesis by auto
  qed
  also have "-((x - k*b) mod+- b) = -(x mod+- b)" 
    unfolding mod_plus_minus_def 
    by (smt (verit, best) mod_mult_self1)
  finally show ?thesis by auto
qed
text ‹Representative with ‹mod+-››
lemma mod_plus_minus_rep_ex:
"∃k. x = k*b + x mod+- b"
unfolding mod_plus_minus_def 
by (split if_splits)(metis add.right_neutral add_diff_eq div_mod_decomp_int 
  eq_iff_diff_eq_0 mod_add_self2)
lemma mod_plus_minus_rep: 
  obtains k where "x = k*b + x mod+- b"
using mod_plus_minus_rep_ex by auto
text ‹Multiplication in ‹mod+-››
lemma mod_plus_minus_mult: 
  "s*x mod+- q = (s mod+- q) * (x mod+- q) mod+- q"
unfolding mod_plus_minus_def 
by (smt (verit, ccfv_threshold) minus_mod_self2 mod_mult_left_eq mod_mult_right_eq)
end