Theory PAC_Checker_Relation

(*
  File:         PAC_Checker_Relation.thy
  Author:       Mathias Fleury, Daniela Kaufmann, JKU
  Maintainer:   Mathias Fleury, JKU
*)
theory PAC_Checker_Relation
  imports PAC_Checker WB_Sort "Native_Word.Uint64"
begin

section ‹Various Refinement Relations›

text ‹When writing this, it was not possible to share the definition with the IsaSAT version.›
definition uint64_nat_rel :: "(uint64 × nat) set" where
 uint64_nat_rel = br nat_of_uint64 (λ_. True)

abbreviation uint64_nat_assn where
  uint64_nat_assn  pure uint64_nat_rel

instantiation uint32 :: hashable
begin
definition hashcode_uint32 :: uint32  uint32 where
  hashcode_uint32 n = n

definition def_hashmap_size_uint32 :: uint32 itself  nat where
  def_hashmap_size_uint32 = (λ_. 16)
  ― ‹same as @{typ nat}
instance
  by standard (simp add: def_hashmap_size_uint32_def)
end

instantiation uint64 :: hashable
begin

context
  includes bit_operations_syntax
begin

definition hashcode_uint64 :: uint64  uint32 where
  hashcode_uint64 n = (uint32_of_nat (nat_of_uint64 ((n) AND ((2 :: uint64)^32 -1))))

end

definition def_hashmap_size_uint64 :: uint64 itself  nat where
  def_hashmap_size_uint64 = (λ_. 16)
  ― ‹same as @{typ nat}
instance
  by standard (simp add: def_hashmap_size_uint64_def)
end

lemma word_nat_of_uint64_Rep_inject[simp]: nat_of_uint64 ai = nat_of_uint64 bi  ai = bi
  by transfer (simp add: word_unat_eq_iff)

instance uint64 :: heap
  by standard (auto simp: inj_def exI[of _ nat_of_uint64])

instance uint64 :: semiring_numeral
  by standard

lemma nat_of_uint64_012[simp]: nat_of_uint64 0 = 0 nat_of_uint64 2 = 2 nat_of_uint64 1 = 1
  by (simp_all add: nat_of_uint64.rep_eq zero_uint64.rep_eq one_uint64.rep_eq)

definition uint64_of_nat_conv where
  [simp]: uint64_of_nat_conv (x :: nat) = x

lemma less_upper_bintrunc_id: n < 2 ^b  n  0  take_bit b n = n for n :: int
  by (rule take_bit_int_eq_self)

lemma nat_of_uint64_uint64_of_nat_id: n < 2^64  nat_of_uint64 (uint64_of_nat n) = n
  by transfer (simp add: take_bit_nat_eq_self unsigned_of_nat)

lemma [sepref_fr_rules]:
  (return o uint64_of_nat, RETURN o uint64_of_nat_conv)  [λa. a < 2 ^64]a nat_assnk  uint64_nat_assn
  by sepref_to_hoare
   (sep_auto simp: uint64_nat_rel_def br_def nat_of_uint64_uint64_of_nat_id)

definition string_rel :: (String.literal × string) set where
  string_rel = {(x, y). y = String.explode x}

abbreviation string_assn :: string  String.literal  assn where
  string_assn  pure string_rel

lemma eq_string_eq:
  ((=), (=))  string_rel  string_rel  bool_rel
 by (auto intro!: frefI simp: string_rel_def String.less_literal_def
    less_than_char_def rel2p_def literal.explode_inject)

lemmas eq_string_eq_hnr =
   eq_string_eq[sepref_import_param]

definition string2_rel :: (string × string) set where
  string2_rel  Idlist_rel

abbreviation string2_assn :: string  string  assn where
  string2_assn  pure string2_rel

abbreviation monom_rel where
  monom_rel  string_rellist_rel

abbreviation monom_assn where
  monom_assn  list_assn string_assn

abbreviation monomial_rel where
  monomial_rel  monom_rel ×r int_rel

abbreviation monomial_assn where
  monomial_assn  monom_assn ×a int_assn

abbreviation poly_rel where
  poly_rel  monomial_rellist_rel


abbreviation poly_assn where
  poly_assn  list_assn monomial_assn

lemma poly_assn_alt_def:
  poly_assn = pure poly_rel
  by (simp add: list_assn_pure_conv)

abbreviation polys_assn where
  polys_assn  hm_fmap_assn uint64_nat_assn poly_assn

lemma string_rel_string_assn:
  ( ((c, a)  string_rel)) = string_assn a c
  by (auto simp: pure_app_eq)

lemma single_valued_string_rel:
   single_valued string_rel
   by (auto simp: single_valued_def string_rel_def)

lemma IS_LEFT_UNIQUE_string_rel:
   IS_LEFT_UNIQUE string_rel
   by (auto simp: IS_LEFT_UNIQUE_def single_valued_def string_rel_def
     literal.explode_inject)

lemma IS_RIGHT_UNIQUE_string_rel:
   IS_RIGHT_UNIQUE string_rel
   by (auto simp: single_valued_def string_rel_def
     literal.explode_inject)

