Theory Monad_Memo_DP.State_Main

subsection ‹Setup for the State Monad›

theory State_Main
  imports
    "../transform/Transform_Cmd"
    Memory
begin

context includes state_monad_syntax begin

thm if_cong
lemma ifT_cong:
  assumes "b = c" "c  x = u" "¬c  y = v"
  shows "State_Monad_Ext.ifT b x y = State_Monad_Ext.ifT c u v"
  unfolding State_Monad_Ext.ifT_def
  unfolding bind_left_identity
  using if_cong[OF assms] .

lemma return_app_return_cong:
  assumes "f x = g y"
  shows "f . x = g . y"
  unfolding State_Monad_Ext.return_app_return_meta assms ..

lemmas [fundef_cong] =
  return_app_return_cong
  ifT_cong
end

memoize_fun compT: comp monadifies (state) comp_def
lemma (in dp_consistency) compT_transfer[transfer_rule]:
  "crel_vs ((R1 ===>T R2) ===>T (R0 ===>T R1) ===>T (R0 ===>T R2)) comp compT"
  apply memoize_combinator_init
  subgoal premises IH [transfer_rule] by memoize_unfold_defs transfer_prover
  done

memoize_fun mapT: map monadifies (state) list.map
lemma (in dp_consistency) mapT_transfer[transfer_rule]:
  "crel_vs ((R0 ===>T R1) ===>T list_all2 R0 ===>T list_all2 R1) map mapT"
  apply memoize_combinator_init
  apply (erule list_all2_induct)
  subgoal premises [transfer_rule] by memoize_unfold_defs transfer_prover
  subgoal premises [transfer_rule] by memoize_unfold_defs transfer_prover
  done

memoize_fun foldT: fold monadifies (state) fold.simps
lemma (in dp_consistency) foldT_transfer[transfer_rule]:
  "crel_vs ((R0 ===>T R1 ===>T R1) ===>T list_all2 R0 ===>T R1 ===>T R1) fold foldT"
  apply memoize_combinator_init
  apply (erule list_all2_induct)
  subgoal premises [transfer_rule] by memoize_unfold_defs transfer_prover
  subgoal premises [transfer_rule] by memoize_unfold_defs transfer_prover
  done

context includes state_monad_syntax begin

thm map_cong
lemma mapT_cong:
  assumes "xs = ys" "x. xset ys  f x = g x"
  shows "mapT . f . xs = mapT . g . ys"
  unfolding mapT_def 
  unfolding assms(1)
  using assms(2) by (induction ys) (auto simp: State_Monad_Ext.return_app_return_meta)

thm fold_cong
lemma foldT_cong:
  assumes "xs = ys" "x. xset ys  f x = g x"
  shows "foldT . f . xs = foldT . g . ys"
  unfolding foldT_def
  unfolding assms(1)
  using assms(2) by (induction ys) (auto simp: State_Monad_Ext.return_app_return_meta)

lemma abs_unit_cong:
  (* for lazy checkmem *)
  assumes "x = y"
  shows "(λ_::unit. x) = (λ_. y)"
  using assms ..

lemmas [fundef_cong] =
  return_app_return_cong
  ifT_cong
  mapT_cong
  foldT_cong
  abs_unit_cong
end

context dp_consistency begin
context includes lifting_syntax state_monad_syntax begin

named_theorems dp_match_rule

thm if_cong
lemma ifT_cong2:
  assumes "Rel (=) b c" "c  Rel (crel_vs R) x xT" "¬c  Rel (crel_vs R) y yT"
  shows "Rel (crel_vs R) (if (Wrap b) then x else y) (State_Monad_Ext.ifT c xT yT)"
  using assms unfolding State_Monad_Ext.ifT_def bind_left_identity Rel_def Wrap_def
  by (auto split: if_split)

