File ‹autocorres_util.ML›
infix 1 THEN_UNSOLVED
fun (tac1 THEN_UNSOLVED tac2) i st =
let
val np = Thm.nprems_of st;
fun cond_tac2 i st = if Thm.nprems_of st >= np then tac2 i st else all_tac st
in
((tac1 i) THEN (cond_tac2 i)) st
end
structure AutoCorresUtil =
struct
val verbose = Utils.verbose
val verbose_msg = Utils.verbose_msg
val timing_msg = Utils.timing_msg
val timeit_msg = Utils.timeit_msg
val timeap_msg_tac = Utils.timeap_msg_tac
fun timeit_conversion_msg level ctxt finfo f =
let
val fname = FunctionInfo.get_name finfo
val phase = FunctionInfo.next_phase (FunctionInfo.get_phase finfo)
fun mk_msg s e =
s ^ " (" ^ FunctionInfo.string_of_phase phase ^ ") for function " ^ fname ^ e
val _ = timing_msg level ctxt (fn _ => mk_msg "Converting" " ...")
in timeit_msg level ctxt (fn _ => mk_msg "Converted" "") f end;
fun timeit_conversion_msg' level ctxt phase fname f =
let
fun mk_msg s e =
s ^ " (" ^ FunctionInfo.string_of_phase phase ^ ") for function " ^ fname ^ e
val _ = timing_msg level ctxt (fn _ => mk_msg "Converting" " ...")
in timeit_msg level ctxt (fn _ => mk_msg "Converted" "") f end;
fun timeit_prepare_msg' level ctxt phase clique f =
let
fun mk_msg s e =
s ^ " (" ^ FunctionInfo.string_of_phase phase ^ ") for clique " ^ commas clique ^ e
val _ = timing_msg level ctxt (fn _ => mk_msg "Preparing" " ...")
in timeit_msg level ctxt (fn _ => mk_msg "Preparing" "") f end;
fun timeit_ts_msg level ctxt fnames f =
let
val fgroup = commas fnames
val phase = FunctionInfo.TS
fun mk_msg s e =
s ^ " (" ^ FunctionInfo.string_of_phase phase ^ ") for function(s) " ^ fgroup ^ e
val _ = timing_msg level ctxt (fn _ => mk_msg "Converting" " ...")
in timeit_msg level ctxt (fn _ => mk_msg "Converted" "") f end;
fun safe_unprefix [] x = x
| safe_unprefix (p::ps) x = case try (unprefix p) x of SOME x' => x' | _ => safe_unprefix ps x
fun get_callees fn_info =
(Symset.dest (FunctionInfo.get_callees fn_info), Symset.dest (FunctionInfo.get_rec_callees fn_info))
fun get_callees' fn_infos fn_name =
let
val fn_info = the (Symtab.lookup fn_infos fn_name)
in
get_callees fn_info
end
fun CHANGED' tac i st =
let fun diff st' = let
val res = not (Thm.eq_thm (st, st'));
in res end
in Seq.filter diff (tac i st) end;
fun REPEAT_DETERM_N' n tac i =
REPEAT_DETERM_N n (tac i)
fun maybe_dest_Trueprop t =
case try HOLogic.dest_Trueprop t of SOME t' => t' | NONE => t
fun prev_phase_pair skips phase = (FunctionInfo.prev_phase skips phase, phase)
fun dest_corres_progs @{term_pat "L1corres _ _ ?new ?old"} = SOME {new=new, old=old}
| dest_corres_progs @{term_pat "L2corres _ _ _ _ ?new ?old"} = SOME {new=new, old=old}
| dest_corres_progs @{term_pat "IOcorres _ _ _ _ _ ?new ?old"} = SOME {new=new, old=old}
| dest_corres_progs @{term_pat "L2Tcorres _ ?new ?old"} = SOME {new=new, old=old}
| dest_corres_progs @{term_pat "corresTA _ _ _ ?new ?old"} = SOME {new=new, old=old}
| dest_corres_progs @{term_pat "refines ?old ?new _ _ _"} = SOME {new=new, old=old}
| dest_corres_progs _ = NONE
fun is_var_new (_, schematic_vars:cterm Vars.table) concl =
case (try HOLogic.dest_Trueprop concl) |> Option.mapPartial dest_corres_progs of
SOME {new, ...} => Vars.exists (fn (_, v) => Thm.term_of v = new) schematic_vars
| NONE => false
fun EVERY_UNSOLVED [] i = all_tac
| EVERY_UNSOLVED (tac::tacs) i = (tac THEN_UNSOLVED (EVERY_UNSOLVED tacs)) i
fun add_matches match (t as (u $ v)) (xs, ctxt) = (case match ctxt t of SOME x => (x::xs, ctxt) | NONE => add_matches match v (add_matches match u (xs, ctxt)))
| add_matches match (t as (Abs _)) (xs, ctxt) = (case match ctxt t of SOME x => (x::xs, ctxt) | NONE =>
let
val ((_, bdy), ctxt') = Variable.dest_abs t ctxt
in add_matches match bdy (xs, ctxt') end)
| add_matches match t (xs, ctxt) = (case match ctxt t of SOME x => (x::xs, ctxt) | NONE => (xs, ctxt))
fun notin_tac ctxt = SUBGOAL (fn (t,i) =>
case t |> Utils.concl_of_subgoal_open of
@{term_pat "Trueprop (_ ∉ _)"} => SOLVED' (asm_full_simp_tac ctxt THEN' Utils.print_subgoal_tac "notin end" ctxt) i
| _ => no_tac)
val d1 = Unsynchronized.ref false
val d2 = Unsynchronized.ref false
val dest_map_of_default = map_of_default_args.dest_map_of_default
val dest_assoc = map_of_default_args.dest_assoc
fun mk_assoc (xs as ((p,f)::_)) =
let
val pT = @{typ "unit ptr"}
val fT = fastype_of f
in
{xs = xs |> map HOLogic.mk_prod |> HOLogic.mk_list (HOLogic.mk_prodT (pT, fT)), pT = pT, fT = fT}
end
| mk_assoc [] = error ("mk_assoc: empty")
fun mk_map_of_default {xs, fT, ...} =
let
val d = AutoCorresData.mk_default fT
in
\<^instantiate>‹'a = fT and d=d and xs = xs in term ‹map_of_default (λ_::unit ptr. d) xs››
end
fun subst_atomic ctxt substs t =
let
val thy = Proof_Context.theory_of ctxt
fun rew t = case AList.lookup (op =) substs t of SOME rhs => rhs | _ => Same.same t
fun do_rew t = case t of Const _ => rew t | Free _ => rew t | _ => Same.same t
in Term.map_aterms do_rew t end
val head_name = Term.head_of #> Term.term_name
type corres_funs = {old: term, new: term}
fun get_first_corres {dest_corres_funs: term -> corres_funs} thms old =
let
val old_name = head_name old
in
get_first (fn thm =>
Thm.concl_of thm |> dest_corres_funs |> (fn {new, old} =>
if head_name old = old_name then SOME (Term.head_of new, thm) else NONE))
thms
end
fun check_map_of_default ct = ct |> Thm.term_of |> Utils.concl_of_subgoal_open |>
exists_subterm (fn @{term_pat "map_of_default ?d ?xs ?p"} => true | _ => false)
fun map_of_default_core t =
case try dest_map_of_default t of
SOME {map_of_default, d, fs, p, ...} => map_of_default $ d $ fs $ p
| _ => t
fun strip_args n t =
let
val (head, args) = strip_comb t
val args' = rev (drop n (rev args))
in
list_comb (head, args')
end
fun replace_arg t x = fst (Term.dest_comb t) $ x
fun mk_corres_map_of_default_thm {get_first_corres} ctxt thms map_of_default_old =
let
val {fs, d, args, ...} = map_of_default_old |> dest_map_of_default
val arity = length args
val (ptrs, olds) = dest_assoc fs |> split_list |> apsnd (map (strip_args arity))
val (news, relevant_thms) = map_filter (get_first_corres thms) olds |> split_list
val corres_top = Named_Theorems.get ctxt @{named_theorems corres_top}
val _ = @{assert} (length news = length olds)
val assoc = (ptrs ~~ news) |> mk_assoc
val fs' = #xs assoc
val d' = AutoCorresData.mk_default (@{typ "unit ptr"} --> #fT assoc)
val map_of_default_new = mk_map_of_default assoc
val ([p], ctxt1) = Utils.fix_variant_frees [("p", @{typ "unit ptr"})] ctxt
val rewrs = [(hd olds, replace_arg (map_of_default_core map_of_default_old) p ), (hd news, map_of_default_new $ p)]
val prop = hd relevant_thms |> Thm.prop_of |> subst_atomic ctxt rewrs
val rule = Drule.infer_instantiate ctxt1
[ (("ys", 0), Thm.cterm_of ctxt1 fs'), (("d'", 0), Thm.cterm_of ctxt1 d'), (("p", 0), Thm.cterm_of ctxt1 p),
(("xs", 0), Thm.cterm_of ctxt1 fs), (("d", 0), Thm.cterm_of ctxt1 d)]
@{thm map_of_default_list_all2_cases}
val intros = corres_top @ @{thms list_all2_prod_cons list_all2_prod_nil conjI refl}
val ([prop'], ctxt2) = Variable.import_terms false [prop] ctxt1
val thm = prove ctxt2 [] [] prop' (fn {context, ...} =>
Utils.dprint_tac (!d1) "init:" context THEN
( resolve_tac context [rule] THEN_ALL_NEW
REPEAT_ALL_NEW (Utils.dprint_subgoal_tac (!d1) "repeat:" context THEN' resolve_tac context (relevant_thms @ intros))) 1 THEN
Utils.dprint_tac (!d1) "after intros:" context)
|> singleton (Proof_Context.export ctxt2 ctxt)
|> Goal.norm_result ctxt
val _ = if (!d2) then tracing ("mk_corres_map_of_default_thm: " ^ Thm.string_of_thm ctxt thm) else ()
in
(map_of_default_new, thm)
end
val measureT = @{typ nat};
fun split_Ball_attr phase = Thm.declaration_attribute (fn thm => fn context =>
let
val ctxt = Context.proof_of context
val derived_thms = thm |> Conv.fconv_rule (Simplifier.rewrite ctxt) |> HOLogic.conj_elims
in
context |> fold (Named_Theorems.add_thm (AutoCorresData.corres_named_thms phase)) derived_thms
end
)
fun assume_called_functions_corres ctxt rec_callees
get_fn_type get_fn_prop get_fn_args get_const_name =
let
fun assume_func ctxt fn_name maybe_fn_ptr_info =
let
val fn_args = get_fn_args fn_name
val (fn_free, ctxt') = case maybe_fn_ptr_info of
SOME l2_term => (l2_term, ctxt)
| NONE =>
let
val ([fixed_fn_name], ctxt') = Variable.variant_fixes [get_const_name fn_name] ctxt
val fn_free = Free (fixed_fn_name, get_fn_type fn_name)
in (fn_free, ctxt') end
val (arg_names, ctxt'')
= Variable.variant_fixes ((map fst fn_args)) ctxt'
val fn_arg_terms = map (fn (n, T) => Free (n, T)) (arg_names ~~ (map snd fn_args))
val assumptions =
get_fn_prop ctxt'' true fn_name fn_free fn_arg_terms
|> (fn (current_phase_prop, params_opt) => (
[current_phase_prop |> apfst
(fold Logic.all (rev fn_arg_terms))]))
|> map (apfst (Sign.no_vars ctxt' #> Thm.cterm_of ctxt'))
val (thms, ctxt''') = Assumption.add_assumes (map fst assumptions) ctxt''
val (thms, ctxt'''') = ctxt'''
|> fold_map (fn (thm, attrs) => Thm.proof_attributes attrs thm)
(thms ~~ map snd assumptions)
in
(fn_free, thms, ctxt'''')
end
val (res, ctxt') = ctxt |> fold_map (
fn (fn_name, is_fn_ptr_param) =>
fn ctxt =>
let
val (free, thms, ctxt') =
assume_func ctxt fn_name is_fn_ptr_param
in
((fn_name, (free, thms)), ctxt')
end)
(map (fn f => (f, NONE)) (Symset.dest rec_callees))
in
(ctxt', res)
end
fun get_body_callees
(callee_consts: string Termtab.table)
(body: term)
: symset =
Term.fold_aterms (fn t => fn a =>
(Termtab.lookup callee_consts t
|> Option.map single
|> the_default []) @ a)
body []
|> Symset.make;
fun get_rec_callees
(callee_terms: (string * (term * thm list)) list)
(body: term)
: symset = let
val callee_lookup =
callee_terms |> List.mapPartial (fn (callee, (const, _)) =>
SOME (const, callee))
|> Termtab.make;
in get_body_callees callee_lookup body end;
fun is_recursive_group infos group =
let
val recursive = null group orelse FunctionInfo.is_function_recursive (the (Symtab.lookup infos (hd group)))
val _ = @{assert} (forall (fn f => recursive = FunctionInfo.is_function_recursive (the (Symtab.lookup infos f))) group)
in recursive end
fun prove_induction_case tac (attrss:attribute list list) ctxt i = Subgoal.FOCUS (fn {context, prems,...} =>
let
val (thms, ctxt) = context
|> fold_map (fn (thm, attrs) => Thm.proof_attributes attrs thm) (prems ~~ attrss)
in
tac thms ctxt
end) ctxt i
fun mcont_tac ctxt i =
REPEAT (
resolve_tac ctxt @{thms mcont_id'} i
ORELSE
(resolve_tac ctxt @{thms mcont2mcont_call} i THEN
TRY (EqSubst.eqsubst_tac ctxt [0] @{thms gfp_lub_fun [symmetric]} i) THEN
TRY (EqSubst.eqsubst_tac ctxt [0] @{thms gfp_le_fun [symmetric]} i)))
fun gen_corres_admissible_tac basic_rules ctxt =
(REPEAT' (resolve_tac ctxt basic_rules)) THEN_UNSOLVED
(Utils.dprint_subgoal_tac (!d1) "after simp" ctxt) THEN_UNSOLVED
( (REPEAT' (resolve_tac ctxt (@{thms admissible_subst_fun_lub_fun_ord} @ basic_rules))) THEN_ALL_NEW
((Utils.dprint_subgoal_tac (!d1) "after resolve admissible_subst_fun_lub_fun_ord" ctxt) THEN' SOLVED' (mcont_tac ctxt) ORELSE'
REPEAT_ALL_NEW (resolve_tac ctxt basic_rules THEN' (Utils.dprint_subgoal_tac (!d1) "after solve resolve" ctxt))))
fun corres_admissible_tac ctxt =
let
val structural_rules = @{thms admissible_imp admissible_all admissible_imp'}
val corres_rules = Named_Theorems.get ctxt @{named_theorems corres_admissible}
val basic_rules = structural_rules @ corres_rules
in
gen_corres_admissible_tac basic_rules ctxt
end
fun prove_functions is_recursive induct_thms solve_non_recursive solve_recursive ctxt props =
let
val defs = map (#1) props
val paramss = map (#2) props
val arbitrary_varss = map (#3) props
val preds = map (#1 o #4) props
val attrss = map (#2 o #4) props
val all_varss = map (fn (xs, ys) => xs @ ys) (paramss ~~ arbitrary_varss)
val N = length props
val top_thms = Named_Theorems.get ctxt @{named_theorems corres_top}
val defs = defs |> map (Local_Defs.abs_def_rule ctxt)
fun prove_non_recursive (def, _, _, (prop, attribs)) =
Goal.prove ctxt [] [] prop (fn {context, prems} =>
EqSubst.eqsubst_tac context [1] defs 1 THEN
solve_non_recursive context)
fun admissibility_tac ctxt i = Seq.INTERVAL (corres_admissible_tac ctxt) i (i + N - 1)
val bump_unify_bound = Config.map Unify.search_bound (fn n => n * N)
val _ = if not (!d1) then () else
let
val _ = tracing (big_list_of_terms "prove_functions preds:" ctxt preds)
val _ = tracing ("all_varss: " ^ @{make_string} all_varss)
val _ = if (!d2) then tracing (big_list_of_thms "induct_thms:" ctxt induct_thms) else ()
in () end
val thms =
if is_recursive then
let
in
Goal.prove_common ctxt NONE [] [] preds (fn {context, prems = _} =>
Utils.dprint_tac (!d1) "prove_functions before induct" context THEN
DETERM (Induct.induct_tac (bump_unify_bound context) false
[]
(map (map_filter (try dest_Free)) all_varss) []
(SOME induct_thms) [] 1) THEN
Utils.dprint_tac (!d1) "prove_functions after induct" context THEN
admissibility_tac (bump_unify_bound context) 1 THEN
Utils.dprint_tac (!d1) "prove_functions after admissiblity_tac" context THEN
REPEAT_DETERM_N N (match_tac context top_thms 1) THEN
Utils.dprint_tac (!d1) "prove_functions after top" context THEN
solve_recursive attrss context THEN
Utils.dprint_tac (!d1) "prove_functions after solve_recursive" context)
end
else
map prove_non_recursive props
in
thms
end
fun WITH_NSUBGOALS tac st =
tac (Thm.nprems_of st) st
fun subgoal_assm_tac all_prems = Subgoal.FOCUS_PREMS (fn {context, prems, ...} =>
let
val all_prems' = prems @ all_prems
in
(DETERM' (resolve_tac context all_prems') THEN_ALL_NEW subgoal_assm_tac all_prems' context) 1
end)
fun subgoal_intro_tac thms = Subgoal.FOCUS (fn {context, prems, ...} =>
let
val n = length thms
in
(DETERM' (resolve_tac context thms)
THEN_ALL_NEW (
(DETERM' (resolve_tac context prems))
THEN_ALL_NEW (
(subgoal_assm_tac prems context)))) 1
end)
fun subgoal_intro_tac' thms = Subgoal.FOCUS (fn {context, prems, ...} =>
let
val n = length thms
val hyps = take n prems
val inst_thms = map (fn thm => thm OF hyps) thms
in
(resolve_tac context inst_thms
THEN_ALL_NEW (subgoal_assm_tac prems context)) 1
end)
fun subgoal_intro_tac'' {instantiate} thms =
(if instantiate then Subgoal.FOCUS else Subgoal.FOCUS_PREMS) (fn {context, prems, ...} =>
let
val (thm_insts, other_prems) =
if instantiate then
let
val n = length thms
val hyps = take n prems
val rest = drop n prems
val inst_thms = map (fn thm => thm OF hyps) thms
in
(inst_thms, rest)
end
else (thms, [])
in
(DETERM' (resolve_tac context thm_insts)
THEN_ALL_NEW (
(subgoal_intro_tac'' {instantiate=false} (prems@other_prems) context))) 1
end)
fun apply f params = fold (fn x => fn f => f $ x) params f
fun lambdas [] t = t
| lambdas (x::xs) t = Term.lambda x (lambdas xs t)
fun define_funcs_single_recursive_group
(skips: FunctionInfo.skip_info)
(phase : FunctionInfo.phase)
(prog_info: ProgramInfo.prog_info)
(qualify: binding -> binding)
concealed
(get_const_name : string -> string)
(get_fn_type : string -> typ)
(get_fn_prop : Proof.context -> bool -> string -> term -> term list ->
((term * attribute list) * term list option))
(get_fn_args_def : string -> (string * typ) list)
(get_fn_args_prop : string -> (string * typ) list)
(functions : (string * (term * thm * (string * typ) list)) list)
(lthy : local_theory)
: string list * local_theory =
let
val _ = @{assert} (not (null functions));
val fn_names = map fst functions
val fn_bodies = map (snd #> #1) functions
val fn_thms = map (#2 o #2) functions
val N = length fn_names
val prev_phase = FunctionInfo.prev_phase skips phase
val fn_names_str = commas (map get_const_name fn_names);
val filename = ProgramInfo.get_prog_name prog_info
val _ = writeln ("Defining (" ^ FunctionInfo.string_of_phase phase ^ ") " ^ fn_names_str)
fun get_prev_info lthy name =
AutoCorresData.get_function_info (Context.Proof lthy) filename prev_phase name |> the
val is_recursive = FunctionInfo.is_function_recursive (get_prev_info lthy (hd fn_names))
val _ = assert (length fn_names = 1 orelse is_recursive)
"define_funcs passed multiple functions, but they don't appear to be recursive."
fun fill_body fn_name body =
let
val fn_info = get_prev_info lthy fn_name
val rec_calls = map (fn x => Free (get_const_name x, get_fn_type x)) (Symset.dest (FunctionInfo.get_rec_callees fn_info))
in
body
|> (fn t => betapplys (t, rec_calls))
end
val defs = map (
fn (fn_name, fn_body) => let
val fn_args = get_fn_args_def fn_name
val (fn_free :: arg_frees, _) = Variable.variant_fixes
(get_const_name fn_name :: map fst fn_args) lthy
in (fn_name, get_const_name fn_name,
(arg_frees ~~ map snd fn_args),
fill_body fn_name fn_body) end)
(fn_names ~~ fn_bodies)
val _ = if (!d1) then tracing ("define_funcs_single_recursive_group: before definition") else ()
val lthy = lthy |> AutoCorresData.in_theory'
(fn lthy =>
let
val (_, lthy') = lthy |> Utils.define_functions defs qualify true is_recursive "spec_monad_gfp"
[AutoCorresData.define_function_attribute concealed filename skips phase] [] []
in lthy' end)
val _ = if (!d1) then tracing ("define_funcs_single_recursive_group: before fn_defs (0)") else ()
val fn_def_thms = map (FunctionInfo.get_definition o the o (AutoCorresData.get_function_info (Context.Proof lthy) filename phase)) fn_names
val _ = if (!d1) then tracing ("define_funcs_single_recursive_group: before fn_defs (1)") else ()
val fn_def_thms = fn_def_thms |> map (safe_mk_meta_eq)
val _ = if (!d1) then tracing ("define_funcs_single_recursive_group: before induct_thms") else ()
fun get_induct_thms () =
let
val c = hd fn_def_thms |> Thm.lhs_of |> Thm.term_of |> Term.head_of
in
Mutual_CCPO_Rec.lookup_info_trimmed (Context.Proof lthy) c |> the_list |> maps #inducts
end
val induct_thms = timeit_msg 1 lthy (fn _ => "induct_thms") (get_induct_thms)
val _ = if (!d1) then tracing ("define_funcs_single_recursive_group: before combined_thms") else ()
val combined_callees = map (get_callees o get_prev_info lthy) (map fst functions)
val _ = if (!d1) then tracing ("define_funcs_single_recursive_group: before combined_normal_calls") else ()
val combined_normal_calls =
map fst combined_callees |> flat |> sort_distinct fast_string_ord
fun get_corres_thm name = AutoCorresData.get_function_info (Context.Proof lthy) filename phase name
|> the |> FunctionInfo.get_corres_thm
val _ = if (!d1) then tracing ("define_funcs_single_recursive_group: before nrec_corres_thms") else ()
val nrec_corres_thms = map get_corres_thm combined_normal_calls
val _ = if (!d1) then tracing ("define_funcs_single_recursive_group: before props") else ()
val (props, ctxt') = lthy |> fold_map (
fn (fn_name, def) => fn ctxt =>
let
val fn_const = Utils.get_term lthy (get_const_name fn_name)
val (params, ctxt') = Utils.fix_variant_frees (get_fn_args_prop fn_name) ctxt
val ((corres_prop_current_phase, attrs), params_opt) = get_fn_prop ctxt' false fn_name fn_const params
val params' = the_default params params_opt
val changed_params = filter_out (member (op =) params) params'
val ((arbitrary_vars, corres_prop_current_phase), ctxt') = Utils.import_universal_prop corres_prop_current_phase ctxt'
in
((def, params', arbitrary_vars @ changed_params, (corres_prop_current_phase, attrs)), ctxt')
end) (fn_names ~~ fn_def_thms)
fun solve_recursive _ ctxt = REPEAT (subgoal_intro_tac'' {instantiate=false} fn_thms ctxt 1)
fun solve_non_recursive ctxt = (
match_tac ctxt fn_thms
THEN_ALL_NEW
(((EVERY' [match_tac ctxt (nrec_corres_thms)]) THEN_ALL_NEW Method.assm_tac ctxt)
ORELSE'
Method.assm_tac ctxt)
ORELSE' (K (print_tac ctxt "define_funcs_single_recursive_group final proof failed")))
1
val _ = if (!d1) then tracing ("define_funcs_single_recursive_group: before corres_thms") else ()
val corres_thms = prove_functions is_recursive induct_thms solve_non_recursive solve_recursive ctxt' props
val corres_thms =
corres_thms
|> Variable.export ctxt' lthy
|> map (Goal.norm_result lthy)
val _ = if (!d1) then tracing ("define_funcs_single_recursive_group: before define_lemma") else ()
val (corres_thms, lthy) = lthy
|> fold_map (fn (name, thm) => Utils.define_lemma (Binding.name (AutoCorresData.corres_thm_name prog_info phase name))
[AutoCorresData.corres_thm_attribute filename skips phase name] thm)
(fn_names ~~ corres_thms)
val _ = if (!d1) then tracing ("define_funcs_single_recursive_group: end") else ()
in
(fn_names, lthy)
end
fun gen_define_funcs
(skips: FunctionInfo.skip_info)
(phase : FunctionInfo.phase)
(prog_info: ProgramInfo.prog_info)
(qualify: binding -> binding)
concealed
(get_const_name : string -> string)
(get_fn_type : string -> typ)
(get_fn_prop: Proof.context -> bool -> string -> term -> term list ->
((term * attribute list) * term list option))
(get_fn_args_def : string -> (string * typ) list)
(get_fn_args_prop : string -> (string * typ) list)
(functions : (string * (term * thm * (string * typ) list)) list)
(lthy : local_theory)
: string list * Proof.context =
let
val prev_phase = FunctionInfo.prev_phase skips phase
fun is_recursive name = AutoCorresData.get_function_info (Context.Proof lthy) (ProgramInfo.get_prog_name prog_info) prev_phase
name |> the |> FunctionInfo.is_function_recursive
val funcss = if null functions orelse is_recursive (fst (hd functions))
then [functions]
else map (fn x => [x]) functions
in
([], lthy) |> fold (fn funcs => fn (names, lthy) =>
let
val (new_names, lthy') = lthy |>
define_funcs_single_recursive_group
skips phase prog_info qualify concealed get_const_name get_fn_type get_fn_prop get_fn_args_def get_fn_args_prop
funcs
in (names @ new_names, lthy') end) funcss
end
fun define_funcs
(skips: FunctionInfo.skip_info)
(phase : FunctionInfo.phase)
(prog_info: ProgramInfo.prog_info)
(qualify: binding -> binding)
concealed
(get_const_name : string -> string)
(get_fn_type : string -> typ)
(get_fn_prop: Proof.context -> bool -> string -> term -> term list ->
((term * attribute list) * term list option))
(get_fn_args : string -> (string * typ) list)
(functions : (string * (term * thm * (string * typ) list)) list)
(lthy : local_theory) =
gen_define_funcs skips phase prog_info qualify concealed get_const_name get_fn_type get_fn_prop
get_fn_args get_fn_args
functions lthy
type convert_result = {
body: term,
proof: thm,
rec_callees: symset,
callee_consts: term Symtab.table,
arg_frees: (string * typ) list
}
fun in_corres_locale_result prog_info skips phase filename clique f lthy =
let
val loc = NameGeneration.intern_globals_locale_name (Proof_Context.theory_of lthy) filename
in AutoCorresData.in_locale_result loc f lthy end
fun in_corres_locale prog_info skips phase filename clique f lthy =
let
val loc = NameGeneration.intern_globals_locale_name (Proof_Context.theory_of lthy) filename
in AutoCorresData.in_locale_result loc (fn lthy => f lthy) lthy end
fun abstract_fn_body
(prev_fn_infos: FunctionInfo.function_info Symtab.table)
(fn_name, {body, callee_consts, arg_frees, ...} : convert_result) = let
val (callees, rec_callees) = get_callees' prev_fn_infos fn_name;
val rec_calls = map (the o Symtab.lookup callee_consts) rec_callees;
val abs_body = body
|> fold lambda (rev (map Free arg_frees))
|> fold lambda (rev rec_calls)
in abs_body end;
fun update_defined t1 t2 =
Symtab.map (fn key => fn x =>
case Symtab.lookup t2 key of
NONE => x
| SOME y => y) t1
fun restrict_domain keys t =
[]
|> fold (fn k => fn xs => case Symtab.lookup t k of NONE => xs | SOME y => (k,y)::xs) keys
|> Symtab.make
fun split_infos infos names =
let
val infos_without_names = infos |> fold Symtab.delete_safe names
val infos_of_names = restrict_domain names infos
in
(infos_without_names, infos_of_names)
end
fun no_prepare (finfos: FunctionInfo.function_info Symtab.table) (clique: string list) (lthy:local_theory) = lthy
fun do_prepare skips prog_info phase prepare clique lthy =
if pointer_eq (prepare, no_prepare) then
lthy
else
let
val prev_phase = FunctionInfo.prev_phase skips phase
val filename = ProgramInfo.get_prog_name prog_info
val infos = AutoCorresData.get_default_phase_info (Context.Proof lthy) filename prev_phase
in
timeit_prepare_msg' 1 lthy phase clique (fn _ => prepare infos clique lthy)
end
fun convert_and_define_clique
(skips: FunctionInfo.skip_info)
(prog_info: ProgramInfo.prog_info)
(phase: FunctionInfo.phase)
(parallel: bool)
(convert: local_theory -> FunctionInfo.function_info Symtab.table ->
string -> convert_result)
(define: local_theory ->
convert_result Symtab.table ->
local_theory)
(todo_clique: string list)
(lthy: local_theory)
: (string list list * local_theory)
=
let
val prev_phase = FunctionInfo.prev_phase skips phase
val par_map = if parallel then Par_List.map else map
val filename = ProgramInfo.get_prog_name prog_info
val existing_infos_prev_phase = AutoCorresData.get_default_phase_info (Context.Proof lthy) filename prev_phase
val existing_infos_current_phase = AutoCorresData.get_default_phase_info (Context.Proof lthy) filename phase
val recursive_group = is_recursive_group existing_infos_prev_phase todo_clique
val loc = the (Named_Target.bottom_locale_of lthy)
val _ = case todo_clique of [] => () | stuff =>
verbose_msg 0 lthy (fn _ => "Conversions (" ^ FunctionInfo.string_of_phase phase ^
") for: " ^ commas stuff ^ " in locale " ^ loc)
val conv_results =
todo_clique
|> par_map (fn fname =>
(fname, timeit_conversion_msg' 1 lthy phase fname (fn () => convert lthy existing_infos_prev_phase fname)))
|> Symtab.make;
val lthy = if Symtab.is_empty conv_results then lthy else
define lthy conv_results;
val new_infos = AutoCorresData.get_phase_info (Context.Proof lthy) filename phase
|> the |> restrict_domain todo_clique
val new_infoss = FunctionInfo.recalc_callees existing_infos_current_phase new_infos
val new_groups = map (map fst o Symtab.dest) new_infoss
val lthy = lthy
|> fold (fn group =>
AutoCorresData.in_theory' (
Local_Theory.declaration {pervasive=true, syntax=false, pos=⌂} (fn phi =>
AutoCorresData.map_default_phase_info filename phase
(FunctionInfo.transfer_call_graph_infoss new_infoss))
))
new_groups
val groups_next_phase = if not recursive_group then [todo_clique] else new_groups
in (groups_next_phase, lthy) end
fun gen_convert_and_define_cliques
(prepare: FunctionInfo.function_info Symtab.table -> string list -> local_theory -> local_theory)
(skips: FunctionInfo.skip_info)
(base_locale_opt: string option)
(prog_info: ProgramInfo.prog_info)
(phase: FunctionInfo.phase)
(parallel: bool)
(convert: local_theory -> FunctionInfo.function_info Symtab.table ->
string -> convert_result)
(define: local_theory ->
convert_result Symtab.table ->
local_theory)
(cliques: string list list)
(lthy: local_theory)
: (string list list * local_theory)
=
let
val prev_phase = FunctionInfo.prev_phase skips phase
val filename = ProgramInfo.get_prog_name prog_info
fun do_or_skip_group clique lthy =
let
val existing_infos_current_phase = AutoCorresData.get_default_phase_info (Context.Proof lthy) filename phase
val infos_prev_phase = AutoCorresData.get_default_phase_info (Context.Proof lthy) filename prev_phase
val todo_clique = clique
|> filter_out (fn name =>
member (op =) (Symtab.keys existing_infos_current_phase) name)
val (groups_next_phase, lthy) =
if todo_clique = clique
then
lthy
|> do_prepare skips prog_info phase prepare clique
|> in_corres_locale_result prog_info skips phase filename clique
(convert_and_define_clique skips prog_info phase parallel convert define clique)
else
let
val _ = @{assert} (null todo_clique)
val (infos_without_names, infos_of_names) = split_infos existing_infos_current_phase clique
val groups_next_phase =
FunctionInfo.recalc_callees infos_without_names infos_of_names
|> map (map fst o Symtab.dest)
in
(groups_next_phase, lthy)
end
val _ = verbose_msg 1 lthy (fn _ => "groups_next_phase: " ^ @{make_string} groups_next_phase)
in
(groups_next_phase, lthy)
end
in
([], lthy)
|> fold (fn clique => fn (cliques, lthy) =>
let
val (new_cliques, lthy) = do_or_skip_group clique lthy
in (cliques @ new_cliques, lthy) end)
cliques
end
val convert_and_define_cliques = gen_convert_and_define_cliques no_prepare
end