Theory Refine_Imperative_HOL.Sepref_Frame
section ‹Frame Inference›
theory Sepref_Frame
imports Sepref_Basic Sepref_Constraints
begin
text ‹ In this theory, we provide a specific frame inference tactic
for Sepref.
The first tactic, ‹frame_tac›, is a standard frame inference tactic,
based on the assumption that only @{const hn_ctxt}-assertions need to be
matched.
The second tactic, ‹merge_tac›, resolves entailments of the form
‹F1 ∨⇩A F2 ⟹⇩t ?F›
that occur during translation of if and case statements.
It synthesizes a new frame ?F, where refinements of variables
with equal refinements in ‹F1› and ‹F2› are preserved,
and the others are set to @{const hn_invalid}.
›
definition mismatch_assn :: "('a ⇒ 'c ⇒ assn) ⇒ ('a ⇒ 'c ⇒ assn) ⇒ 'a ⇒ 'c ⇒ assn"
where "mismatch_assn R1 R2 x y ≡ R1 x y ∨⇩A R2 x y"
abbreviation "hn_mismatch R1 R2 ≡ hn_ctxt (mismatch_assn R1 R2)"
lemma recover_pure_aux: "CONSTRAINT is_pure R ⟹ hn_invalid R x y ⟹⇩t hn_ctxt R x y"
by (auto simp: is_pure_conv invalid_pure_recover hn_ctxt_def)
lemma frame_thms:
"P ⟹⇩t P"
"P⟹⇩tP' ⟹ F⟹⇩tF' ⟹ F*P ⟹⇩t F'*P'"
"hn_ctxt R x y ⟹⇩t hn_invalid R x y"
"hn_ctxt R x y ⟹⇩t hn_ctxt (λ_ _. true) x y"
"CONSTRAINT is_pure R ⟹ hn_invalid R x y ⟹⇩t hn_ctxt R x y"
apply -
applyS simp
applyS (rule entt_star_mono; assumption)
subgoal
apply (simp add: hn_ctxt_def)
apply (rule enttI)
apply (rule ent_trans[OF invalidate[of R]])
by solve_entails
applyS (sep_auto simp: hn_ctxt_def)
applyS (erule recover_pure_aux)
done
named_theorems_rev sepref_frame_match_rules ‹Sepref: Additional frame rules›
text ‹Rules to discharge unmatched stuff›
lemma frame_rem1: "P⟹⇩tP" by simp
lemma frame_rem2: "F ⟹⇩t F' ⟹ F * hn_ctxt A x y ⟹⇩t F' * hn_ctxt A x y"
apply (rule entt_star_mono) by auto
lemma frame_rem3: "F ⟹⇩t F' ⟹ F * hn_ctxt A x y ⟹⇩t F'"
using frame_thms(2) by fastforce
lemma frame_rem4: "P ⟹⇩t emp" by simp
lemmas frame_rem_thms = frame_rem1 frame_rem2 frame_rem3 frame_rem4
named_theorems_rev sepref_frame_rem_rules
‹Sepref: Additional rules to resolve remainder of frame-pairing›
lemma ent_disj_star_mono:
"⟦ A ∨⇩A C ⟹⇩A E; B ∨⇩A D ⟹⇩A F ⟧ ⟹ A*B ∨⇩A C*D ⟹⇩A E*F"
by (metis ent_disjI1 ent_disjI2 ent_disjE ent_star_mono)
lemma entt_disj_star_mono:
"⟦ A ∨⇩A C ⟹⇩t E; B ∨⇩A D ⟹⇩t F ⟧ ⟹ A*B ∨⇩A C*D ⟹⇩t E*F"
proof -
assume a1: "A ∨⇩A C ⟹⇩t E"
assume "B ∨⇩A D ⟹⇩t F"
then have "A * B ∨⇩A C * D ⟹⇩A true * E * (true * F)"
using a1 by (simp add: ent_disj_star_mono enttD)
then show ?thesis
by (metis (no_types) assn_times_comm enttI merge_true_star_ctx star_aci(3))
qed
lemma hn_merge1:
"F ∨⇩A F ⟹⇩t F"
"⟦ hn_ctxt R1 x x' ∨⇩A hn_ctxt R2 x x' ⟹⇩t hn_ctxt R x x'; Fl ∨⇩A Fr ⟹⇩t F ⟧
⟹ Fl * hn_ctxt R1 x x' ∨⇩A Fr * hn_ctxt R2 x x' ⟹⇩t F * hn_ctxt R x x'"
apply simp
by (rule entt_disj_star_mono; simp)
lemma hn_merge2:
"hn_invalid R x x' ∨⇩A hn_ctxt R x x' ⟹⇩t hn_invalid R x x'"
"hn_ctxt R x x' ∨⇩A hn_invalid R x x' ⟹⇩t hn_invalid R x x'"
by (sep_auto eintros: invalidate ent_disjE intro!: ent_imp_entt simp: hn_ctxt_def)+
lemma invalid_assn_mono: "hn_ctxt A x y ⟹⇩t hn_ctxt B x y
⟹ hn_invalid A x y ⟹⇩t hn_invalid B x y"
by (clarsimp simp: invalid_assn_def entailst_def entails_def hn_ctxt_def)
(force simp: mod_star_conv)
lemma hn_merge3:
"⟦NO_MATCH (hn_invalid XX) R2; hn_ctxt R1 x x' ∨⇩A hn_ctxt R2 x x' ⟹⇩t hn_ctxt Rm x x'⟧ ⟹ hn_invalid R1 x x' ∨⇩A hn_ctxt R2 x x' ⟹⇩t hn_invalid Rm x x'"
"⟦NO_MATCH (hn_invalid XX) R1; hn_ctxt R1 x x' ∨⇩A hn_ctxt R2 x x' ⟹⇩t hn_ctxt Rm x x'⟧ ⟹ hn_ctxt R1 x x' ∨⇩A hn_invalid R2 x x' ⟹⇩t hn_invalid Rm x x'"
apply (meson entt_disjD1 entt_disjD2 entt_disjE entt_trans frame_thms(3) invalid_assn_mono)
apply (meson entt_disjD1 entt_disjD2 entt_disjE entt_trans frame_thms(3) invalid_assn_mono)
done
lemmas merge_thms = hn_merge1 hn_merge2
named_theorems sepref_frame_merge_rules ‹Sepref: Additional merge rules›
lemma hn_merge_mismatch: "hn_ctxt R1 x x' ∨⇩A hn_ctxt R2 x x' ⟹⇩t hn_mismatch R1 R2 x x'"
by (sep_auto simp: hn_ctxt_def mismatch_assn_def)
lemma is_merge: "P1∨⇩AP2⟹⇩tP ⟹ P1∨⇩AP2⟹⇩tP" .
lemma merge_mono: "⟦A⟹⇩tA'; B⟹⇩tB'; A'∨⇩AB' ⟹⇩t C⟧ ⟹ A∨⇩AB ⟹⇩t C"
by (meson entt_disjE entt_disjI1_direct entt_disjI2_direct entt_trans)
text ‹Apply forward rule on left or right side of merge›
lemma gen_merge_cons1: "⟦A⟹⇩tA'; A'∨⇩AB ⟹⇩t C⟧ ⟹ A∨⇩AB ⟹⇩t C"
by (meson merge_mono entt_refl)
lemma gen_merge_cons2: "⟦B⟹⇩tB'; A∨⇩AB' ⟹⇩t C⟧ ⟹ A∨⇩AB ⟹⇩t C"
by (meson merge_mono entt_refl)
lemmas gen_merge_cons = gen_merge_cons1 gen_merge_cons2
text ‹These rules are applied to recover pure values that have been destroyed by rule application›
definition "RECOVER_PURE P Q ≡ P ⟹⇩t Q"
lemma recover_pure:
"RECOVER_PURE emp emp"
"⟦RECOVER_PURE P2 Q2; RECOVER_PURE P1 Q1⟧ ⟹ RECOVER_PURE (P1*P2) (Q1*Q2)"
"CONSTRAINT is_pure R ⟹ RECOVER_PURE (hn_invalid R x y) (hn_ctxt R x y)"
"RECOVER_PURE (hn_ctxt R x y) (hn_ctxt R x y)"
unfolding RECOVER_PURE_def
subgoal by sep_auto
subgoal by (drule (1) entt_star_mono)
subgoal by (rule recover_pure_aux)
subgoal by sep_auto
done
lemma recover_pure_triv:
"RECOVER_PURE P P"
unfolding RECOVER_PURE_def by sep_auto
text ‹Weakening the postcondition by converting @{const invalid_assn} to @{term "λ_ _. true"}›
definition "WEAKEN_HNR_POST Γ Γ' Γ'' ≡ (∃h. h⊨Γ) ⟶ (Γ'' ⟹⇩t Γ')"
lemma weaken_hnr_postI:
assumes "WEAKEN_HNR_POST Γ Γ'' Γ'"
assumes "hn_refine Γ c Γ' R a"
shows "hn_refine Γ c Γ'' R a"
apply (rule hn_refine_preI)
apply (rule hn_refine_cons_post)
apply (rule assms)
using assms(1) unfolding WEAKEN_HNR_POST_def by blast
lemma weaken_hnr_post_triv: "WEAKEN_HNR_POST Γ P P"
unfolding WEAKEN_HNR_POST_def
by sep_auto
lemma weaken_hnr_post:
"⟦WEAKEN_HNR_POST Γ P P'; WEAKEN_HNR_POST Γ' Q Q'⟧ ⟹ WEAKEN_HNR_POST (Γ*Γ') (P*Q) (P'*Q')"
"WEAKEN_HNR_POST (hn_ctxt R x y) (hn_ctxt R x y) (hn_ctxt R x y)"
"WEAKEN_HNR_POST (hn_ctxt R x y) (hn_invalid R x y) (hn_ctxt (λ_ _. true) x y)"
proof (goal_cases)
case 1 thus ?case
unfolding WEAKEN_HNR_POST_def
apply clarsimp
apply (rule entt_star_mono)
by (auto simp: mod_star_conv)
next
case 2 thus ?case by (rule weaken_hnr_post_triv)
next
case 3 thus ?case
unfolding WEAKEN_HNR_POST_def
by (sep_auto simp: invalid_assn_def hn_ctxt_def)
qed
lemma reorder_enttI:
assumes "A*true = C*true"
assumes "B*true = D*true"
shows "(A⟹⇩tB) ≡ (C⟹⇩tD)"
apply (intro eq_reflection)
unfolding entt_def_true
by (simp add: assms)
lemma merge_sat1: "(A∨⇩AA' ⟹⇩t Am) ⟹ (A∨⇩AAm ⟹⇩t Am)"
using entt_disjD1 entt_disjE by blast
lemma merge_sat2: "(A∨⇩AA' ⟹⇩t Am) ⟹ (Am∨⇩AA' ⟹⇩t Am)"
using entt_disjD2 entt_disjE by blast
ML ‹
signature SEPREF_FRAME = sig
val is_merge: term -> bool
val frame_tac: (Proof.context -> tactic') -> Proof.context -> tactic'
val merge_tac: (Proof.context -> tactic') -> Proof.context -> tactic'
val frame_step_tac: (Proof.context -> tactic') -> bool -> Proof.context -> tactic'
val prepare_frame_tac : Proof.context -> tactic'
val recover_pure_tac: Proof.context -> tactic'
val align_goal_tac: Proof.context -> tactic'
val norm_goal_pre_tac: Proof.context -> tactic'
val align_rl_conv: Proof.context -> conv
val weaken_post_tac: Proof.context -> tactic'
val add_normrel_eq : thm -> Context.generic -> Context.generic
val del_normrel_eq : thm -> Context.generic -> Context.generic
val get_normrel_eqs : Proof.context -> thm list
val cfg_debug: bool Config.T
val setup: theory -> theory
end
structure Sepref_Frame : SEPREF_FRAME = struct
val cfg_debug =
Attrib.setup_config_bool @{binding sepref_debug_frame} (K false)
val DCONVERSION = Sepref_Debugging.DBG_CONVERSION cfg_debug
val dbg_msg_tac = Sepref_Debugging.dbg_msg_tac cfg_debug
structure normrel_eqs = Named_Thms (
val name = @{binding sepref_frame_normrel_eqs}
val description = "Equations to normalize relations for frame matching"
)
val add_normrel_eq = normrel_eqs.add_thm
val del_normrel_eq = normrel_eqs.del_thm
val get_normrel_eqs = normrel_eqs.get
val mk_entailst = HOLogic.mk_binrel @{const_name "entailst"}
local
open Sepref_Basic Refine_Util Conv
fun assn_ord p = case apply2 dest_hn_ctxt_opt p of
(NONE,NONE) => EQUAL
| (SOME _, NONE) => LESS
| (NONE, SOME _) => GREATER
| (SOME (_,a,_), SOME (_,a',_)) => Term_Ord.fast_term_ord (a,a')
in
fun reorder_ctxt_conv ctxt ct = let
val cert = Thm.cterm_of ctxt
val new_ct = Thm.term_of ct
|> strip_star
|> sort assn_ord
|> list_star
|> cert
val thm = Goal.prove_internal ctxt [] (mk_cequals (ct,new_ct))
(fn _ => simp_tac
(put_simpset HOL_basic_ss ctxt addsimps @{thms star_aci}) 1)
in
thm
end
fun prepare_fi_conv ctxt ct = case Thm.term_of ct of
(t as @{mpat "?P ⟹⇩t ?Q"}) => let
val (Qm, Qum) = strip_star Q |> filter_out is_true |> List.partition is_hn_ctxt
val Qtab = (
Qm |> map (fn x => (#2 (dest_hn_ctxt x),(NONE,x)))
|> Termtab.make
) handle
e as (Termtab.DUP _) => (
tracing ("Dup heap: " ^ @{make_string} ct); raise e)
val (Qtab,Pum) = fold (fn a => fn (Qtab,Pum) =>
case dest_hn_ctxt_opt a of
NONE => (Qtab,a::Pum)
| SOME (_,p,_) => ( case Termtab.lookup Qtab p of
SOME (NONE,tg) => (Termtab.update (p,(SOME a,tg)) Qtab, Pum)
| _ => (Qtab,a::Pum)
)
) (strip_star P) (Qtab,[])
val Pum = filter_out is_true Pum
val (pairs,Qum2) = Termtab.dest Qtab |> map #2
|> List.partition (is_some o #1)
|> apfst (map (apfst the))
|> apsnd (map #2)
val P' = mk_star (list_star (map fst pairs), list_star Pum)
val Q' = mk_star (list_star (map snd pairs), list_star (Qum2@Qum))
val new_t = mk_entailst (P', Q')
val goal_t = Logic.mk_equals (t,new_t)
val goal_ctxt = Variable.declare_term goal_t ctxt
val msg_tac = dbg_msg_tac (Sepref_Debugging.msg_allgoals "Solving frame permutation") goal_ctxt 1
val tac =
msg_tac
THEN ALLGOALS (resolve_tac goal_ctxt @{thms reorder_enttI})
THEN star_permute_tac goal_ctxt
val goal_ct = Thm.cterm_of ctxt goal_t
val thm = Goal.prove_internal ctxt [] goal_ct (fn _ => tac)
in
thm
end
| _ => no_conv ct
end
fun is_merge @{mpat "Trueprop (_ ∨⇩A _ ⟹⇩t _)"} = true | is_merge _ = false
fun is_gen_frame @{mpat "Trueprop (_ ⟹⇩t _)"} = true | is_gen_frame _ = false
fun prepare_frame_tac ctxt = let
open Refine_Util Conv
val frame_ss = put_simpset HOL_basic_ss ctxt addsimps
@{thms mult_1_right[where 'a=assn] mult_1_left[where 'a=assn]}
in
CONVERSION Thm.eta_conversion THEN'
simp_tac frame_ss THEN'
CONVERSION (HOL_concl_conv prepare_fi_conv ctxt)
end
local
fun wrap_side_tac side_tac dbg tac = tac THEN_ALL_NEW_FWD (
CONCL_COND' is_gen_frame
ORELSE' (if dbg then TRY_SOLVED' else SOLVED') side_tac
)
in
fun frame_step_tac side_tac dbg ctxt = let
open Refine_Util Conv
val side_tac = Sepref_Constraints.constraint_tac ctxt ORELSE' side_tac ctxt
val frame_thms = @{thms frame_thms} @
Named_Theorems_Rev.get ctxt @{named_theorems_rev sepref_frame_match_rules}
val merge_thms = @{thms merge_thms} @
Named_Theorems.get ctxt @{named_theorems sepref_frame_merge_rules}
val ss = put_simpset HOL_basic_ss ctxt addsimps normrel_eqs.get ctxt
fun frame_thm_tac dbg = wrap_side_tac side_tac dbg (resolve_tac ctxt frame_thms)
fun merge_thm_tac dbg = wrap_side_tac side_tac dbg (resolve_tac ctxt merge_thms)
fun thm_tac dbg = CONCL_COND' is_merge THEN_ELSE' (merge_thm_tac dbg, frame_thm_tac dbg)
in
full_simp_tac ss THEN' thm_tac dbg
end
end
fun frame_loop_tac side_tac ctxt = let
in
TRY o (
REPEAT_ALL_NEW (DETERM o frame_step_tac side_tac false ctxt)
)
end
fun frame_tac side_tac ctxt = let
open Refine_Util Conv
val frame_rem_thms = @{thms frame_rem_thms}
@ Named_Theorems_Rev.get ctxt @{named_theorems_rev sepref_frame_rem_rules}
val solve_remainder_tac = TRY o REPEAT_ALL_NEW (DETERM o resolve_tac ctxt frame_rem_thms)
in
(prepare_frame_tac ctxt
THEN' resolve_tac ctxt @{thms ent_star_mono entt_star_mono})
THEN_ALL_NEW_LIST [
frame_loop_tac side_tac ctxt,
solve_remainder_tac
]
end
fun merge_tac side_tac ctxt = let
open Refine_Util Conv
fun merge_conv ctxt = arg1_conv (binop_conv (reorder_ctxt_conv ctxt))
in
CONVERSION Thm.eta_conversion THEN'
CONCL_COND' is_merge THEN'
simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms star_aci}) THEN'
CONVERSION (HOL_concl_conv merge_conv ctxt) THEN'
frame_loop_tac side_tac ctxt
end
val setup = normrel_eqs.setup
local
open Sepref_Basic
fun is_invalid @{mpat "hn_invalid _ _ _ :: assn"} = true | is_invalid _ = false
fun contains_invalid @{mpat "Trueprop (RECOVER_PURE ?Q _)"} = exists is_invalid (strip_star Q)
| contains_invalid _ = false
in
fun recover_pure_tac ctxt =
CONCL_COND' contains_invalid THEN_ELSE' (
REPEAT_ALL_NEW (DETERM o (resolve_tac ctxt @{thms recover_pure} ORELSE' Sepref_Constraints.constraint_tac ctxt)),
resolve_tac ctxt @{thms recover_pure_triv}
)
end
local
open Sepref_Basic Refine_Util
datatype cte = Other of term | Hn of term * term * term
fun dest_ctxt_elem @{mpat "hn_ctxt ?R ?a ?c"} = Hn (R,a,c)
| dest_ctxt_elem t = Other t
fun mk_ctxt_elem (Other t) = t
| mk_ctxt_elem (Hn (R,a,c)) = @{mk_term "hn_ctxt ?R ?a ?c"}
fun match x (Hn (_,y,_)) = x aconv y
| match _ _ = false
fun dest_with_frame _ t = let
val (P,c,Q,R,a) = dest_hn_refine t
val (_,(_,args)) = dest_hnr_absfun a
val pre_ctes = strip_star P |> map dest_ctxt_elem
val (pre_args,frame) =
(case split_matching match args pre_ctes of
NONE => raise TERM("align_conv: Could not match all arguments",[P,a])
| SOME x => x)
in
((frame,pre_args),c,Q,R,a)
end
fun align_goal_conv_aux ctxt t = let
val ((frame,pre_args),c,Q,R,a) = dest_with_frame ctxt t
val P' = apply2 (list_star o map mk_ctxt_elem) (frame,pre_args) |> mk_star
val t' = mk_hn_refine (P',c,Q,R,a)
in t' end
fun align_rl_conv_aux ctxt t = let
val ((frame,pre_args),c,Q,R,a) = dest_with_frame ctxt t
val _ = frame = [] orelse raise TERM ("align_rl_conv: Extra preconditions in rule",[t,list_star (map mk_ctxt_elem frame)])
val P' = list_star (map mk_ctxt_elem pre_args)
val t' = mk_hn_refine (P',c,Q,R,a)
in t' end
fun normrel_conv ctxt = let
val ss = put_simpset HOL_basic_ss ctxt addsimps normrel_eqs.get ctxt
in
Simplifier.rewrite ss
end
in
fun align_goal_conv ctxt = f_tac_conv ctxt (align_goal_conv_aux ctxt) star_permute_tac
fun norm_goal_pre_conv ctxt = let
open Conv
fun conv ctxt = let
val nr_conv = normrel_conv ctxt
in
hn_refine_conv nr_conv all_conv all_conv all_conv all_conv
end
in
HOL_concl_conv conv ctxt
end
fun norm_goal_pre_tac ctxt = CONVERSION (norm_goal_pre_conv ctxt)
fun align_rl_conv ctxt = let
open Conv
fun conv ctxt = let
val nr_conv = normrel_conv ctxt
in
hn_refine_conv nr_conv all_conv nr_conv nr_conv all_conv
end
in
HOL_concl_conv (fn ctxt => f_tac_conv ctxt (align_rl_conv_aux ctxt) star_permute_tac) ctxt
then_conv HOL_concl_conv conv ctxt
end
fun align_goal_tac ctxt =
CONCL_COND' is_hn_refine_concl
THEN' DCONVERSION ctxt (HOL_concl_conv align_goal_conv ctxt)
end
fun weaken_post_tac ctxt = TRADE (fn ctxt =>
resolve_tac ctxt @{thms weaken_hnr_postI}
THEN' SOLVED' (REPEAT_ALL_NEW (DETERM o resolve_tac ctxt @{thms weaken_hnr_post weaken_hnr_post_triv}))
) ctxt
end
›
setup Sepref_Frame.setup
method_setup weaken_hnr_post = ‹Scan.succeed (fn ctxt => SIMPLE_METHOD' (Sepref_Frame.weaken_post_tac ctxt))›
‹Convert "hn_invalid" to "hn_ctxt (λ_ _. true)" in postcondition of hn_refine goal›
= (
rule hn_refine_preI,
((drule mod_starD hn_invalidI | elim conjE exE)+)?
)
lemmas [sepref_frame_normrel_eqs] = the_pure_pure pure_the_pure
end