File ‹type_strengthen.ML›
structure TypeStrengthen =
struct
val timeit_msg = Utils.timeit_msg
val verbose_msg = Utils.verbose_msg
exception AllLiftingFailed of (string * thm) list
exception LiftingFailed of unit
val the' = Utils.the'
val apply_tac = Utils.apply_tac
fun state_typ finfo =
let
val term = FunctionInfo.get_const finfo;
val res = AutoCorresData.state_type_of_exn_monad term
in
res
end;
fun get_l2_state_typ l2_infos fn_name =
state_typ (the (Symtab.lookup l2_infos fn_name))
fun get_typ_from_L2 (rule_set : Monad_Types.monad_type) L2_typ =
let
val res = FunctionInfo.dest_exn_monad_type L2_typ |> snd |> #typ_from_L2 rule_set
in
res
end
type fn_ptr_infos = (string * (FunctionInfo.phase -> AutoCorresData.fn_ptr_info)) list
fun get_ts_corres_prop dyn_call ctxt skips prog_info prev_phase l2_infos fn_name
(rule_set : Monad_Types.monad_type) state args rhs_term =
let
val ts_corres_attr = AutoCorresData.corres_thm_attribute (ProgramInfo.get_prog_name prog_info) skips FunctionInfo.TS fn_name
val synth_attr = Synthesize_Rules.add_rule_attrib (#rules_name (#refines_nondet rule_set)) {only_schematic_goal = false}
(Binding.make (fn_name ^ "_recursion" , ⌂)) 10
val (old_fn, new, args) =
let
val fn_def = the (Symtab.lookup l2_infos fn_name)
in (FunctionInfo.get_const fn_def, rhs_term, args) end
val lift = #lift (#refines_nondet rule_set)
val old = betapplys (old_fn, args)
val sT = AutoCorresData.state_type_of_exn_monad old
val resT = AutoCorresData.res_type_of_exn_monad old
val exT = AutoCorresData.ex_type_of_exn_monad old
val relator = (case exT of
\<^Type>‹c_exntype _› => Monad_Types.relator_from_c_exntype (#refines_nondet rule_set)
| _ => #relator (#refines_nondet rule_set))
val term = (
\<^instantiate>‹
's = sT and 'a = resT and 'f = exT and 'x = dummyT and 'e = dummyT and 'b = dummyT and
s = ‹Free("s", sT)› and old = old and lift=‹Utils.dummy lift› and
new = new and relator=‹Utils.dummy relator›
in
prop ‹refines old (lift new) s s (rel_prod relator (=))›
for s::'s and old::‹('f, 'a, 's) exn_monad› and
lift::‹'x ⇒ ('e::default, 'b, 's) spec_monad› and new::'x and
relator::‹('f, 'a) xval ⇒ ('e::default, 'b) exception_or_result ⇒ bool›› |> Utils.infer_types_simple ctxt)
handle ERROR str =>
(Utils.verbose_msg 1 ctxt (fn _ => "type strengthening into " ^ quote (#name rule_set) ^ " failed:\n" ^ str);
raise LiftingFailed ())
in
([],
(term, [ts_corres_attr, synth_attr]))
end
fun assume_rec_lifted ctxt skips prog_info l2_infos prev_phase make_function_name rule_set
state rec_fn_fixes recursive_calls fn_name =
let
val (rec_frees, assumptions_rec) = map (fn (callee, name) =>
let
val fn_def' = the (Symtab.lookup l2_infos callee)
val args = FunctionInfo.get_plain_args fn_def' |> (fn xs => Utils.fix_variant_frees xs ctxt) |> fst |> map dest_Free
val T = map snd args
---> (fastype_of (FunctionInfo.get_const fn_def') |> get_typ_from_L2 rule_set)
val args = map Free args
val state_typ = get_l2_state_typ l2_infos fn_name
val free = Free (name, T)
val (prev_props, prop) =
get_ts_corres_prop true ctxt skips prog_info prev_phase l2_infos callee
rule_set state args (betapplys(free, args))
in
(free, map (apfst (Thm.cterm_of ctxt)) (prev_props @ [apfst (fold Logic.all (rev (state::args))) prop]))
end) (recursive_calls ~~ rec_fn_fixes)
|> split_list |> apsnd flat
val assumptions = assumptions_rec
val (thms, ctxt_asms) = Assumption.add_assumes (map fst assumptions) ctxt
val (thms, ctxt_asms) = ctxt_asms
|> fold_map (fn (thm, attrs) => Thm.proof_attributes attrs thm)
(thms ~~ map snd assumptions)
val export_assms = Assumption.export_morphism ctxt_asms ctxt
in
(ctxt_asms,
thms,
rec_frees,
export_assms)
end
fun mk_corresTS_fun_ptr_thm prog_info (rec_funs, rec_ptrs) ctxt ((P_prev as Const (_, T_prev), _), (P as Const (Pname, T), monad_type)) =
let
val (ptrT::prev_argTs, ret_prevT) = strip_type T_prev
val funT = let val (ptrT::argTs, retT) = strip_type T in argTs ---> retT end
fun mk_fun_ptr fname = HP_TermsTypes.mk_fun_ptr ctxt (ProgramInfo.get_prog_name prog_info) fname
val (empty, ptr_assoc) = map_filter (fn fname => find_first (fn (n, _) => n = fname) rec_funs) rec_ptrs
|> filter (fn (_, Free (_, fT)) => fT = funT)
|> `null
||> map (apfst mk_fun_ptr) ||> map HOLogic.mk_prod
||> HOLogic.mk_list (HOLogic.mk_prodT (@{typ "unit ptr"}, funT))
val _ = if empty andalso not (null rec_funs) then raise Match else ()
val {exT=ex_prevT, resT= ret_prevT, stateT} = AutoCorresData.dest_exn_monad_result_type ret_prevT
val mt = Monad_Types.get_monad_type monad_type (Context.Proof ctxt) |> the
val lift = #lift (#refines_nondet mt)
val relator = (case ex_prevT of
\<^Type>‹c_exntype _› => Monad_Types.relator_from_c_exntype (#refines_nondet mt)
| _ => #relator (#refines_nondet mt))
val args = map (fn T => ("x", T)) prev_argTs
val (s::p::args, ctxt') = Utils.fix_variant_frees ([("s", stateT), ("p", ptrT)] @ args) ctxt
val old = betapplys (P_prev, p::args)
val P = if empty then P else \<^infer_instantiate>‹P = P and xs = ptr_assoc in term ‹map_of_default P xs›› ctxt'
val new = betapplys (P, p::args)
val resT = AutoCorresData.res_type_of_exn_monad old
val exT = AutoCorresData.ex_type_of_exn_monad old
val corres = (
\<^instantiate>‹
's = stateT and 'a = resT and 'f = exT and 'x = dummyT and 'e = dummyT and 'b = dummyT and
s = ‹s› and old = old and lift=‹Utils.dummy lift› and
new = new and relator=‹Utils.dummy relator›
in
prop ‹refines old (lift new) s s (rel_prod relator (=))›
for s::'s and old::‹('f, 'a, 's) exn_monad› and
lift::‹'x ⇒ ('e::default, 'b, 's) spec_monad› and new::'x and
relator::‹('f, 'a) xval ⇒ ('e::default, 'b) exception_or_result ⇒ bool›› |> Utils.infer_types_simple ctxt)
val corres_pre = @{term DYN_CALL} $ corres
val goal = Logic.mk_implies (corres_pre, corres)
val [thm] = Goal.prove ctxt' [] [] goal (fn {context, ...} =>
asm_full_simp_tac (context addsimps @{thms DYN_CALL_def} delsimps @{thms map_of_default.simps}) 1)
|> single |> Proof_Context.export ctxt' ctxt
in
[(monad_type, thm)]
end
handle Match => []
fun perform_lift ctxt skips prog_info l2_infos prev_phase make_function_name rule_set fn_name =
let
val f_info = the (Symtab.lookup l2_infos fn_name)
val recursive_calls = Symset.dest (FunctionInfo.get_recursive_clique f_info)
val rec_names = map (make_function_name) recursive_calls
val (rec_fn_fixes, ctxt1_fun_names)
= Variable.add_fixes (map (make_function_name) recursive_calls) ctxt
val _ = @{assert} (rec_fn_fixes = rec_names)
val ([state], ctxt2_state) = ctxt1_fun_names |> Utils.fix_variant_frees [("s", state_typ f_info)];
val args = FunctionInfo.get_plain_args f_info;
val (arg_frees, ctxt3_args) = Utils.fix_variant_frees args ctxt2_state;
val export_fun_names = Variable.export_morphism ctxt1_fun_names ctxt
val export_measure = Variable.export_morphism ctxt2_state ctxt1_fun_names
val export_args = Variable.export_morphism ctxt3_args ctxt2_state
val (ctxt4_rec_assms, thms, rec_frees, export_assms)
= assume_rec_lifted ctxt3_args skips prog_info l2_infos prev_phase make_function_name rule_set
state rec_fn_fixes recursive_calls fn_name
val fn_def = FunctionInfo.get_definition f_info
val synth = Var (("_p", 0), (fastype_of (FunctionInfo.get_const f_info) |> get_typ_from_L2 rule_set))
val (_, (goal,_)) = get_ts_corres_prop true ctxt4_rec_assms skips prog_info prev_phase l2_infos fn_name rule_set
state arg_frees synth
val rewrite = Monad_Convert.sim_nondet prog_info FunctionInfo.TS ctxt4_rec_assms rule_set fn_def
val maybe_thm = rewrite goal
|> Option.map (Morphism.thm (export_assms $> export_args $> export_measure))
val _ = case maybe_thm of NONE => warning ("lifting failed for (" ^ #name rule_set ^ "): " ^ fn_name) | _ => ()
in
maybe_thm |> Option.map (rpair export_fun_names)
end
fun perform_lift_and_polish ctxt skips prog_info fn_info prev_phase make_function_name rule_set do_polish keep_going fn_name =
case (timeit_msg 2 ctxt (fn _ => "trying type strengthening to '" ^ #name rule_set ^ "'-monad for function: " ^ fn_name)
(fn () => perform_lift ctxt skips prog_info fn_info prev_phase make_function_name rule_set fn_name))
of NONE => NONE
| SOME (thm, export_fun_names) => SOME let
val _ = verbose_msg 3 ctxt (fn _ => "before polish thm: " ^ Thm.string_of_thm ctxt thm)
val fun_names = ProgramInfo.get_csenv prog_info |> ProgramAnalysis.get_functions |> map make_function_name
|> map (Syntax.read_term ctxt) |> map_filter (try (fst o dest_Const o head_of))
fun pretty_bounds_conv ctxt =
(PrettyBoundVarNames.pretty_bound_vars_thm keep_going ctxt fun_names)
val polish_thm = timeit_msg 1 ctxt (fn _ => "Polish - " ^ fn_name)
(fn _ => Monad_Convert.polish_refines ctxt rule_set do_polish
pretty_bounds_conv
map_of_default_args.fold_map_of_default_conv thm)
in (polish_thm, export_fun_names) end
fun get_body ctxt (mt:Monad_Types.monad_type) @{term_pat "refines ?f ?new _ _ _"} =
(case #dest_lift (#refines_nondet mt) new of
SOME f' => (f, f')
| NONE => (f, new))
fun admissibility_tac ctxt thms i =
Seq.INTERVAL
(REPEAT_ALL_NEW
(match_tac ctxt (@{thms admissible_imp admissible_all admissible_imp'} @ thms)))
1 i
fun subgoal_intro_tac' thms = Subgoal.FOCUS (fn {context, prems, ...} =>
let
val n = length thms
val hyps = take n prems
val inst_thms = map (fn thm => thm OF hyps) thms
in
(resolve_tac context inst_thms
THEN_ALL_NEW (AutoCorresUtil.subgoal_assm_tac prems context)) 1
end)
fun lift_function_rewrite rule_set filename skips prog_info prev_phase l2_infos ts_infos
fn_names make_function_name do_polish keep_going lthy =
let
val ts_monad_name = #name rule_set
val _ = verbose_msg 2 lthy (fn _ => "TS trying rule set: " ^ ts_monad_name)
val these_l2_infos = map (the o Symtab.lookup l2_infos) fn_names
val is_recursive = exists FunctionInfo.is_function_recursive these_l2_infos
val _ = if is_recursive andalso #ccpo_name rule_set = "" then raise LiftingFailed () else ()
val lifted_functions =
map (perform_lift_and_polish lthy skips prog_info l2_infos prev_phase make_function_name rule_set do_polish keep_going) fn_names
val lifted_functions = map (fn x =>
case x of
SOME a => a
| NONE => raise LiftingFailed ())
lifted_functions
fun simplified ctxt thms = Simplifier.full_simplify ((Raw_Simplifier.clear_simpset ctxt) addsimps thms)
val thms = map (simplified lthy @{thms DYN_CALL_def} o #1) lifted_functions
val morphs = map #2 lifted_functions
fun gen_fun_def_term (fn_name, thm) =
let
val nargs = Symtab.lookup l2_infos fn_name |> the |> FunctionInfo.get_args |> length
val (((typ_inst, var_inst), [imp_thm]), _) = Variable.import true [thm] lthy
val (orig, new) = imp_thm |> Thm.concl_of |> HOLogic.dest_Trueprop |> get_body lthy rule_set
val args = strip_comb orig |> snd |> rev |> take nargs |> rev
val term = foldr (fn (v, t) => Utils.abs_over "" v t) new args
in
(fn_name, make_function_name fn_name, map dest_Free args, term)
end
val phase = FunctionInfo.TS
val input_defs = map gen_fun_def_term (fn_names ~~ thms)
fun src_name n = get_first (fn (orig, n', _, _ ) => if n = n' then SOME orig else NONE) input_defs
val do_guard_simps = member (op =) ["nondet", "exit"] ts_monad_name
fun has_fun_pointers b = the_default false (
src_name (Binding.name_of b)
|> Option.map (ProgramAnalysis.has_fun_ptr_calls (ProgramInfo.get_csenv prog_info)))
fun qualify b =
if do_guard_simps then
Binding.qualify true "raw" b
else
Binding.qualify (has_fun_pointers b) "ts" b
val has_fun_pointers = ProgramAnalysis.program_has_fun_ptr_calls (ProgramInfo.get_csenv prog_info)
val final_attr =
if has_fun_pointers orelse do_guard_simps then []
else
[K (Named_Theorems.add @{named_theorems final_defs})]
val lthy = lthy |> AutoCorresData.in_theory' (
Utils.define_functions input_defs qualify false is_recursive (#ccpo_name rule_set)
(final_attr @ [AutoCorresData.define_function_attribute {concealed_named_theorems = do_guard_simps} filename skips phase])
[]
[]
#> snd)
val (fs, ts_defs) =
let
val finfos = fn_names |> map_filter (AutoCorresData.get_function_info (Context.Proof lthy) filename phase)
|> map (fn info => (FunctionInfo.get_const info, FunctionInfo.get_definition info))
in
split_list finfos
end
val ([state], lthy') = Utils.fix_variant_frees [("s", get_l2_state_typ l2_infos (hd fn_names))] lthy
val final_props' = (map (fn (fn_name, fn_trm) =>
let
val finfo = the (Symtab.lookup l2_infos fn_name)
val args = FunctionInfo.get_plain_args finfo |> (fn xs => Utils.fix_variant_frees xs lthy') |> fst
val prop = get_ts_corres_prop false lthy' skips prog_info prev_phase l2_infos fn_name
rule_set state args (betapplys (fn_trm, args)) |> snd |> fst
in
fold Logic.all (rev (args)) prop
end) (fn_names ~~ fs))
val ((paramss, props), lthy') = lthy' |> fold_map Utils.import_universal_prop final_props' |> apfst split_list
val simps =
@{thms gets_bind_ign L2_call_fail HOL.simp_thms}
val exp_thms = map (fn (thm, export_fun_names) => Morphism.thm export_fun_names thm) (thms ~~ morphs)
fun get_induct_thms () =
let
val c = hd ts_defs |> Thm.concl_of |> Utils.lhs_of_eq |> Term.head_of
in
Mutual_CCPO_Rec.lookup_info_trimmed (Context.Proof lthy) c |> the_list |> maps #inducts
end
val induct_thms = get_induct_thms ()
val admissibility_thms = Named_Theorems.get lthy @{named_theorems corres_admissible}
val top_thms = Named_Theorems.get lthy @{named_theorems corres_top}
val N = length props
val arbitrary_varss = replicate N [state]
val all_varss = map (fn (xs, ys) => xs @ ys) (paramss ~~ arbitrary_varss)
fun admissibility_tac ctxt i = Seq.INTERVAL (AutoCorresUtil.corres_admissible_tac ctxt) i (i + N - 1)
val bump_unify_bound = Config.map Unify.search_bound (fn n => n * N)
val rewrite_thms =
if is_recursive then
Goal.prove_common lthy' NONE [] [] props (fn {context,...} =>
DETERM (Induct.induct_tac (bump_unify_bound context) false
[]
(map (map dest_Free) all_varss) [] (SOME induct_thms) [] 1) THEN
admissibility_tac (bump_unify_bound context) 1 THEN
REPEAT_DETERM_N N (match_tac context top_thms 1) THEN
REPEAT (subgoal_intro_tac' exp_thms context 1)
)
else
Goal.prove_common lthy' NONE [] [] props (fn {context,...} =>
EVERY [
EqSubst.eqsubst_tac lthy' [0] [hd ts_defs] 1,
resolve_tac lthy' exp_thms 1,
(REPEAT (
FIRST [
CHANGED (asm_simp_tac (put_simpset HOL_ss (Context_Position.set_visible false lthy') addsimps simps) 1),
Method.assm_tac lthy' 1]))
]
)
val new_thms = rewrite_thms
|> Proof_Context.export lthy' lthy
val (ctxt_new_thms, lthy) = lthy
|> fold_map (fn (name, thm) =>
let
val thm_name = AutoCorresData.corres_thm_name prog_info FunctionInfo.TS name
in thm |> Utils.define_lemma (Binding.name thm_name) (
AutoCorresData.corres_thm_attribute filename skips phase name::
Monad_Types.add_call_rule_attribs (Context.Proof lthy) rule_set {only_schematic_goal = false}
(Binding.make (thm_name, ⌂)) 10)
end)
(fn_names ~~ new_thms)
in
(ts_monad_name, lthy)
end
fun compute_lift_rules rules force_lift fn_names =
let
fun all_list f xs = fold (fn x => (fn b => b andalso f x)) xs true
val forced = fn_names
|> map (fn func => case Symtab.lookup force_lift func of
SOME rule => [(func, rule)]
| NONE => [])
|> List.concat
in
case forced of
[] => rules
| ((func, rule) :: rest) =>
if map snd rest |> all_list (fn rule' => #name rule = #name rule')
then [rule]
else error ("autocorres: this set of mutually recursive functions " ^
"cannot be lifted to different monads: " ^
commas_quote (map fst forced))
end
fun lift_function rules force_lift filename skips prog_info prev_phase l2_infos ts_infos
fn_names make_function_name do_polish keep_going lthy =
let
val rules' = compute_lift_rules rules force_lift fn_names
fun first prevs (rule::xs) =
(lift_function_rewrite rule filename skips prog_info prev_phase l2_infos ts_infos
fn_names make_function_name do_polish keep_going lthy
handle LiftingFailed _ =>
(Utils.verbose_msg 4 lthy (fn _ => "LiftingFailed: " ^ #name rule);
first (rule::prevs) xs))
| first _ [] = raise AllLiftingFailed (map (fn f =>
(f, FunctionInfo.get_definition (the (Symtab.lookup l2_infos f)))) fn_names)
in
first [] rules'
end
fun print_statistics results =
let
fun count_dups x [] = [x]
| count_dups (head, count) (next::rest) =
if head = next then
count_dups (head, count + 1) rest
else
(head, count) :: (count_dups (next, 1) rest)
val tabulated = count_dups ("__fake__", 0) (sort_strings results) |> tl
val data = map (fn (a,b) =>
(" " ^ a ^ ": " ^ (@{make_string} b) ^ "\n")
) tabulated
|> String.concat
in
writeln ("Type Strengthening Statistics: \n" ^ data)
end
fun drop_while P [] = []
| drop_while P (x::xs) = if P x then drop_while P xs else (x::xs)
fun get_unchanged_typing_prop prog_info ts_infos monad_name
ctxt fn_name fn_args =
let
val heap_abs = ProgramInfo.get_heap_abs (ProgramInfo.get_fun_options prog_info fn_name)
val unchanged_typing_on =
(if heap_abs then
Syntax.read_term ctxt
(fold_rev Long_Name.qualify []
(Long_Name.base_name @{const_name heap_typing_state.unchanged_typing_on}))
else
Syntax.read_term ctxt
(fold_rev Long_Name.qualify [NameGeneration.global_rcd_name, "typing"]
(Long_Name.base_name @{const_name heap_typing_state.unchanged_typing_on})))
|> Term_Subst.instantiate_frees (TFrees.make [(("'a", @{sort type}), @{typ unit})], Frees.empty)
val attrs = map (Attrib.attribute ctxt) @{attributes [runs_to_vcg]}
val ts_fun = (the (Symtab.lookup ts_infos fn_name) |> FunctionInfo.get_const)
val ts_term = betapplys (ts_fun, fn_args)
in
(([],
\<^infer_instantiate>‹C = ts_term and unchanged = ‹unchanged_typing_on›
in prop ‹⋀s. Spec_Monad.runs_to_partial C s (λr t. unchanged (UNIV::addr set) s t)›› ctxt),
attrs)
end
fun translate
(skips: FunctionInfo.skip_info)
(base_locale_opt: string option)
(rules : Monad_Types.monad_type list)
(force_lift : Monad_Types.monad_type Symtab.table)
(prog_info : ProgramInfo.prog_info)
(keep_going : bool)
(do_polish : bool)
(groups: string list list)
(lthy: local_theory)
: string list list * local_theory =
let
val phase = FunctionInfo.TS
val prev_phase = FunctionInfo.prev_phase skips phase
val filename = ProgramInfo.get_prog_name prog_info
val make_function_name = ProgramInfo.get_mk_fun_name prog_info phase
val existing_ts_infos = AutoCorresData.get_default_phase_info (Context.Proof lthy) filename phase
val infos_prev_phase = AutoCorresData.get_default_phase_info (Context.Proof lthy) filename prev_phase
fun translate_group fn_names lthy =
if forall (Symtab.defined existing_ts_infos) fn_names then
lthy
else
let
val _ = writeln ("Translating (type strengthen) " ^ commas fn_names);
val (monad_name, lthy) = lthy |> AutoCorresUtil.timeit_ts_msg 1 lthy fn_names (fn () =>
AutoCorresUtil.in_corres_locale_result prog_info skips phase filename fn_names (fn lthy => lthy |>
let
val l2_infos = AutoCorresData.get_default_phase_info (Context.Proof lthy) filename prev_phase
val ts_infos = AutoCorresData.get_default_phase_info (Context.Proof lthy) filename phase
in lift_function rules force_lift filename skips prog_info prev_phase l2_infos ts_infos
fn_names (make_function_name "") do_polish keep_going
end))
val _ = writeln (" --> " ^ monad_name);
val heap_abs = ProgramInfo.get_heap_abs (ProgramInfo.get_fun_options prog_info (hd fn_names))
val stateT =
if heap_abs then
the (ProgramInfo.get_lifted_globals_type prog_info)
else ProgramInfo.get_globals_type prog_info
val all_rules = Monad_Types.get_ordered_rules [] (Context.Proof lthy)
val monad_infos = all_rules |> drop_while (fn r => #name r <> monad_name)
|> map (fn {name, lift_from_previous_monad,...} =>
(name, if name = monad_name then I else lift_from_previous_monad lthy stateT))
val monad_infos = (I, [])
|> fold (fn (n, current_lift) => fn (lift, xs) =>
let val new_lift = current_lift o lift in (new_lift, (n, new_lift)::xs) end)
monad_infos
|> snd |> rev
|> filter (fn (n, _) => member (op =) (map #name rules) n)
val lthy = lthy |> member (op =) ["nondet", "exit"] monad_name ? (fn lthy =>
let
val ts_infos = AutoCorresData.get_default_phase_info (Context.Proof lthy) filename phase
fun finfo f = (the (Symtab.lookup ts_infos f))
val is_recursive = FunctionInfo.is_function_recursive (finfo (hd fn_names))
fun get_induct_thms () =
let
val c = (finfo (hd fn_names)) |> FunctionInfo.get_const |> Term.head_of
in
Mutual_CCPO_Rec.lookup_info_trimmed (Context.Proof lthy) c |> the_list |> maps #inducts
end
val induct_thms = get_induct_thms ()
fun prop f ctxt =
let
val info = finfo f
val args = FunctionInfo.get_plain_args info
val def = FunctionInfo.get_definition info
val (params, ctxt) = Utils.fix_variant_frees args ctxt;
val ((prems, prop), attrs) = get_unchanged_typing_prop prog_info ts_infos monad_name lthy f params
val ((arbitrary_vars, prop), ctxt) = Utils.import_universal_prop prop ctxt
val (prems, ctxt) = Assumption.add_assumes (map (Thm.cterm_of ctxt) prems) ctxt
val (_, ctxt) = fold_map (Thm.proof_attributes attrs) prems ctxt
in
((def, params, arbitrary_vars, (prop, attrs)), ctxt)
end
val heap_syntax_defs = Named_Theorems.get lthy @{named_theorems heap_update_syntax}
|> map (Utils.abs_def lthy)
val (props, ctxt) = lthy |> fold_map prop fn_names
val thms = Utils.timeit_label 1 lthy ("Trying unchanged typing proof for " ^ commas fn_names) (fn _ =>
AutoCorresUtil.prove_functions is_recursive induct_thms
(fn ctxt => Unchanged_Typing.unchanged_typing_tac NONE (ctxt addsimps heap_syntax_defs))
(fn attrss => fn ctxt => ALLGOALS (AutoCorresUtil.prove_induction_case
(K (Unchanged_Typing.unchanged_typing_tac NONE)) attrss (ctxt addsimps heap_syntax_defs)))
ctxt props
handle ERROR msg =>
(warning ("Could not prove 'unchanged_typing' for " ^ commas fn_names ^ "\n " ^ msg); []))
val thms = thms |> (Proof_Context.export ctxt lthy)
val lthy = lthy |> not (null thms)?
(Local_Theory.note ((Binding.make (suffix "_unchanged_typing" (space_implode "_" fn_names), ⌂),
@{attributes [unchanged_typing]}), thms) #> snd)
fun simplify_def f lthy =
let
val has_fun_pointers = ProgramAnalysis.has_fun_ptr_calls (ProgramInfo.get_csenv prog_info) f
val final_attr = if has_fun_pointers then [] else [Named_Theorems.add @{named_theorems final_defs}]
val info = finfo f
val def = FunctionInfo.get_definition info
val _ = Utils.verbose_msg 3 ctxt (fn _ => "before guard simplification:\n " ^ Thm.string_of_thm lthy def)
val size_simps = Named_Theorems.get lthy @{named_theorems size_simps}
val ctxt = lthy delsimps
@{thms map_of_default.simps replicate_0 replicate_Suc replicate_numeral} @
size_simps
val def' = timeit_msg 1 lthy (fn _ => "Simplifying guards within " ^ f) (fn _ =>
Monad_Cong_Simp.monad_simplify_import ctxt def)
val _ = Utils.verbose_msg 3 ctxt (fn _ => "after guard simplification:\n " ^ Thm.string_of_thm lthy def')
val base_name = make_function_name "" f
val b =
if is_recursive then
Binding.name "simps" |> Binding.qualify true base_name
else
Binding.name (base_name ^ "_def")
in
lthy
|> Utils.define_lemma (Binding.qualify has_fun_pointers "ts" (Binding.set_pos ⌂ b))
(final_attr @ [AutoCorresData.define_function_attribute {concealed_named_theorems=false} filename skips phase f])
def'
|> snd
end
val lthy = lthy |> fold simplify_def fn_names
in
lthy
end)
in lthy end;
val lthy = lthy
|> fold translate_group groups
in
(groups, lthy)
end
end