Theory Multiset_Extra

theory Multiset_Extra
  imports
    "HOL-Library.Multiset"
    "HOL-Library.Multiset_Order"
    Nested_Multisets_Ordinals.Multiset_More
    Abstract_Substitution.Natural_Magma_Functor
begin

lemma exists_multiset [intro]: "M. x  set_mset M"
  by (meson union_single_eq_member)

global_interpretation muliset_magma: natural_magma_with_empty where
  to_set = set_mset and plus = "(+)" and wrap = "λl. {#l#}" and add = add_mset and empty = "{#}"
  by unfold_locales simp_all

global_interpretation multiset_functor: finite_natural_functor where
  map = image_mset and to_set = set_mset
  by unfold_locales auto

global_interpretation multiset_functor: natural_functor_conversion where
  map = image_mset and to_set = set_mset and map_to = image_mset and map_from = image_mset and
  map' = image_mset and to_set' = set_mset
  by unfold_locales simp_all

global_interpretation muliset_functor: natural_magma_functor where
  map = image_mset and to_set = set_mset and plus = "(+)" and wrap = "λl. {#l#}" and add = add_mset
  by unfold_locales simp_all

lemma one_le_countE:
  assumes "1  count M x"
  obtains M' where "M = add_mset x M'"
  using assms by (meson count_greater_eq_one_iff multi_member_split)

lemma two_le_countE:
  assumes "2  count M x"
  obtains M' where "M = add_mset x (add_mset x M')"
  using assms
  by (metis Suc_1 Suc_eq_plus1_left Suc_leD add.right_neutral count_add_mset multi_member_split
      not_in_iff not_less_eq_eq)

lemma three_le_countE:
  assumes "3  count M x"
  obtains M' where "M = add_mset x (add_mset x (add_mset x M'))"
  using assms
  by (metis One_nat_def Suc_1 Suc_leD add_le_cancel_left count_add_mset numeral_3_eq_3 plus_1_eq_Suc
      two_le_countE)

lemma one_step_implies_multpHO_strong:
  fixes A B J K :: "_ multiset"
  defines "J  B - A" and "K  A - B"
  assumes "J  {#}" and "k ∈# K. x ∈# J. R k x"
  shows "multpHO R A B"
  unfolding multpHO_def
proof (intro conjI allI impI)
  show "A  B"
    using assms
    by force
next
  fix y
  assume "count B y < count A y"

  then show "x. R y x  count A x < count B x"
    using assms
    by (metis in_diff_count)
qed

lemma Uniq_antimono: "Q  P  Uniq Q  Uniq P"
  unfolding le_fun_def le_bool_def
  by (rule impI) (simp only: Uniq_I Uniq_D)

lemma Uniq_antimono': "(x. Q x  P x)  Uniq P  Uniq Q"
  by (fact Uniq_antimono[unfolded le_fun_def le_bool_def, rule_format])

lemma multp_singleton_right[simp]:
  assumes "transp R"
  shows "multp R M {#x#}  (y ∈# M. R y x)"
proof (rule iffI)
  show "y ∈# M. R y x  multp R M {#x#}"
    using one_step_implies_multp[of "{#x#}" _ R "{#}", simplified] .
next
  show "multp R M {#x#}  y∈#M. R y x"
    using multp_implies_one_step[OF transp R]
    by (smt (verit, del_insts) add_0 set_mset_add_mset_insert set_mset_empty single_is_union
        singletonD)
qed

lemma multp_singleton_left[simp]:
  assumes "transp R"
  shows "multp R {#x#} M  ({#x#} ⊂# M  (y ∈# M. R x y))"
proof (rule iffI)
  show "{#x#} ⊂# M  (y∈#M. R x y)  multp R {#x#} M"
  proof (elim disjE bexE)
    show "{#x#} ⊂# M  multp R {#x#} M"
      by (simp add: subset_implies_multp)
  next
    show "y. y ∈# M  R x y  multp R {#x#} M"
      using one_step_implies_multp[of M "{#x#}" R "{#}", simplified] by force
  qed
next
  show "multp R {#x#} M  {#x#} ⊂# M  (y∈#M. R x y)"
    using multp_implies_one_step[OF transp R, of "{#x#}" M]
    by (metis (no_types, opaque_lifting) add_cancel_right_left subset_mset.gr_zeroI
        subset_mset.less_add_same_cancel2 union_commute union_is_single union_single_eq_member)
qed

lemma multp_singleton_singleton[simp]: "transp R  multp R {#x#} {#y#}  R x y"
  using multp_singleton_right[of R "{#x#}" y] by simp