lemma mapT_cong2:
  assumes
    "is_equality R"
    "Rel R xs ys"
    "x. xset ys  Rel (crel_vs S) (f x) (fT' x)"
  shows "Rel (crel_vs (list_all2 S)) (App (App map (Wrap f)) (Wrap xs)) (mapT . fT' . ys)"
  unfolding mapT_def
  unfolding State_Monad_Ext.return_app_return_meta
  unfolding assms(2)[unfolded Rel_def assms(1)[unfolded is_equality_def]]
  using assms(3)
  unfolding Rel_def Wrap_def App_def
  apply (induction ys)
  subgoal premises by (memoize_unfold_defs (state) map) transfer_prover
  subgoal premises prems for a ys
  apply (memoize_unfold_defs (state) map)
    apply (unfold State_Monad_Ext.return_app_return_meta Wrap_App_Wrap)
    supply [transfer_rule] =
      prems(2)[OF list.set_intros(1)]
      prems(1)[OF prems(2)[OF list.set_intros(2)], simplified]
    by transfer_prover
  done

lemma foldT_cong2:
  assumes
    "is_equality R"
    "Rel R xs ys"
    "x. xset ys  Rel (crel_vs (S ===> crel_vs S)) (f x) (fT' x)"
  shows
    "Rel (crel_vs (S ===> crel_vs S)) (fold f xs) (foldT . fT' . ys)"
  unfolding foldT_def
  unfolding State_Monad_Ext.return_app_return_meta
  unfolding assms(2)[unfolded Rel_def assms(1)[unfolded is_equality_def]]
  using assms(3)
  unfolding Rel_def
  apply (induction ys)
  subgoal premises by (memoize_unfold_defs (state) fold) transfer_prover
  subgoal premises prems for a ys
    apply (memoize_unfold_defs (state) fold)
    apply (unfold State_Monad_Ext.return_app_return_meta Wrap_App_Wrap)
    supply [transfer_rule] =
      prems(2)[OF list.set_intros(1)]
      prems(1)[OF prems(2)[OF list.set_intros(2)], simplified]
    by transfer_prover
  done

lemma refl2:
  "is_equality R  Rel R x x"
  unfolding is_equality_def Rel_def by simp

lemma rel_fun2:
  assumes "is_equality R0" "x. Rel R1 (f x) (g x)"
  shows "Rel (rel_fun R0 R1) f g"
  using assms unfolding is_equality_def Rel_def by auto

lemma crel_vs_return_app_return:
  assumes "Rel R (f x) (g x)"
  shows "Rel R (App (Wrap f) (Wrap x)) (g . x)"
  using assms unfolding State_Monad_Ext.return_app_return_meta Wrap_App_Wrap .

thm option.case_cong[no_vars]
lemma option_case_cong':
"Rel (=) option' option 
(option = None  Rel R f1 g1) 
(x2. option = Some x2  Rel R (f2 x2) (g2 x2)) 
Rel R (case option' of None  f1 | Some x2  f2 x2)
(case option of None  g1 | Some x2  g2 x2)"
  unfolding Rel_def by (auto split: option.split)

thm prod.case_cong[no_vars]
lemma prod_case_cong': fixes prod prod' shows
"Rel (=) prod prod' 
(x1 x2. prod' = (x1, x2)  Rel R (f x1 x2) (g x1 x2)) 
Rel R (case prod of (x1, x2)  f x1 x2)
(case prod' of (x1, x2)  g x1 x2)"
  unfolding Rel_def by (auto split: prod.splits)

thm nat.case_cong[no_vars]
lemma nat_case_cong': fixes nat nat' shows
"Rel (=) nat nat' 
(nat' = 0  Rel R f1 g1) 
(x2. nat' = Suc x2  Rel R (f2 x2) (g2 x2)) 
Rel R (case nat of 0  f1 | Suc x2  f2 x2) (case nat' of 0  g1 | Suc x2  g2 x2)"
  unfolding Rel_def by (auto split: nat.splits)
  

lemmas [dp_match_rule] =
  prod_case_cong'
  option_case_cong'
  nat_case_cong'


lemmas [dp_match_rule] =
  crel_vs_return_app_return

lemmas [dp_match_rule] =
  mapT_cong2
  foldT_cong2
  ifT_cong2

lemmas [dp_match_rule] =
  crel_vs_return
  crel_vs_fun_app
  refl2
  rel_fun2

(*
lemmas [dp_match_rule] =
  crel_vs_checkmem_tupled
*)

end (* context lifting_syntax *)
end (* context dp_consistency *)


subsubsection ‹Code Setup›

lemmas [code_unfold] =
  state_mem_defs.checkmem_checkmem'[symmetric]
  state_mem_defs.checkmem'_def
  mapT_def

end (* theory *)