File ‹monad_convert.ML›
structure Monad_Convert = struct
fun intersperse _ [] = []
| intersperse _ [x] = [x]
| intersperse a (x::xs) = x :: a :: intersperse a xs
fun theE NONE exc = raise exc
| theE (SOME x) _ = x
fun oneE [] exc = raise exc
| oneE (x::_) _ = x
fun apply_dummies tm =
let
val (xs, _) = Term.strip_abs tm;
val tm' = Term.betapplys (tm, map (Term.dummy_pattern o #2) xs);
in #1 (Term.replace_dummy_patterns tm' 1) end;
fun parse_pattern ctxt nm =
let
val consts = Proof_Context.consts_of ctxt;
val nm' =
(case Syntax.parse_term ctxt nm of
Const (c, _) => c
| _ => Consts.intern consts nm);
in
(case try (Consts.the_abbreviation consts) nm' of
SOME (_, rhs) => apply_dummies (Proof_Context.expand_abbrevs ctxt rhs)
| NONE => Proof_Context.read_term_pattern ctxt nm)
end;
fun term_search_bf cont pred prune = let
fun fresh_var vars v = if member (op =) vars v then fresh_var vars (v ^ "'") else v
fun search ((vars, term), queue) =
if pred term then cont (vars, term) (fn () => walk queue) else
if prune term then walk queue else
case term of
t as Abs (v, typ, _) =>
let val v' = fresh_var vars v in
walk (Queue.enqueue
((v'::vars), betapply (t, Free (v', typ))) queue)
end
| f $ x => walk (Queue.enqueue (vars, x) (Queue.enqueue (vars, f) queue))
| _ => walk queue
and walk queue = if Queue.is_empty queue then () else search (Queue.dequeue queue)
in
(fn term => search (([], term), Queue.empty))
end
fun term_search_bf_first pred prune term = let
val r = Unsynchronized.ref NONE
val _ = term_search_bf (fn result => K (r := SOME result)) pred prune term
in !r end
fun matches_subterm thy (pat, obj) =
let
fun msub bounds obj = Pattern.matches thy (pat, obj) orelse
(case obj of
(abs as Abs (_, T, t)) => msub (bounds + 1) (snd (Term.dest_abs_fresh (Name.bound bounds) abs))
| t $ u => msub bounds t orelse msub bounds u
| _ => false)
in msub 0 obj end;
fun grep_term ctxt pattern =
let
val thy = Proof_Context.theory_of ctxt
in
term_search_bf_first
(fn term => Pattern.matches thy (pattern, term))
(fn term => not (matches_subterm thy (pattern, term)))
end
val term_is_L2 = Monad_Types.check_lifting_head
[@{term "L2_unknown"}, @{term "L2_seq"}, @{term "L2_modify"},
@{term "L2_gets"}, @{term "L2_condition"}, @{term "L2_catch"}, @{term "L2_while"},
@{term "L2_throw"}, @{term "L2_spec"}, @{term "L2_assume"},
@{term "L2_guard"}, @{term "L2_fail"},
@{term "L2_call"}]
local
val case_prod_eta_contract_thm =
@{lemma "(λx. (case_prod s) x) == (case_prod s)" by simp}
in
fun case_prod_eta_conv ctxt =
Conv.bottom_conv (
K (Conv.try_conv (Conv.rewrs_conv [case_prod_eta_contract_thm]))) ctxt
then_conv
Drule.beta_eta_conversion
end
fun unit_fun_rewr_conv ct =
case Thm.term_of ct of
Abs (_, @{typ unit}, f) =>
Conv.rewr_conv @{thm unit_bind'} ct
| _ => Conv.no_conv ct
val unit_fun_conv = Conv.bottom_conv (K (Conv.try_conv unit_fun_rewr_conv))
local
local
fun strip t = snd (Synthesize_Rules.strip_abs_prod t)
in
fun l2_compound_index ((head as Const (@{const_name "L2_seq"}, _)) $ L $ R) =
head $ L $ l2_compound_index (strip R)
| l2_compound_index ((head as Const (@{const_name "L2_while"}, _)) $ C $ B $ I $ ns) =
head $ strip C $ l2_compound_index (strip B) $ I $ ns
| l2_compound_index ((head as Const (@{const_name "L2_condition"}, _)) $ C $ L $ R) =
head $ strip C $ l2_compound_index L $ l2_compound_index R
| l2_compound_index ((head as Const (@{const_name "L2_try"}, _)) $ B) =
head $ l2_compound_index B
| l2_compound_index x = x
end
fun l2_index (@{const Trueprop} $ ((sim as Const (@{const_name ‹refines›}, _)) $ f $ s $ f' $ s' $ R)) =
@{const Trueprop} $ (sim $ l2_compound_index f $ s $ f' $ s' $ R)
| l2_index x = x
fun check_compound _ @{term_pat ‹Trueprop (refines ?f _ _ _ _)›} =
(case strip_comb f |> fst of
@{term_pat "L2_seq"} => true
| @{term_pat "L2_while"} => true
| @{term_pat "L2_condition"} => true
| @{term_pat "L2_try"} => true
| @{term_pat "L2_guarded"} => true
| _ => false)
| check_compound _ _ = false
in
fun sim_nondet prog_info phase ctxt
(mt:Monad_Types.monad_type)
prev_def goal =
let
val mname = #name mt
val {rules_name, lift_prev, ...} = #refines_nondet mt
fun get_concr @{term_pat "refines ?f _ _ _ _"} = f
| get_concr t = error ("prune_unused_bounds_sim_nondet_tac, unexpected term: " ^ @{make_string} t)
val THIN_tac = Utils.THIN_tac (Utils.prune_unused_bounds_from_concr_tac get_concr)
val ctxt' = ctxt
|> Context.proof_map
(Synthesize_Rules.add_pattern_tac_rule rules_name THIN_tac @{binding THIN} 10 @{pattern ‹THIN (PROP ?P)›})
val _ = Utils.verbose_fn 2 ctxt' (fn _ => Synthesize_Rules.print_rules (Context.Proof ctxt') rules_name NONE)
val sim_rules = Synthesize_Rules.get_rules ctxt' rules_name |> the
fun lift rules goal thm =
case rules of [] => K CT.no_tac
| r::rules' =>
(case try (fn thm => r OF [thm]) thm of
NONE => lift rules' goal thm
| SOME thm' =>
CT.resolve_tac [thm'] ORELSE_CTXT' (lift rules' goal thm))
val cache = Synthesize_Rules.gen_cond_cache check_compound l2_index (lift lift_prev) sim_rules
val thm = Goal.prove ctxt' [] [] goal (fn {context, ...} =>
full_simp_tac (Simplifier.clear_simpset context addsimps @{thms DYN_CALL_def}) 1 THEN
EqSubst.eqsubst_tac context [0] [prev_def] 1 THEN
Context_Tactic.NO_CONTEXT_TACTIC context (
CT.cache_deepen_tac (fn ctxt => Config.get ctxt Utils.verbose) cache
(Synthesize_Rules.resolve_tacs sim_rules context) 1)
)
val _ = Utils.verbose_msg 2 ctxt (fn _ => ("sim_nondet_rewrite (" ^ mname ^ ") thm:\n " ^
(Thm.string_of_thm ctxt thm)))
in
SOME thm
end
handle ERROR str => (Utils.verbose_msg 2 ctxt (fn _ => "sim_nondet proof failed:\n " ^ str); NONE)
end
val d1 = Unsynchronized.ref false;
fun dprint_conv d msg = if d then Utils.print_conv msg else Conv.all_conv
fun polish_arg (arg: conv->conv) ctxt (mt : Monad_Types.monad_type) do_polish pretty_bounds_conv final_conv thm =
let
val ctxt = Context_Position.set_visible false ctxt
val simps = if do_polish then Utils.get_rules ctxt @{named_theorems polish} else []
val congs = if do_polish then Utils.get_rules ctxt @{named_theorems polish_cong} else []
val record_ss = RecursiveRecordPackage.get_simpset (Proof_Context.theory_of ctxt)
val basic_ss = merge_ss (HOL_ss, record_ss)
val simp_ctxt = put_simpset basic_ss ctxt
|> Simplifier.add_simps simps
|> fold Simplifier.add_proc [@{simproc NO_MATCH}, @{simproc ETA_TUPLED_HINT}]
|> fold Simplifier.add_cong congs
val simp_conv = Simplifier.rewrite simp_ctxt
val ((_, [thm]), ctxt') = Variable.import true [thm] ctxt
val [thm_p] = thm |>
Conv.fconv_rule (Conv.concl_conv (Thm.nprems_of thm) (arg (
(unit_fun_conv ctxt) then_conv
(pretty_bounds_conv ctxt') then_conv
(unit_fun_conv ctxt) then_force_conv
(pretty_bounds_conv ctxt') force_then_conv
(dprint_conv (!d1) "before simp_conv:") then_conv
simp_conv then_conv
(case_prod_eta_conv ctxt) then_conv
(dprint_conv (!d1) "before final_conv:") then_conv
(final_conv ctxt)
))) |> single |> Proof_Context.export ctxt' ctxt
in
thm_p
end
val polish_refines = polish_arg (fn conv => Conv.arg_conv (Utils.nth_arg_conv 2 conv))
val polish_eq = polish_arg (fn conv => (Conv.arg_conv (Conv.arg_conv conv)))
fun handle_invalid_subgoals (tac : int -> tactic) n =
fn thm =>
if Logic.count_prems (term_of_thm thm) < n then
no_tac thm
else
tac n thm
end