lemma single_valued_monom_rel: single_valued monom_rel
  by (rule list_rel_sv)
    (auto intro!: frefI simp: string_rel_def
    rel2p_def single_valued_def p2rel_def)

lemma single_valued_monomial_rel:
  single_valued monomial_rel
  using single_valued_monom_rel
  by (auto intro!: frefI simp:
    rel2p_def single_valued_def p2rel_def)

lemma single_valued_monom_rel': IS_LEFT_UNIQUE monom_rel
  unfolding IS_LEFT_UNIQUE_def inv_list_rel_eq string2_rel_def
  by (rule list_rel_sv)+
   (auto intro!: frefI simp: string_rel_def
    rel2p_def single_valued_def p2rel_def literal.explode_inject)


lemma single_valued_monomial_rel':
  IS_LEFT_UNIQUE monomial_rel
  using single_valued_monom_rel'
  unfolding IS_LEFT_UNIQUE_def inv_list_rel_eq
  by (auto intro!: frefI simp:
    rel2p_def single_valued_def p2rel_def)

lemma [safe_constraint_rules]:
  Sepref_Constraints.CONSTRAINT single_valued string_rel
  Sepref_Constraints.CONSTRAINT IS_LEFT_UNIQUE string_rel
  by (auto simp: CONSTRAINT_def single_valued_def
    string_rel_def IS_LEFT_UNIQUE_def literal.explode_inject)

lemma eq_string_monom_hnr[sepref_fr_rules]:
  (uncurry (return oo (=)), uncurry (RETURN oo (=)))  monom_assnk *a monom_assnk a bool_assn
  using single_valued_monom_rel' single_valued_monom_rel
  unfolding list_assn_pure_conv
  by sepref_to_hoare
   (sep_auto simp: list_assn_pure_conv string_rel_string_assn
       single_valued_def IS_LEFT_UNIQUE_def
     dest!: mod_starD
     simp flip: inv_list_rel_eq)


definition term_order_rel' where
  [simp]: term_order_rel' x y = ((x, y)  term_order_rel)

lemma term_order_rel[def_pat_rules]:
  (∈)$(x,y)$term_order_rel  term_order_rel'$x$y
  by auto

lemma term_order_rel_alt_def:
  term_order_rel = lexord (p2rel char.lexordp)
  by (auto simp: p2rel_def char.lexordp_conv_lexord var_order_rel_def intro!: arg_cong[of _ _ lexord])


instantiation char :: linorder
begin
  definition less_char where [symmetric, simp]: "less_char = PAC_Polynomials_Term.less_char"
  definition less_eq_char where [symmetric, simp]: "less_eq_char = PAC_Polynomials_Term.less_eq_char"
instance
  apply standard
  using char.linorder_axioms
  by (auto simp: class.linorder_def class.order_def class.preorder_def
       less_eq_char_def less_than_char_def class.order_axioms_def
       class.linorder_axioms_def p2rel_def less_char_def)
end


instantiation list :: (linorder) linorder
begin
  definition less_list where  "less_list = lexordp (<)"
  definition less_eq_list where "less_eq_list = lexordp_eq"

instance
proof standard
  have [dest]: x y :: 'a :: linorder list. (x, y)  lexord {(x, y). x < y} 
           lexordp_eq y x  False
    by (metis lexordp_antisym lexordp_conv_lexord lexordp_eq_conv_lexord)
  have [simp]: x y :: 'a :: linorder list. lexordp_eq x y 
           ¬ lexordp_eq y x 
           (x, y)  lexord {(x, y). x < y}
    using lexordp_conv_lexord lexordp_conv_lexordp_eq by blast
  show
   (x < y) = Restricted_Predicates.strict (≤) x y
   x  x
   x  y  y  z  x  z
   x  y  y  x  x = y
   x  y  y  x
   for x y z :: 'a :: linorder list
    by (auto simp: less_list_def less_eq_list_def List.lexordp_def
    lexordp_conv_lexord lexordp_into_lexordp_eq lexordp_antisym
    antisym_def lexordp_eq_refl lexordp_eq_linear intro: lexordp_eq_trans
    dest: lexordp_eq_antisym)
qed

end


lemma term_order_rel'_alt_def_lexord:
    term_order_rel' x y = ord_class.lexordp x y and
  term_order_rel'_alt_def:
    term_order_rel' x y  x < y
proof -
  show
    term_order_rel' x y = ord_class.lexordp x y
    term_order_rel' x y  x < y
    unfolding less_than_char_of_char[symmetric, abs_def]
    by (auto simp: lexordp_conv_lexord less_eq_list_def
         less_list_def lexordp_def var_order_rel_def
         rel2p_def term_order_rel_alt_def p2rel_def)
qed

lemma list_rel_list_rel_order_iff:
  assumes (a, b)  string_rellist_rel (a', b')  string_rellist_rel
  shows a < a'  b < b'