lemma multp_subset_supersetI: "transp R  multp R A B  C ⊆# A  B ⊆# D  multp R C D"
  by (metis subset_implies_multp subset_mset.antisym_conv2 transpE transp_multp)

lemma multp_double_doubleI:
  assumes "transp R" "multp R A B"
  shows "multp R (A + A) (B + B)"
  using multp_repeat_mset_repeat_msetI[OF transp R multp R A B, of 2]
  by (simp add: numeral_Bit0)

lemma multp_implies_one_step_strong:
  fixes A B I J K :: "_ multiset"
  assumes "transp R" and "asymp R" and "multp R A B"
  defines "J  B - A" and "K  A - B"
  shows "J  {#}" and "k ∈# K. x ∈# J. R k x"
proof -
  from assms have "multpHO R A B"
    by (simp add: multp_eq_multpHO)

  thus "J  {#}" and "k ∈# K. x ∈# J. R k x"
    using multpHO_implies_one_step_strong[OF multpHO R A B]
    by (simp_all add: J_def K_def)
qed

lemma multp_double_doubleD:
  assumes "transp R" and "asymp R" and "multp R (A + A) (B + B)"
  shows "multp R A B"
proof -
  from assms have
    "B + B - (A + A)  {#}" and
    "k∈#A + A - (B + B). x∈#B + B - (A + A). R k x"
    using multp_implies_one_step_strong[OF assms] by simp_all

  have "multp R (A ∩# B + (A - B)) (A ∩# B + (B - A))"
  proof (rule one_step_implies_multp[of "B - A" "A - B" R "A ∩# B"])
    show "B - A  {#}"
      using B + B - (A + A)  {#}
      by (meson Diff_eq_empty_iff_mset mset_subset_eq_mono_add)
  next
    show "k∈#A - B. j∈#B - A. R k j"
    proof (intro ballI)
      fix x assume "x ∈# A - B"
      hence "x ∈# A + A - (B + B)"
        by (simp add: in_diff_count)
      then obtain y where "y ∈# B + B - (A + A)" and "R x y"
        using k∈#A + A - (B + B). x∈#B + B - (A + A). R k x by auto
      then show "j∈#B - A. R x j"
        by (auto simp add: in_diff_count)
    qed
  qed

  moreover have "A = A ∩# B + (A - B)"
    by (simp add: inter_mset_def)

  moreover have "B = A ∩# B + (B - A)"
    by (metis diff_intersect_right_idem subset_mset.add_diff_inverse subset_mset.inf.cobounded2)

  ultimately show ?thesis
    by argo
qed

lemma multp_double_double:
  "transp R  asymp R  multp R (A + A) (B + B)  multp R A B"
  using multp_double_doubleD multp_double_doubleI by metis

lemma multp_doubleton_doubleton[simp]:
  "transp R  asymp R  multp R {#x, x#} {#y, y#}  R x y"
  using multp_double_double[of R "{#x#}" "{#y#}", simplified] by simp

lemma multp_single_doubleI: "M  {#}  multp R M (M + M)"
  using one_step_implies_multp[of M "{#}" _ M, simplified] by simp

lemma mult1_implies_one_step_strong:
  assumes "trans r" and "asym r" and "(A, B)  mult1 r"
  shows "B - A  {#}" and "k ∈# A - B. j ∈# B - A. (k, j)  r"
proof -
  from (A, B)  mult1 r obtain b B' A' where
    B_def: "B = add_mset b B'" and
    A_def: "A = B' + A'" and
    "a. a ∈# A'  (a, b)  r"
    unfolding mult1_def by auto

  have "b ∉# A'"
    by (meson a. a ∈# A'  (a, b)  r assms(2) asym_onD iso_tuple_UNIV_I)
  then have "b ∈# B - A"
    by (simp add: A_def B_def)
  thus "B - A  {#}"
    by auto

  show "k ∈# A - B. j ∈# B - A. (k, j)  r"
    by (metis A_def B_def a. a ∈# A'  (a, b)  r b ∈# B - A b ∉# A' add_diff_cancel_left'
        add_mset_add_single diff_diff_add_mset diff_single_trivial)
qed

lemma asymp_multp:
  assumes "asymp R" and "transp R"
  shows "asymp (multp R)"
  using asymp_multpHO[OF assms]
  unfolding multp_eq_multpHO[OF assms].

lemma multp_doubleton_singleton: "transp R  multp R {# x, x #} {# y #}  R x y"
  by (cases "x = y") auto

lemma image_mset_remove1_mset:
  assumes "inj f"
  shows "remove1_mset (f a) (image_mset f X) = image_mset f (remove1_mset a X)"
  using image_mset_remove1_mset_if
  unfolding image_mset_remove1_mset_if inj_image_mem_iff[OF assms, symmetric]
  by simp

lemma multpDM_map_strong:
  assumes
    f_mono: "monotone_on (set_mset (M1 + M2)) R S f" and
    M1_lt_M2: "multpDM R M1 M2"
  shows "multpDM S (image_mset f M1) (image_mset f M2)"
proof -
  obtain Y X where
    "Y  {#}" and "Y ⊆# M2" and M1_eq: "M1 = M2 - Y + X" and
    ex_y: "x. x ∈# X  (y. y ∈# Y  R x y)"
    using M1_lt_M2[unfolded multpDM_def Let_def mset_map] by blast


  let ?fY = "image_mset f Y"
  let ?fX = "image_mset f X"

  show ?thesis
    unfolding multpDM_def
  proof (intro exI conjI)
    show "image_mset f Y  {#}"
      using Y  {#} unfolding image_mset_is_empty_iff .
  next
    show "image_mset f Y ⊆# image_mset f M2"
      using Y ⊆# M2 image_mset_subseteq_mono by metis
  next
    show "image_mset f M1 = image_mset f M2 - ?fY + ?fX"
      using M1_eq[THEN arg_cong, of "image_mset f"] Y ⊆# M2
      by (metis image_mset_Diff image_mset_union)
  next
    obtain g where y: "x. x ∈# X  g x ∈# Y  R x (g x)"
      using ex_y by moura

    show "fx. fx ∈# ?fX  (fy. fy ∈# ?fY  S fx fy)"
    proof (intro allI impI)
      fix x' assume "x' ∈# ?fX"
      then obtain x where x': "x' = f x" and x_in: "x ∈# X"
        by auto
      hence y_in: "g x ∈# Y" and y_gt: "R x (g x)"
        using y[rule_format, OF x_in] by blast+

      moreover have "X ⊆# M1"
        using M1_eq by simp

      ultimately have "f (g x) ∈# ?fY  S (f x)(f (g x)) "
        using f_mono[THEN monotone_onD, of x "g x"] Y ⊆# M2 X ⊆# M1 x_in
        by (metis imageI in_image_mset mset_subset_eqD union_iff)
      thus "fy. fy ∈# ?fY  S x' fy"
        unfolding x' by auto
    qed
  qed
qed

lemma multp_map_strong:
  assumes
    transp: "transp R" and
    f_mono: "monotone_on (set_mset (M1 + M2)) R S f" and
    M1_lt_M2: "multp R M1 M2"
  shows "multp S (image_mset f M1) (image_mset f M2)"
  using monotone_on_multp_multp_image_mset[THEN monotone_onD, OF f_mono transp _ _ M1_lt_M2]
  by simp

lemma multpHO_add_mset:
  assumes "asymp R" "transp R" "R x y" "multpHO R X Y"
  shows "multpHO R (add_mset x X) (add_mset y Y)"
  unfolding multpHO_def
proof(intro allI conjI impI)
  show "add_mset x X  add_mset y Y"
    using assms(1, 3, 4)
    unfolding multpHO_def
    by (metis asympD count_add_mset lessI less_not_refl)
next
  fix x'
  assume count_x': "count (add_mset y Y) x' < count (add_mset x X) x'"
  show "y'. R x' y'  count (add_mset x X) y' < count (add_mset y Y) y'"
  proof(cases "x' = x")
      case True
      then show ?thesis
        using assms
        unfolding multpHO_def
        by (metis count_add_mset irreflpD irreflp_on_if_asymp_on not_less_eq transpE)
    next
      case x'_neq_x: False
      show ?thesis
      proof(cases "y = x'")
        case True
        then show ?thesis
          using assms(1, 3, 4) count_x' x'_neq_x
          unfolding multpHO_def count_add_mset
          by (smt (verit) Suc_lessD asympD)
      next
        case False
        then show ?thesis
          using assms count_x' x'_neq_x
          unfolding multpHO_def count_add_mset
          by (smt (verit, del_insts) irreflpD irreflp_on_if_asymp_on not_less_eq transpE)
      qed
    qed
qed

(* TODO: Better names? *)
lemma multp_add_mset:
  assumes "asymp R" "transp R" "R x y" "multp R X Y"
  shows "multp R (add_mset x X) (add_mset y Y)"
  using multpHO_add_mset[OF assms(1-3)] assms(4)
  unfolding multp_eq_multpHO[OF assms(1, 2)]
  by simp

lemma multp_add_mset':
  assumes "R x y"
  shows "multp R (add_mset x X) (add_mset y X)"
  using assms
  by (metis add_mset_add_single empty_iff insert_iff one_step_implies_multp set_mset_add_mset_insert
        set_mset_empty)

lemma multp_add_mset_reflclp:
  assumes "asymp R" "transp R" "R x y" "(multp R)== X Y"
  shows "multp R (add_mset x X) (add_mset y Y)"
  using
    assms(4)
    multp_add_mset'[of R, OF assms(3)]
    multp_add_mset[OF assms(1-3)]
  by blast

lemma multp_add_same [simp]:
  assumes "asymp R" "transp R"
  shows "multp R (add_mset x X) (add_mset x Y)  multp R X Y"
  by (meson assms asymp_on_subset irreflp_on_if_asymp_on multp_cancel_add_mset top_greatest)

lemma inj_mset_plus_same: "inj (λX :: 'a multiset . X + X)"
proof(unfold inj_def, intro allI impI)
  fix X Y :: "'a multiset"
  assume "X + X = Y + Y"

  then show "X = Y"
  proof(induction X arbitrary: Y)
    case empty
    then show ?case
      by simp
  next
    case (add x X)
    then show ?case
      by (metis diff_single_eq_union diff_union_single_conv single_subset_iff
          subset_mset.add_diff_assoc2 union_iff union_single_eq_member)
  qed
qed

(* TODO: Should be possible
lemma
  assumes "wfP (multpDM R)" (* "asymp (multpDM R)" "transp (multpDM R)" ? *)
  shows "wfP R"
  using assms
  sorry
*)

(* TODO: everywhere less_eq → lesseq *)
lemma multp_image_lesseq_if_all_lesseq:
   assumes
    asymp: "asymp R" and
    transp: "transp R" and
    all_lesseq: "x∈#X. R== (f x) (g x)"
  shows "(multp R)== (image_mset f X) (image_mset g X)"
  using assms
  by(induction X) (auto simp: multp_add_mset multp_add_mset')

lemma multp_image_less_if_all_lesseq_ex_less:
  assumes
    asymp: "asymp R" and
    transp: "transp R" and
    all_less_eq: "x∈#X. R== (f x) (g x)" and
    ex_less: "x∈#X. R (f x) (g x)"
  shows "multp R {# f x. x ∈# X #} {# g x. x ∈# X #}"
  using all_less_eq ex_less
proof(induction X)
  case empty
  then show ?case
    by simp
next
  case (add x X)

  show ?case
  proof(cases "x∈#X. R (f x) (g x)")
    case True

    then have "x∈#X. R== (f x) (g x)" "x∈#X. R (f x) (g x)"
      using add.prems
      by auto

    then have "multp R (image_mset f X) (image_mset g X)"
      using add.IH
      by blast

    then show ?thesis
      using add.prems(1) multp_add_mset[OF asymp transp] multp_add_same[OF asymp transp]
      by auto
  next
    case False

    then have "R (f x) (g x)"
      using add.prems(2) by fastforce

    moreover have "x∈#X. f x = g x"
      using False add.prems(1) by auto

    ultimately show ?thesis
      by (metis image_mset_add_mset multiset.map_cong0 multp_add_mset')
  qed
qed

lemma not_reflp_multpDM: "¬ reflp (multpDM R)"
  unfolding multpDM_def reflp_def
  by force

lemma not_less_empty_multpDM: "¬ multpDM R X {#}"
  by (simp add: multpDM_def)

lemma not_reflp_multpHO: "¬ reflp (multpHO R)"
  unfolding multpHO_def reflp_def
  by simp

lemma not_less_empty_multpHO: "¬ multpHO R X {#}"
  by (simp add: multpHO_def)

lemma not_refl_mult: "¬ refl (mult R)"
  unfolding refl_on_def mult_def
  by (meson UNIV_I not_less_empty trancl.cases)

lemma not_less_empty_mult: "(X, {#})  mult R"
  by (metis mult_def not_less_empty tranclD2)

lemma empty_less_mult: "X  {#}  ({#}, X)  mult R"
  using subset_implies_mult
  by force

lemma not_reflp_multp: "¬ reflp (multp R)"
  using not_refl_mult
  unfolding multp_def reflp_refl_eq
  by blast

lemma empty_less_multp: "X  {#}  multp R {#} X"
  by (simp add: subset_implies_multp subset_mset.not_eq_extremum)

lemma not_less_empty_multp: "¬ multp R X {#}"
  using not_less_empty_mult
  unfolding multp_def
  by blast

end