proof
  have H: (a, b)  string_rellist_rel 
       (a, cs)  string_rellist_rel  b = cs for cs
     using single_valued_monom_rel' IS_RIGHT_UNIQUE_string_rel
     unfolding string2_rel_def
     by (subst (asm)list_rel_sv_iff[symmetric])
       (auto simp: single_valued_def)
  assume a < a'
  then consider
    u u' where a' = a @ u # u' |
    u aa v w aaa where a = u @ aa # v a' = u @ aaa # w aa < aaa
    by (subst (asm) less_list_def)
     (auto simp: lexord_def List.lexordp_def
      list_rel_append1 list_rel_split_right_iff)
  then show b < b'
  proof cases
    case 1
    then show b < b'
      using assms
      by (subst less_list_def)
        (auto simp: lexord_def List.lexordp_def
        list_rel_append1 list_rel_split_right_iff dest: H)
  next
    case 2
    then obtain u' aa' v' w' aaa' where
       b = u' @ aa' # v' b' = u' @ aaa' # w'
       (aa, aa')  string_rel
       (aaa, aaa')  string_rel
      using assms
      by (smt (verit) list_rel_append1 list_rel_split_right_iff single_valued_def single_valued_monom_rel)
    with aa < aaa have aa' < aaa'
      by (auto simp: string_rel_def less_literal.rep_eq less_list_def
        lexordp_conv_lexord lexordp_def char.lexordp_conv_lexord
          simp flip: less_char_def PAC_Polynomials_Term.less_char_def)
    then show b < b'
      using b = u' @ aa' # v' b' = u' @ aaa' # w'
      by (subst less_list_def)
        (fastforce simp: lexord_def List.lexordp_def
        list_rel_append1 list_rel_split_right_iff)
  qed
next
  have H: (a, b)  string_rellist_rel 
       (a', b)  string_rellist_rel  a = a' for a a' b
     using single_valued_monom_rel'
     by (auto simp: single_valued_def IS_LEFT_UNIQUE_def
       simp flip: inv_list_rel_eq)
  assume b < b'
  then consider
    u u' where b' = b @ u # u' |
    u aa v w aaa where b = u @ aa # v b' = u @ aaa # w aa < aaa
    by (subst (asm) less_list_def)
     (auto simp: lexord_def List.lexordp_def
      list_rel_append1 list_rel_split_right_iff)
  then show a < a'
  proof cases
    case 1
    then show a < a'
      using assms
      by (subst less_list_def)
        (auto simp: lexord_def List.lexordp_def
        list_rel_append2 list_rel_split_left_iff dest: H)
  next
    case 2
    then obtain u' aa' v' w' aaa' where
       a = u' @ aa' # v' a' = u' @ aaa' # w'
       (aa', aa)  string_rel
       (aaa', aaa)  string_rel
      using assms
      by (auto simp: lexord_def List.lexordp_def
        list_rel_append2 list_rel_split_left_iff dest: H)
    with aa < aaa have aa' < aaa'
      by (auto simp: string_rel_def less_literal.rep_eq less_list_def
        lexordp_conv_lexord lexordp_def char.lexordp_conv_lexord
          simp flip: less_char_def PAC_Polynomials_Term.less_char_def)
    then show a < a'
      using a = u' @ aa' # v' a' = u' @ aaa' # w'
      by (subst less_list_def)
        (fastforce simp: lexord_def List.lexordp_def
        list_rel_append1 list_rel_split_right_iff)
  qed
qed


lemma string_rel_le[sepref_import_param]:
  shows ((<), (<))  string_rellist_rel   string_rellist_rel  bool_rel
  by (auto intro!: fun_relI simp: list_rel_list_rel_order_iff)

(* TODO Move *)
lemma [sepref_import_param]:
  assumes CONSTRAINT IS_LEFT_UNIQUE R  CONSTRAINT IS_RIGHT_UNIQUE R
  shows (remove1, remove1)  R  Rlist_rel  Rlist_rel
  apply (intro fun_relI)
  subgoal premises p for x y xs ys
    using p(2) p(1) assms
    by (induction xs ys rule: list_rel_induct)
      (auto simp: IS_LEFT_UNIQUE_def single_valued_def)
  done

instantiation pac_step :: (heap, heap, heap) heap
begin

instance
proof standard
  obtain f :: 'a  nat where
    f: inj f
    by blast
  obtain g :: nat × nat × nat × nat × nat  nat where
    g: inj g
    by blast
  obtain h :: 'b  nat where
    h: inj h
    by blast
  obtain i :: 'c  nat where
    i: inj i
    by blast
  have [iff]: g a = g b  a = bh a'' = h b''  a'' = b''  f a' = f b'  a' = b'
    i a''' = i b'''  a''' = b'''  for a b a' b' a'' b'' a''' b'''
    using f g h i unfolding inj_def by blast+
  let ?f = λx :: ('a, 'b, 'c) pac_step.
     g (case x of
        Add a b c d      (0, i a,  i b,  i c, f d)
      | Del a            (1, i a,    0,   0,   0)
      | Mult a b c d     (2, i a, f b, i c, f d)
      | Extension a b c  (3, i a, f c, 0, h b))
   have inj ?f
     apply (auto simp: inj_def)
     apply (case_tac x; case_tac y)
     apply auto
     done
   then show f :: ('a, 'b, 'c) pac_step  nat. inj f
     by blast
qed

end

end