File ‹local_var_extract.ML›
structure LocalVarExtract =
struct
open Prog
val timeit_msg = Utils.timeit_msg
val timeap_msg_tac = Utils.timeap_msg_tac
val timing_msg' = Utils.timing_msg'
val verbose_msg = Utils.verbose_msg
val preservation_cache_mode = Attrib.setup_config_int @{binding "preservation_cache_mode"} (K 0)
structure Symstab = Table(type key = string list val ord = list_ord fast_string_ord)
infix 1 INTER MINUS UNION
val empty_set = Varset.empty
val make_set = Varset.make
val union_sets = Varset.union_sets
fun (a INTER b) = Varset.inter a b
fun (a MINUS b) = Varset.subtract b a
fun (a UNION b) = Varset.union a b
local
fun compare ((x,xT), (y,yT)) = Term_Ord.var_ord ( ((x, 0), xT), ((y, 0), yT))
fun extern_compare prog_info ((x, xT), (y, yT)) =
compare ((ProgramInfo.demangle_name prog_info x, xT), (ProgramInfo.demangle_name prog_info y, yT))
in
fun sort_extern prog_info = sort (extern_compare prog_info)
fun dest_sort_extern prog_info s = Varset.dest s |> sort_extern prog_info
fun dest_extern prog_info s = Varset.dest s |> map (apfst (ProgramInfo.demangle_name prog_info)) |> Ord_List.make compare
end
val warning = Utils.ac_warning
val apply_tac = Utils.apply_tac
val the' = Utils.the'
val exit_status_name = "__exit_status"
val exit_status_pretty_name = "exit_status"
val exn_name = NameGeneration.global_exn_var_name
val exn_name_type = (exn_name, HP_TermsTypes.c_exntype_ty)
val exn_free = Free exn_name_type;
val exn_var = make_set [exn_name_type];
fun setup_l2_ss base_ss ctxt =
let
val state_simps = Named_Theorems.get ctxt @{named_theorems "state_simp"}
val state_simprocs = @{code_simprocs "state_simp"}
in
put_simpset base_ss (Context_Position.set_visible false ctxt)
|> Simplifier.add_simps (@{thms
globals_surj ucast_id pred_conj_def
Hoare.Collect_False
Set.mem_Collect_eq Set.Int_iff Set.empty_iff
simp_thms HOL.implies_True_equals prod.sel
Pure.triv_forall_equality comp_def K_eq_cong} @ state_simps)
|> fold Simplifier.add_proc ([Record.simproc,
@{simproc arg_cong}, @{simproc fun_cong} ] @
state_simprocs)
end
fun var_set_to_isa_list ctxt prog_info s =
let
fun demangle_name s =
if s = exit_status_name then exit_status_pretty_name
else ProgramInfo.demangle_name prog_info s
in
dest_sort_extern prog_info s
|> map fst
|> map demangle_name
|> CLocals.name_hints ctxt
end
fun convert_local_vars name_map term [] = ([], term)
| convert_local_vars name_map term ((var_name, var_term) :: vars) =
if Utils.contains_subterm var_term term then
let
val free_var = name_map (var_name, fastype_of var_term)
val abstracted = betapply (Utils.abs_over var_name var_term term, free_var)
val (other_vars, other_term) = convert_local_vars name_map abstracted vars
in
(other_vars @ [(var_name, fastype_of var_term)], other_term)
end
else
convert_local_vars name_map term vars
fun get_args_info l1_infos fn_name =
let
val fn_info = the (Symtab.lookup l1_infos fn_name);
in
FunctionInfo.get_args fn_info
end
fun get_variables l1_infos fn_name =
let
val fn_info = the (Symtab.lookup l1_infos fn_name);
fun prj (n, (T, _)) = (n, T)
val inputs = FunctionInfo.get_plain_args fn_info |> Varset.make;
val locals = FunctionInfo.get_locals fn_info |> map prj |> Varset.make;
val outputs = FunctionInfo.get_returns fn_info |> map prj |> Varset.make;
in
(inputs, locals, outputs)
end
fun get_fn_input_output_vars l1_infos fn_name =
get_variables l1_infos fn_name |> (fn (inputs, _, outputs) => (inputs, outputs))
fun get_ret_var' outputs = hd (outputs @ [("void", @{typ unit})])
fun get_ret_var l1_infos fn_name =
let
val (_, outputs) = get_fn_input_output_vars l1_infos fn_name
in
get_ret_var' (Varset.dest outputs)
end
fun dest_L2corres_term_abs @{term_pat "L2corres _ _ _ _ ?t _"} = t
fun dest_L2corres_term_conc @{term_pat "L2corres _ _ _ _ _ ?t"} = t
fun extract_pair_of_globals pat ctxt prog_info term =
let
val term = Simplifier.rewrite_term (Proof_Context.theory_of ctxt)
(map safe_mk_meta_eq @{thms Collect_prod_inter Collect_prod_union}) [] term
val dummy_s = Free ("_dummy_state1", ProgramInfo.get_state_type prog_info)
val dummy_t = Free ("_dummy_state2", ProgramInfo.get_state_type prog_info)
val ([((vstateT, _), _)], [((vf, _), f)]) = Utils.match_insts ctxt pat term
val t = Envir.beta_eta_contract (Thm.term_of f $ dummy_s $ dummy_t)
val t = Simplifier.rewrite_term (Proof_Context.theory_of ctxt)
(map mk_meta_eq @{thms split_def fst_conv snd_conv mem_Collect_eq}) [] t
val globals_getter = ProgramInfo.get_globals_getter prog_info
val t = Utils.abs_over "t" (globals_getter $ dummy_t) t
|> Utils.abs_over "s" (globals_getter $ dummy_s)
in
if Utils.contains_subterm dummy_s t
orelse Utils.contains_subterm dummy_t t then
(warning ("Can't parse pair of globals term: "
^ (Utils.term_to_string ctxt term)); NONE)
else
SOME (subst_vars ([(vstateT, fastype_of (globals_getter $ dummy_t))], [(vf, t)]) pat)
end
local
val stateT = TVar (("'state", 0), [])
val pat_spec =
\<^Const>‹Collect ‹HOLogic.mk_prodT (stateT, stateT)›› $
(HOLogic.mk_case_prod (Abs ("s", stateT, Abs ("t", stateT,
Var (("f", 0), stateT --> stateT --> HOLogic.boolT) $ Bound 1 $ Bound 0))))
val pat_assume = Abs ("s", stateT, \<^Const>‹Collect ‹HOLogic.mk_prodT (HOLogic.unitT, stateT)›› $
(HOLogic.mk_case_prod (Abs ("u", HOLogic.unitT, Abs ("t", stateT,
Var (("f", 0), stateT --> stateT --> HOLogic.boolT) $ Bound 2 $ Bound 0)))))
in
fun parse_spec ctxt prog_info term = extract_pair_of_globals pat_spec ctxt prog_info term
fun parse_assume ctxt prog_info term = extract_pair_of_globals pat_assume ctxt prog_info term
end
fun parse_expr ctxt prog_info name_map term =
let
val dummy_state = Free ("_dummy_state", ProgramInfo.get_state_type prog_info)
val term = Envir.beta_eta_contract (term $ dummy_state)
val globals_getter = ProgramInfo.get_globals_getter prog_info $ dummy_state
val globals_used = Utils.contains_subterm globals_getter term
val t = Utils.abs_over "s" globals_getter term
val all_getters = ProgramInfo.all_var_getters ctxt prog_info dummy_state |> map (apsnd fst)
val ps = HPInter.collect_positional (HPInter.mk_locals ctxt dummy_state) t
val (v1, t) = convert_local_vars name_map t (all_getters @ ps)
val t = if Utils.contains_subterm dummy_state t then
(warning ("Can't parse expression: "
^ (Utils.term_to_string ctxt term)); NONE)
else
SOME t;
in
(v1, globals_used, t)
end
fun gen_parse_modify (params as {write_scope, read_scope, two_state}) ctxt prog_info name_map term =
let
val dummy_state_write = Free ("_dummy_state", ProgramInfo.get_state_type prog_info)
val dummy_state_read = if two_state then Free ("_dummy_state_old", ProgramInfo.get_state_type prog_info) else dummy_state_write
val trm = if two_state then (term $ dummy_state_read) else term
fun parse_modify' term =
let
val trm = if two_state then (term $ dummy_state_read) else term
val modify_clause = Envir.beta_eta_contract (trm $ dummy_state_write)
val ((var_name, var_type), modify_val_opt, s) = case ProgramInfo.dest_var_update modify_clause of
SOME ((var_name, var_type), modify_val_opt, SOME s) => ((var_name, var_type), modify_val_opt, s)
| _ => Utils.invalid_term' ctxt "variable update" modify_clause;
val get_value = ProgramInfo.get_var_value (write_scope ctxt) prog_info var_name dummy_state_read
fun remove_dummy_state st t = Utils.abs_over "s" st t
val (vars, globals_used, modify_val) =
case modify_val_opt of
NONE => ([], false, NONE)
| SOME modify_val =>
let
val modify_val = betapply (modify_val, get_value)
|> Envir.beta_eta_contract
val (vars, globals_used, modify_val) = parse_expr (read_scope ctxt) prog_info name_map (remove_dummy_state dummy_state_read modify_val)
in (vars, globals_used, modify_val) end
in
((var_name, var_type), vars, globals_used, modify_val,
remove_dummy_state dummy_state_write s |> two_state ? remove_dummy_state dummy_state_read)
end
in
if Envir.beta_eta_contract (trm $ dummy_state_write) = dummy_state_write then []
else
let
val (updated_var, read_vars, reads_globals, term, residual) = parse_modify' term
in
(updated_var, read_vars, reads_globals, term) :: gen_parse_modify params ctxt prog_info name_map residual
end
end
val parse_modify = gen_parse_modify {write_scope = I, read_scope = I, two_state = false};
val parse_modify_two_state = gen_parse_modify {write_scope = I, read_scope = I, two_state = true};
fun int_of_string s =
case (s |> Symbol.explode |> read_int) of
(i, []) => SOME i
| _ => NONE
fun mk_loc_ref T n =
case NameGeneration.dest_positional_name n of
SOME (NameGeneration.In i, _) => NameGeneration.Positional (i, T)
| _ => (case int_of_string n of
SOME i => NameGeneration.Positional (i, T)
| _ => NameGeneration.Named n)
fun var_getter ctxt (prog_info : ProgramInfo.prog_info) state proj_status (var, T) =
let
fun getter x = Symtab.lookup (ProgramInfo.get_var_getters prog_info) x |> the
fun proj x = if proj_status then @{const the_Nonlocal(exit_status)} $ x else x
fun get state =
if var = NameGeneration.global_exn_var_name andalso T = @{typ exit_status}
then proj (getter NameGeneration.global_exn_var_name $ state)
else ProgramInfo.get_var_value ctxt prog_info (mk_loc_ref T var) state
in get state
end handle Option => (Utils.invalid_input "valid local variable name" var)
fun dest_positional ctxt vars =
Varset.dest vars |> sort (CLocals.positional_ord ctxt o (apply2 fst))
fun mk_precond ctxt prog_info name_map vars =
let
val myvarsT = ProgramInfo.get_state_type prog_info
val dummy_state = Free ("_dummy_state", myvarsT)
in
Utils.chain_preds myvarsT
(map (fn (var_name, var_type) =>
let
val var = var_getter ctxt prog_info dummy_state false (var_name, var_type)
in
Utils.abs_over "s" dummy_state
(HOLogic.mk_eq (var, name_map (var_name, var_type)))
end)
(dest_positional ctxt vars))
end
fun mk_xf ctxt (prog_info : ProgramInfo.prog_info) vars =
let
val dummy_state = Free ("_dummy_state", ProgramInfo.get_state_type prog_info)
in
Utils.abs_over "s" dummy_state
(HOLogic.mk_tuple (dest_sort_extern prog_info vars |> map (var_getter ctxt prog_info dummy_state true)))
end
fun mk_corresXF_prop ctxt prog_info name_map return_vars except_vars precond_vars l2_term l1_term =
let
val precond = mk_precond ctxt prog_info name_map precond_vars
val return_xf = mk_xf ctxt prog_info return_vars
val except_xf = mk_xf ctxt prog_info except_vars
val corres = \<^infer_instantiate>‹st = ‹ProgramInfo.get_globals_getter prog_info› and ret_xf = return_xf and
ex_xf = except_xf and P = precond and A = l2_term and C = l1_term
in prop ‹L2corres st ret_xf ex_xf P A C›› ctxt
in
corres
end
fun mk_corresXF_thm ctxt prog_info name_map return_vars except_vars precond_vars l2_term l1_term tac =
let
val ctxt = ctxt addsimps @{thms split_def}
val prop = timeit_msg 2 ctxt (fn _ => "mk_corresXF_prop: ") (fn _ => mk_corresXF_prop ctxt prog_info name_map
return_vars except_vars precond_vars l2_term l1_term)
in
prop
|> (fn goal => Utils.simple_prove ctxt goal tac)
end
fun solve_simp_sideconditions ctxt thm =
let
val nprems = Thm.prems_of thm |> length
val st = thm |> Goal.protect nprems
in
Utils.simple_cprove ctxt st
(timeap_msg_tac 2 ctxt (fn _ => "solve_simp_sideconditions")
(REPEAT (CHANGED (asm_full_simp_tac ctxt 1))))
end
fun mk_corresXF_thm_direct ctxt prog_info name_map return_vars except_vars precond_vars l2_term l1_term thm =
let
val goal = mk_corresXF_prop ctxt prog_info name_map return_vars except_vars precond_vars l2_term l1_term
val concl = Thm.concl_of thm
val nprems = Thm.prems_of thm |> length
val (ty_insts, trm_insts) = Utils.match_or_unify ctxt concl goal
val st = Thm.instantiate (TVars.make ty_insts, Vars.make trm_insts) thm |> Goal.protect nprems
val ctxt = ctxt addsimps @{thms split_def}
in
Utils.simple_cprove ctxt st
(timeap_msg_tac 2 ctxt (fn _ => "mk_corresXF_thm_direct - solve sideconditions")
(REPEAT (Method.assm_tac ctxt 1 ORELSE CHANGED (asm_full_simp_tac ctxt 1))))
end
fun dummy_state_guards_l1 prog_info = Free ("_dummy_state", ProgramInfo.get_state_type prog_info)
fun dummy_state_guards_l2 prog_info = Free ("_dummy_state_l2", ProgramInfo.get_globals_type prog_info)
val dummy_fun_ptr = Free ("_p", @{typ "unit ptr"})
fun l1call_function_const t = case strip_comb t |> apsnd rev of
(Const c, (Const c' :: _)) => if String.isSuffix "_'proc" (fst c')
then Const c' else Const c
| (Const c, _) => Const c
| (Abs (_, _, t), []) => l1call_function_const t
| _ => raise TERM ("l1call_function_const", [t])
fun callee_scope prog_info t ctxt =
(case strip_comb t of
(Const (fname, _), _) =>
(case try (ProgramInfo.get_dest_fun_name prog_info FunctionInfo.L1 "") (Long_Name.base_name fname) of
SOME fname => CLocals.switch_scope fname ctxt
| _ => ctxt)
| _ => ctxt)
fun parse_l1 ctxt prog_info l1_infos l1_call_info name_map term =
case term of
(Const (@{const_name "L1_skip"}, _)) =>
Modify (term,
(SOME (Abs ("s", ProgramInfo.get_globals_type prog_info, @{term "()"})), empty_set, false), empty_set)
| (Const (@{const_name "L1_modify"}, _) $ m) =>
let
val parsed_clause = parse_modify ctxt prog_info name_map m
val (updated_var, read_vars, is_globals_reader, parsed_expr) =
case parsed_clause of
[x] => x
| _ => Utils.invalid_term' ctxt "Modifies clause too complex." m
in
Modify (term, (parsed_expr, make_set read_vars, is_globals_reader),
make_set [apfst NameGeneration.the_named updated_var])
end
| (Const (@{const_name "L1_seq"}, _) $ lhs $ rhs) =>
Seq (term, parse_l1 ctxt prog_info l1_infos l1_call_info name_map lhs,
parse_l1 ctxt prog_info l1_infos l1_call_info name_map rhs)
| (Const (@{const_name "L1_catch"}, _) $ lhs $ rhs) =>
Catch (term, parse_l1 ctxt prog_info l1_infos l1_call_info name_map lhs,
parse_l1 ctxt prog_info l1_infos l1_call_info name_map rhs)
| (Const (@{const_name "L1_guard"}, _) $ c) =>
let
val (read_vars, is_globals_reader, parsed_expr) = parse_expr ctxt prog_info name_map c
in
Guard (term, (parsed_expr, make_set read_vars, is_globals_reader))
end
| @{term_pat "L1_guarded ?g (gets ?dest >>= ?c)"} =>
let
val (read_vars, is_globals_reader, parsed_expr) = parse_expr ctxt prog_info name_map g
val (read_vars_dest, is_globals_reader_dest, parsed_expr_dest) = parse_expr ctxt prog_info name_map dest
val dummy_state = dummy_state_guards_l1 prog_info
val p = Envir.beta_eta_contract (dest $ dummy_state)
val c = Envir.beta_eta_contract (c $ p)
in
Guarded (term, (parsed_expr, make_set read_vars, is_globals_reader),
(parsed_expr_dest, make_set read_vars_dest, is_globals_reader_dest),
parse_l1 ctxt prog_info l1_infos l1_call_info name_map c)
end
| @{term_pat "L1_guarded ?g ?c"} =>
let
val (read_vars, is_globals_reader, parsed_expr) = parse_expr ctxt prog_info name_map g
val emp = (NONE, empty_set, false)
in
Guarded (term, (parsed_expr, make_set read_vars, is_globals_reader),
emp, parse_l1 ctxt prog_info l1_infos l1_call_info name_map c)
end
| (Const (@{const_name "L1_throw"}, _)) =>
Throw term
| (Const (@{const_name "L1_condition"}, _) $ cond $ lhs $ rhs) =>
let
val (read_vars, is_globals_reader, parsed_expr) = parse_expr ctxt prog_info name_map cond
in
Condition (term, (parsed_expr, make_set read_vars, is_globals_reader),
parse_l1 ctxt prog_info l1_infos l1_call_info name_map lhs,
parse_l1 ctxt prog_info l1_infos l1_call_info name_map rhs)
end
| (Const (@{const_name "L1_call"}, L1_call_type)
$ arg_setup $ callee_term $ return_norm $ return_exn $ ret_extract) =>
let
val arg_setup_exprs =
gen_parse_modify {write_scope = callee_scope prog_info callee_term, read_scope = I, two_state = false}
ctxt prog_info name_map arg_setup
|> map (fn (_, read_vars, is_globals_reader, term) =>
(term, make_set read_vars, is_globals_reader))
val callee_expr =
let
val callee' = Utils.abs_over "s" (dummy_state_guards_l1 prog_info) callee_term
val (read_vars, is_globals_reader, term) = parse_expr ctxt prog_info name_map callee'
in (term, make_set read_vars, is_globals_reader) end
val parsed_clause =
gen_parse_modify {write_scope = I, read_scope = callee_scope prog_info callee_term, two_state = false}
ctxt prog_info name_map (betapply (ret_extract, Free ("_dummy_state", ProgramInfo.get_state_type prog_info)))
|> map (fn (target_var, read_vars, globals_read, expr) =>
let
val ret_var = get_ret_var' read_vars
in
(target_var, (make_set read_vars) MINUS (make_set [ret_var]),
globals_read, Option.map (Utils.abs_over "ret" (name_map ret_var)) expr)
end)
val (ret_expr, updated_var) =
case parsed_clause of
[(target_var, read_vars, globals_read, expr)] =>
((expr, read_vars, globals_read), make_set [apfst NameGeneration.the_named target_var])
| [] => ((NONE, empty_set, false), empty_set)
| x => Utils.invalid_input "single return param" (@{make_string} x)
in
Call (term, callee_expr, arg_setup_exprs, ret_expr, (updated_var UNION exn_var), ())
end
| (Const (@{const_name "L1_exec_spec_monad"},_)$ upd_x $ st $ args $ f $ res) =>
let
fun dest_tuple_args (Abs (s,sT, b)) = map (fn b' => Abs (s,sT, b')) (HOLogic.strip_tuple b)
| dest_tuple_args t = [t]
fun e x =
let
val (read_vars, is_globals_reader, term) = parse_expr ctxt prog_info name_map x
in (term, make_set read_vars, is_globals_reader) end
fun m x = case ProgramInfo.dest_var_update_bare x of
SOME (updated_var, _ ,_) => [apfst NameGeneration.the_named updated_var]
| _ => []
val arg_exprs = dest_tuple_args args |> map e
val updated_var = m res |> make_set
in
Exec_Spec_Monad (term, arg_exprs, updated_var)
end
| (Const (@{const_name "L1_while"}, _) $ cond $ body) =>
let
val (read_vars, is_globals_reader, parsed_expr) = parse_expr ctxt prog_info name_map cond;
in
While (term, (parsed_expr, make_set read_vars, is_globals_reader),
parse_l1 ctxt prog_info l1_infos l1_call_info name_map body)
end
| (Const (@{const_name "L1_init"}, _) $ setter) =>
let
val updated_var = ProgramInfo.guess_var_name_type_from_setter_term setter
in
Init (term, make_set [updated_var])
end
| (Const (@{const_name "L1_spec"}, _) $ c) =>
(case parse_spec ctxt prog_info c of
SOME x =>
Spec (term, (SOME x, empty_set, true))
| NONE =>
Spec (term, (NONE, empty_set, true)))
| (Const (@{const_name "L1_assume"}, _) $ c) =>
(case parse_assume ctxt prog_info c of
SOME x =>
Assume (term, (SOME x, empty_set, true))
| NONE =>
Assume (term, (NONE, empty_set, true)))
| (Const (@{const_name "L1_fail"}, _)) =>
Fail term
| other =>
let
val {init, c, ...} = with_fresh_stack_ptr.match ctxt other
val Abs (p, pT, _) = c
val sT = fastype_of init |> domain_type
val (read_vars, is_globals_reader, parsed_expr) = parse_expr ctxt prog_info name_map init
val ((p, pT), bdy) = Term.dest_abs_fresh p c
val ([p], ctxt') = Utils.gen_fix_variant_frees true [(p, pT)] ctxt
val bdy' = parse_l1 ctxt' prog_info l1_infos l1_call_info name_map bdy
in
Stack (term, (parsed_expr, make_set read_vars, is_globals_reader), bdy')
end handle Match => Utils.invalid_term' ctxt "a L1 term" other
type export_info =
{phi_export: Morphism.morphism,
dummy_value: string,
dummy_init: string,
phi_cache: Morphism.morphism,
weaken_superset : varset -> thm -> thm};
fun mk_pattern ctxt dummy_value dummy_init t =
let
val mk_pattern = mk_pattern ctxt dummy_value dummy_init
in
case t of
(c as Const (@{const_name "L1_modify"}, T)) $ f =>
c $ HPInter.subst_var_update ctxt dummy_value f
| (c as Const (@{const_name "L1_call"}, _)) $ init $ n $ exit $ res_exn $ res_ret =>
c $ Free (dummy_init , fastype_of init) $ n $ exit $ res_exn $ HPInter.subst_var_update ctxt dummy_value res_ret
| (c as Const (@{const_name "L1_seq"}, _)) $ lhs $ rhs =>
let
val lhs' = mk_pattern lhs
val rhs'= mk_pattern rhs
in
c $ lhs' $ rhs'
end
| (c as Const (@{const_name "L1_while"}, _)) $ cond $ body =>
let
val body' = mk_pattern body
in
c $ cond $ body'
end
| (c as Const (@{const_name "L1_condition"}, _)) $ cond $ lhs $ rhs =>
let
val lhs' = mk_pattern lhs
val rhs' = mk_pattern rhs
in
c $ cond $ lhs' $ rhs'
end
| (c as Const (@{const_name "L1_catch"}, _)) $ lhs $ rhs =>
let
val lhs' = mk_pattern lhs
val rhs' = mk_pattern rhs
in
c $ lhs' $ rhs'
end
| _ => t
end
fun fold_default d f xs =
case xs of
[] => d
| [x] => x
| (x::xs) => fold f xs x
fun join_preserve_thms thms =
fold_default @{thm hoareE_TrueI}
(fn x => fn y => @{thm combine_validE} OF [x,y])
thms
@{record ‹datatype pres_cache = Cache of
{tab: thm Symstab.table Termtab.table,
mode: int,
hits: int,
misses: int,
superset: int,
join: int}›}
fun pres_cache_empty ctxt = make_pres_cache {mode = Config.get ctxt preservation_cache_mode,
tab = Termtab.empty, hits = 0, misses = 0, superset = 0, join = 0};
local
fun key var_set = map fst (Varset.dest var_set)
fun find_superset mode tab var_set =
if Varset.card var_set = 0 orelse mode <= 0 then NONE
else
let
val set = Symset.make (map fst (Varset.dest var_set))
val elems = Symstab.dest tab
in
case find_first (fn (vars, thm) => Symset.subset (set, Symset.make vars)) elems of
SOME (_, thm) => SOME [thm]
| NONE => if mode >= 2 then
let
val all = Symset.make (flat (map fst elems))
in
if Symset.subset (set, all)
then SOME (map snd elems)
else NONE
end
else NONE
end
in
fun update_pres_cache_pattern pat var_set thm cache =
cache
|> get_mode cache >= 0 ?
map_tab (Termtab.map_default (pat, Symstab.empty) (Symstab.update (key var_set, thm)))
fun lookup_pres_cache_pattern export_info cache pat var_set =
if get_mode cache < 0 then (NONE, pat, cache)
else
case Termtab.lookup (get_tab cache) pat of
NONE => (NONE, pat, map_misses (fn i => i + 1) cache)
| SOME vars => (case Symstab.lookup vars (key var_set) of
NONE => (case find_superset (get_mode cache) vars var_set of
NONE => (NONE, pat, map_misses (fn i => i + 1) cache)
| SOME superset_thms =>
let
val (thm, cache) = case superset_thms of [thm] => (thm, cache)
| _ => (join_preserve_thms superset_thms, map_join (fn i => i + 1) cache)
val thm = (#weaken_superset export_info) var_set thm
val cache = map_superset (fn i => i + 1) cache
in (SOME thm, pat, update_pres_cache_pattern pat var_set thm cache) end)
| SOME thm => (SOME thm, pat, map_hits (fn i => i + 1) cache))
fun lookup_pres_cache ctxt export_info cache term var_set =
let
val pat = mk_pattern ctxt (#dummy_value export_info) (#dummy_init export_info) term
in
lookup_pres_cache_pattern export_info cache pat var_set
end
end
fun weaken_superset ctxt phi_export prog_info name_map var_set thm =
timeit_msg 2 ctxt (fn _ => "weaken_superset: ") (fn _ =>
let
val precond = mk_precond ctxt prog_info name_map var_set
fun i thm args = Drule.infer_instantiate' ctxt (map (SOME o Thm.cterm_of ctxt) args) thm
val weaken = i @{thm validE_weaken_dependent_same} [precond] OF [Morphism.thm phi_export thm]
val tac = SOLVES
(EVERY [CHANGED (asm_full_simp_tac ctxt 1),
REPEAT (eresolve_tac ctxt @{thms conjE} 1),
REPEAT (TRY (resolve_tac ctxt @{thms conjI} 1) THEN assume_tac ctxt 1)])
in
Utils.solve_sideconditions ctxt weaken (REPEAT1 tac)
end)
fun mk_preservation_proof_atomic ctxt export_info prog_info name_map var pat (cache : pres_cache) =
case lookup_pres_cache_pattern export_info cache pat (make_set [var]) of
(SOME t, _, cache) => (t, cache)
| (NONE, pat, cache) =>
let
fun s thm =
Utils.solve_sideconditions ctxt thm (TRY (REPEAT (CHANGED (asm_full_simp_tac ctxt 1))))
fun i thm args = Drule.infer_instantiate' ctxt (map (SOME o Thm.cterm_of ctxt) args) thm
val e = Morphism.thm (#phi_cache export_info)
val var_set = make_set [var]
val precond = mk_precond ctxt prog_info name_map var_set
val (thm, cache) =
(case pat of
(Const (@{const_name "L1_skip"}, _)) =>
(i @{thm L1_skip_lp_same_pre_post} [precond], cache)
| (Const (@{const_name "L1_init"}, _) $ f) =>
(i @{thm L1_init_lp_same_pre_post} [precond, f], cache)
| (Const (@{const_name "L1_modify"}, _) $ f) =>
(e (i @{thm L1_modify_lp_same_pre_post} [precond, f]), cache)
| (Const (@{const_name "L1_call"}, _) $ init $ n $ exit $ res_exn $ res_ret) =>
(e (i @{thm L1_call_lp_same_pre_post} [precond, res_ret, exit, res_exn, init, n]), cache)
| (Const (@{const_name "L1_guard"}, _) $ g) =>
(i @{thm L1_guard_lp_same_pre_post} [precond, g], cache)
| (Const (@{const_name "L1_throw"}, _)) =>
(i @{thm L1_throw_lp_same_pre_post} [precond], cache)
| (Const (@{const_name "L1_spec"}, _) $ _) =>
(i @{thm hoareE_TrueI} [precond, pat], cache)
| (Const (@{const_name "L1_assume"}, _) $ _) =>
(i @{thm hoareE_TrueI} [precond, pat], cache)
| (Const (@{const_name "L1_fail"}, _)) =>
(i @{thm L1_fail_lp} [precond], cache)
| other => error ("mk_preservation_proof_atomic does not handle compound statements: " ^ Syntax.string_of_term ctxt other))
val thm = s thm
in
(thm, update_pres_cache_pattern pat var_set thm cache)
end
fun mk_multivar_preservation_proof_atomic ctxt export_info prog_info name_map term var_set cache =
let
val (proofs, cache) = Utils.dep_timeit_msg 0 (Utils.threshold_msg (seconds 10.0)
(fn _ => "preservation_proof longrunning: " ^ Syntax.string_of_term ctxt term)) (fn _ =>
cache |>
fold_map (fn x => mk_preservation_proof_atomic ctxt export_info prog_info name_map x term)
(rev (dest_sort_extern prog_info var_set)))
val result = join_preserve_thms proofs
in
(result, cache)
end
handle Option => error ("Preservation proof failed for " ^ quote (@{make_string} var_set))
fun mk_multivar_preservation_proof ctxt (export_info: export_info) prog_info name_map term var_set (cache : pres_cache) =
if Varset.card var_set = 0 then (@{thm hoareE_TrueI}, cache) else
case lookup_pres_cache ctxt export_info cache term var_set of
(SOME t, _, cache) => (t, cache)
| (NONE, pat, cache) =>
let
val (thm, cache) =
(case pat of
(Const (@{const_name "L1_while"}, _) $ _ $ body) =>
let
val (body', cache) = mk_multivar_preservation_proof ctxt export_info prog_info name_map body var_set cache
in
(@{thm L1_while_lp_same_pre_post} OF [body'], cache)
end
| (Const (@{const_name "L1_condition"}, _) $ _ $ lhs $ rhs) =>
let
val (lhs', cache) = mk_multivar_preservation_proof ctxt export_info prog_info name_map lhs var_set cache
val (rhs', cache) = mk_multivar_preservation_proof ctxt export_info prog_info name_map rhs var_set cache
in
(@{thm L1_condition_lp_same_pre_post} OF [lhs', rhs'], cache)
end
| (Const (@{const_name "L1_seq"}, _) $ lhs $ rhs) =>
let
val (lhs', cache) = mk_multivar_preservation_proof ctxt export_info prog_info name_map lhs var_set cache
val (rhs', cache) = mk_multivar_preservation_proof ctxt export_info prog_info name_map rhs var_set cache
in
(@{thm L1_seq_lp_same_pre_post} OF [lhs', rhs'], cache)
end
| (Const (@{const_name "L1_catch"}, _) $ lhs $ rhs) =>
let
val (lhs', cache) = mk_multivar_preservation_proof ctxt export_info prog_info name_map lhs var_set cache
val (rhs', cache) = mk_multivar_preservation_proof ctxt export_info prog_info name_map rhs var_set cache
in
(@{thm L1_catch_lp_same_pre_post} OF [lhs', rhs'], cache)
end
| @{term_pat ‹L1_guarded _ (gets ?dest ⤜ ?body)›} =>
let
val (body', cache) = mk_multivar_preservation_proof ctxt export_info prog_info name_map body var_set cache
val thm = @{thm L1_guarded_lp_gets} OF [body']
in
(thm, cache)
end
| (abs as Abs (x, xT, body)) =>
let
val ([x], ctxt') = Utils.fix_variant_frees [(x,xT)] ctxt
fun strip_x (t as (bdy $ x')) = if x=x' then bdy else t
| strip_x t = t
val body = betapply(abs, x)
val (body', cache) = mk_multivar_preservation_proof ctxt' export_info prog_info name_map body var_set cache
val [body'] = Proof_Context.export ctxt' ctxt [body']
in
(body', cache)
end
| other =>
let
val {init, c, ...} = with_fresh_stack_ptr.match ctxt other
val Abs (p, pT, _) = c
val ((p, pT), bdy) = Term.dest_abs_fresh p c
val ([p], ctxt') = Utils.gen_fix_variant_frees true [(p, pT)] ctxt
val (bdy', cache) = mk_multivar_preservation_proof ctxt' export_info prog_info name_map bdy var_set cache
val [bdy'] = Proof_Context.export ctxt' ctxt [bdy']
val (rule::_) = Named_Theorems.get ctxt @{named_theorems with_fresh_stack_ptr_lp_same_pre_post} |> Utils.OFs [bdy']
val thm = solve_simp_sideconditions ctxt rule
in
(thm, cache)
end handle Match => mk_multivar_preservation_proof_atomic ctxt export_info prog_info name_map other var_set cache)
in
(thm, update_pres_cache_pattern pat var_set thm cache)
end
fun status_ty T =
if T = HP_TermsTypes.c_exntype_ty then @{typ exit_status} else T
fun exn_status_ty (n, T) =
if n = NameGeneration.global_exn_var_name then status_ty T else T
fun mk_l2monad ctxt (prog_info : ProgramInfo.prog_info) const ret throw params =
let
val retT = HOLogic.mk_tupleT (dest_sort_extern prog_info ret |> map snd)
val exT = HOLogic.mk_tupleT (dest_sort_extern prog_info throw |> map snd)
val monadT = AutoCorresData.mk_l2monadT (ProgramInfo.get_globals_type prog_info) retT exT
in
betapplys ((Const (const, (map fastype_of params) ---> monadT)), params)
end
fun abs_over_tuple_vars prog_info (name_map : (string * typ) -> term) (vars : varset) =
Utils.abs_over_tuple (map (fn (a, b) => (ProgramInfo.demangle_name prog_info a, name_map (a, b))) (dest_sort_extern prog_info vars))
fun inject_return_vals ctxt (export_info: export_info) prog_info name_map needed_returns allow_excess throw_vars
term (vars_read, vars_returned, output_monad, thm, cache) =
if needed_returns = vars_returned then
(vars_read, vars_returned, output_monad, thm, cache)
else if (allow_excess andalso Varset.subset (needed_returns, vars_returned)) then
(vars_read, vars_returned, output_monad, thm, cache)
else
let
val (l1_term, _, _) = get_node_data term
val injected_return =
mk_l2monad ctxt prog_info @{const_name L2_gets} needed_returns throw_vars
[absdummy (ProgramInfo.get_globals_type prog_info) (HOLogic.mk_tuple (dest_sort_extern prog_info needed_returns |> map name_map)),
var_set_to_isa_list ctxt prog_info needed_returns]
|> abs_over_tuple_vars prog_info name_map vars_returned
val generated_term = mk_l2monad ctxt prog_info @{const_name L2_seq}
needed_returns throw_vars [output_monad, injected_return]
val preserved_vals = needed_returns MINUS vars_returned
val generated_thm =
let
val (preserve_proof, cache) = mk_multivar_preservation_proof ctxt export_info prog_info name_map l1_term preserved_vals cache
in
mk_corresXF_thm_direct ctxt prog_info name_map needed_returns throw_vars (vars_read UNION preserved_vals)
generated_term l1_term
(@{thm L2corres_inject_return'} OF [thm, @{thm asm_rl} , @{thm validE_weaken} OF [preserve_proof]])
end
in
(vars_read UNION preserved_vals, needed_returns, generated_term, generated_thm, cache)
end
val corres_seq_split = Fun_Cache.create @{binding "L2corres_seq_split"}
(fn c => @{make_string} c) Inttab.empty Inttab.lookup Inttab.update
(Tuple_Tools.split_rule @{context} ["P'", "B"] @{thm L2corres_seq})
val corres_catch_split = Fun_Cache.create @{binding "L2corres_catch_split"}
(fn c => @{make_string} c) Inttab.empty Inttab.lookup Inttab.update
(Tuple_Tools.split_rule @{context} ["P'", "R"] @{thm L2corres_catch})
val corres_while_split = Fun_Cache.create @{binding "L2corres_while_split"}
(fn c => @{make_string} c) Inttab.empty Inttab.lookup Inttab.update
(Tuple_Tools.split_rule @{context} ["P'", "A"] @{thm L2corres_while})
val _ = map corres_seq_split (1 upto 5)
val _ = map corres_catch_split (1 upto 5)
val _ = map corres_while_split (1 upto 5)
fun trace_resolve_tac do_trace ctxt thm i st =
let
val _ = if do_trace then tracing ("trying to resolve: " ^ Thm.string_of_thm ctxt thm) else ()
in
resolve_tac ctxt [thm] i st
end
fun string_of_terms ctxt ts =
map (Syntax.string_of_term ctxt) ts |> Pretty.strs |> Pretty.string_of
fun dest_L2corres @{term_pat "L2corres ?st ?ret ?ex ?P ?new ?old"} = {st = st, ret = ret, ex = ex , P = P, new = new, old = old}
| dest_L2corres @{term_pat "Trueprop ?X"} = dest_L2corres X
| dest_L2corres t = raise TERM ("dest_L2corres", [t])
val dest_L2corres_funs = dest_L2corres #> (fn {old, new, ...} => {old = old, new = new})
val get_first_L2corres = AutoCorresUtil.get_first_corres {dest_corres_funs = dest_L2corres_funs}
val d1 = Unsynchronized.ref false
val d2 = Unsynchronized.ref false
val mk_L2corres_map_of_default_thm = AutoCorresUtil.mk_corres_map_of_default_thm {get_first_corres = get_first_L2corres}
fun dest_unit_abs (Abs (_, \<^Type>‹unit›, bdy)) = bdy
| dest_unit_abs x = raise TERM ("dest_unit_abs: ", [x])
val unit_range_eq = @{lemma "f x = ()" by auto}
fun do_conv
(ctxt : Proof.context)
(export_info: export_info)
skips
prog_info
(l1_infos : FunctionInfo.function_info Symtab.table)
(l1_call_info : FunctionInfo.call_graph_info)
name_map
fname
recursive_fun_ptrs
(fn_vars : varset)
(callee_proofs : (term * thm list) Symtab.table)
(grds: thm list)
(needed_vars : varset)
(allow_excess : bool)
(throw_vars : varset)
(term : (term * varset * varset, term option * varset * bool, (string * typ) list, unit) prog)
(cache : pres_cache)
: (varset * varset * term * thm * pres_cache) =
let
val l1_term = get_node_data term |> #1
val live_vars = get_node_data term |> #2
val modified_vars = get_node_data term |> #3
val inject =
inject_return_vals ctxt export_info prog_info name_map needed_vars allow_excess throw_vars term
fun mkthm read_vars ret_vars generated_term thm =
mk_corresXF_thm_direct ctxt prog_info name_map ret_vars throw_vars read_vars generated_term l1_term thm
val mk_monad = mk_l2monad ctxt prog_info
fun do_conv' ctxt = do_conv ctxt export_info skips prog_info l1_infos l1_call_info name_map fname recursive_fun_ptrs fn_vars callee_proofs
val do_conv = do_conv' ctxt
fun read_vars_of_call (Call (_, expr_f, expr_args, (ret_expr, ret_read_vars, _), ret_var, _)) =
union_sets (map #2 (expr_f::expr_args)) UNION (throw_vars MINUS exn_var)
| read_vars_of_call t = error ("read_vars_of_call: only works for call statements: " ^ @{make_string} t)
in
case term of
Init (_, [output_var]) =>
let
val _ = verbose_msg 3 ctxt (fn _ => "Init_SOME begin " ^ fname)
val start = Timing.start ();
val out_vars = make_set [output_var]
val generated_term = mk_monad @{const_name L2_unknown} out_vars throw_vars
[var_set_to_isa_list ctxt prog_info out_vars]
val thm = mkthm empty_set out_vars generated_term @{thm L2corres_spec_unknown}
in
inject (empty_set, out_vars, generated_term, thm, cache)
before (timing_msg' 2 ctxt (fn _ => "Init_SOME") start)
end
| Modify (_, (SOME expr, _, _), []) =>
let
val _ = verbose_msg 3 ctxt (fn _ => "Modify_SOME begin " ^ fname)
val start = Timing.start ();
val generated_term = mk_monad @{const_name L2_gets}
empty_set throw_vars [expr, var_set_to_isa_list ctxt prog_info empty_set]
val thm = mkthm empty_set empty_set generated_term @{thm L2corres_gets_skip}
in
inject (empty_set, empty_set, generated_term, thm, cache)
before (timing_msg' 2 ctxt (fn _ => "Modify_SOME") start)
end
| Modify (_, (NONE, _, _), [output_var]) =>
let
val _ = verbose_msg 3 ctxt (fn _ => "Modify_NONE begin " ^ fname)
val start = Timing.start ();
val out_vars = make_set [output_var]
val generated_term = mk_monad @{const_name L2_unknown} out_vars throw_vars []
val thm = mkthm empty_set out_vars generated_term @{thm L2corres_modify_unknown}
in
inject (empty_set, out_vars, generated_term, thm, cache)
before (timing_msg' 2 ctxt (fn _ => "Modify_NONE") start)
end
| Modify (_, (SOME expr, read_vars, _), [("globals'", _)]) =>
let
val _ = verbose_msg 3 ctxt (fn _ => "Modify_SOME_globals begin " ^ fname)
val start = Timing.start ();
val generated_term = mk_monad @{const_name L2_modify} empty_set throw_vars [expr]
val thm = mkthm read_vars empty_set generated_term @{thm L2corres_modify_global}
in
inject (read_vars, empty_set, generated_term, thm, cache)
before (timing_msg' 2 ctxt (fn _ => "Modify_SOME_globals") start)
end
| Modify (_, (SOME expr, read_vars, _), [output_var]) =>
let
val _ = verbose_msg 3 ctxt (fn _ => "Modify_SOME_SOME begin " ^ fname)
val start = Timing.start ();
val generated_term = mk_monad @{const_name L2_gets}
(make_set [output_var]) throw_vars [expr, var_set_to_isa_list ctxt prog_info (make_set [output_var])]
val thm = mkthm read_vars (make_set [output_var]) generated_term @{thm L2corres_modify_gets}
in
inject (read_vars, make_set [output_var], generated_term, thm, cache)
before (timing_msg' 2 ctxt (fn _ => "Modify_SOME_SOME") start)
end
| Throw _ =>
let
val _ = verbose_msg 3 ctxt (fn _ => "Throw begin " ^ fname)
val start = Timing.start ();
val generated_term = mk_monad @{const_name L2_throw} needed_vars throw_vars
[HOLogic.mk_tuple (dest_sort_extern prog_info throw_vars |> map name_map),
var_set_to_isa_list ctxt prog_info throw_vars]
val thm = mkthm throw_vars needed_vars generated_term @{thm L2corres_throw}
in
(throw_vars, needed_vars, generated_term, thm, cache)
before (timing_msg' 2 ctxt (fn _ => "Throw") start)
end
| Spec (_, (SOME expr, read_vars, _)) =>
let
val _ = verbose_msg 3 ctxt (fn _ => "Spec_SOME begin " ^ fname)
val start = Timing.start ();
val generated_term = mk_monad @{const_name "L2_spec"} needed_vars throw_vars [expr]
val thm = mkthm read_vars needed_vars generated_term @{thm L2corres_spec}
in
inject (read_vars, needed_vars, generated_term, thm, cache)
before (timing_msg' 2 ctxt (fn _ => "Spec_SOME") start)
end
| Spec (_, (NONE, _, _)) =>
let
val _ = verbose_msg 3 ctxt (fn _ => "Spec_NONE begin " ^ fname)
val start = Timing.start ();
val generated_term = mk_monad @{const_name "L2_fail"} needed_vars throw_vars []
val thm = mkthm empty_set needed_vars generated_term @{thm L2corres_fail}
in
inject (empty_set, needed_vars, generated_term, thm, cache)
before (timing_msg' 2 ctxt (fn _ => "Spec_NONE") start)
end
| Assume (_, (SOME expr, read_vars, _)) =>
let
val _ = verbose_msg 3 ctxt (fn _ => "Assume_SOME begin " ^ fname)
val start = Timing.start ();
val generated_term = mk_monad @{const_name "L2_assume"} needed_vars throw_vars [expr]
val thm = mkthm read_vars needed_vars generated_term @{thm L2corres_assume}
in
inject (read_vars, needed_vars, generated_term, thm, cache)
before (timing_msg' 2 ctxt (fn _ => "Assume_SOME") start)
end
| Assume (_, (NONE, _, _)) =>
let
val _ = verbose_msg 3 ctxt (fn _ => "Assume_NONE begin " ^ fname)
val start = Timing.start ();
val generated_term = mk_monad @{const_name "L2_fail"} needed_vars throw_vars []
val thm = mkthm empty_set needed_vars generated_term @{thm L2corres_fail}
in
inject (empty_set, needed_vars, generated_term, thm, cache)
before (timing_msg' 2 ctxt (fn _ => "Spec_NONE") start)
end
| Guard (_, (SOME expr, read_vars, _)) =>
let
val _ = verbose_msg 3 ctxt (fn _ => "Guard_SOME begin " ^ fname)
val start = Timing.start ();
val generated_term = mk_monad @{const_name "L2_guard"} empty_set throw_vars [expr]
val thm = mkthm read_vars empty_set generated_term @{thm L2corres_guard}
in
inject (read_vars, empty_set, generated_term, thm, cache)
before (timing_msg' 2 ctxt (fn _ => "Guard_SOME") start)
end
| Guard (_, (NONE, _, _)) =>
let
val _ = verbose_msg 3 ctxt (fn _ => "Guard_NONE begin " ^ fname)
val start = Timing.start ();
val generated_term = mk_monad @{const_name "L2_fail"} needed_vars throw_vars []
val thm = mkthm empty_set needed_vars generated_term @{thm L2corres_fail}
in
(empty_set, needed_vars, generated_term, thm, cache)
before (timing_msg' 2 ctxt (fn _=> "Guard_NONE") start)
end
| Fail _ =>
let
val _ = verbose_msg 3 ctxt (fn _ => "Fail begin " ^ fname)
val start = Timing.start ();
val generated_term = mk_monad @{const_name "L2_fail"} needed_vars throw_vars []
val thm = mkthm empty_set needed_vars generated_term @{thm L2corres_fail}
in
(empty_set, needed_vars, generated_term, thm, cache)
before (timing_msg' 2 ctxt (fn _ => "Fail") start)
end
| Seq (_, lhs, rhs) =>
let
val _ = verbose_msg 3 ctxt (fn _ => "Seq begin " ^ fname)
val start = Timing.start ();
val (_, rhs_live, rhs_modified) = get_node_data rhs
val (lhs_term, _, lhs_modified) = get_node_data lhs
val ret_vars = rhs_live INTER lhs_modified
val (lhs_reads, lhs_rets, new_lhs, lhs_thm, cache)
= do_conv grds ret_vars true throw_vars lhs cache
val (rhs_reads, rhs_rets, new_rhs, rhs_thm, cache)
= do_conv grds needed_vars allow_excess throw_vars rhs cache
val start_montage = Timing.start ();
val block_reads = lhs_reads UNION (rhs_reads MINUS lhs_modified)
val rhs_thm = timeit_msg 2 ctxt (fn _ => "Seq export rhs_thm: ")
(fn _ => Morphism.thm (#phi_export export_info) rhs_thm)
val new_rhs = timeit_msg 2 ctxt (fn _ => "Seq abs_over_tuple_vars: ")
(fn _ => abs_over_tuple_vars prog_info name_map lhs_rets new_rhs);
val generated_term = timeit_msg 2 ctxt (fn _ => "Seq mk_monad: ")
(fn _ => mk_monad @{const_name L2_seq} rhs_rets throw_vars [new_lhs, new_rhs])
val (thm, cache) =
let
val needed_preserves = (rhs_reads MINUS lhs_modified)
val (preserve_proof, cache) = timeit_msg 2 ctxt (fn _ => "preserve_proof: ") (fn _ =>
mk_multivar_preservation_proof ctxt export_info prog_info name_map lhs_term needed_preserves cache);
val weaken = timeit_msg 2 ctxt (fn _ => "weaken: ") (fn _ =>
@{thm validE_weaken} OF [preserve_proof])
val seq_split = corres_seq_split (Varset.card lhs_rets)
val seq = timeit_msg 2 ctxt (fn _ => "seq: ") (fn _ => seq_split OF [lhs_thm, rhs_thm, weaken])
in
timeit_msg 2 ctxt (fn _ => "Seq mkthm: ") (fn _ => (mkthm block_reads rhs_rets generated_term seq, cache))
before (timing_msg' 2 ctxt (fn _ => "Seq_montage") start_montage)
end
in
inject (block_reads, rhs_rets, generated_term, thm, cache)
before (timing_msg' 2 ctxt (fn _ => "Seq") start)
end
| Catch (_, lhs, rhs) =>
let
val _ = verbose_msg 3 ctxt (fn _ => "Catch begin " ^ fname)
val start = Timing.start ();
val (lhs_term, _, lhs_modified) = get_node_data lhs
val (_, rhs_live, _) = get_node_data rhs
val lhs_throws = rhs_live INTER lhs_modified
val (lhs_reads, lhs_rets, new_lhs, lhs_thm, cache)
= do_conv grds needed_vars false lhs_throws lhs cache
val (rhs_reads, _, new_rhs, rhs_thm, cache)
= do_conv grds needed_vars false throw_vars rhs cache
val block_reads = lhs_reads UNION (rhs_reads MINUS lhs_throws)
val rhs_thm = timeit_msg 2 ctxt (fn _ => "Catch export rhs_thm: ")
(fn _ => Morphism.thm (#phi_export export_info) rhs_thm)
val new_rhs = abs_over_tuple_vars prog_info name_map lhs_throws new_rhs
val generated_term = mk_monad @{const_name L2_catch} needed_vars throw_vars [new_lhs, new_rhs]
val (thm, cache) =
let
val needed_preserves = (rhs_reads MINUS lhs_modified)
val (preserve_proof, cache) = mk_multivar_preservation_proof ctxt export_info prog_info name_map lhs_term needed_preserves cache
val catch_split = corres_catch_split (Varset.card lhs_throws)
in
(mkthm block_reads needed_vars generated_term
(catch_split OF [lhs_thm, rhs_thm, @{thm validE_weaken} OF [preserve_proof]]), cache)
end
in
inject (block_reads, needed_vars, generated_term, thm, cache)
before (timing_msg' 2 ctxt (fn _ => "Catch") start)
end
| Condition (_, (SOME expr, read_vars, _), lhs, rhs) =>
let
val _ = verbose_msg 3 ctxt (fn _ => "Condition begin " ^ fname)
val start = Timing.start ();
val requested_vars = needed_vars INTER modified_vars
val (lhs_reads, _, new_lhs, lhs_thm, cache)
= do_conv grds requested_vars false throw_vars lhs cache
val (rhs_reads, _, new_rhs, rhs_thm, cache)
= do_conv grds requested_vars false throw_vars rhs cache
val block_reads = lhs_reads UNION rhs_reads UNION read_vars
val generated_term = mk_monad @{const_name "L2_condition"}
requested_vars throw_vars [expr, new_lhs, new_rhs]
val thm = mkthm block_reads requested_vars generated_term
(@{thm L2corres_cond} OF [lhs_thm, rhs_thm])
in
inject (block_reads, requested_vars, generated_term, thm, cache)
before (timing_msg' 2 ctxt (fn _ => "Condition") start)
end
| While (_, (SOME expr, read_vars, _), body) =>
let
val _ = verbose_msg 3 ctxt (fn _ => "While begin " ^ fname)
val start = Timing.start ();
val loop_iterators = (needed_vars UNION live_vars) INTER modified_vars
val (body_reads, _, new_body, body_thm, cache) =
do_conv grds loop_iterators false throw_vars body cache
val (body_term, body_live, body_modifies) = get_node_data body
val new_body = abs_over_tuple_vars prog_info name_map loop_iterators new_body
val body_thm = timeit_msg 2 ctxt (fn _ => "While export body_thm: ")
(fn _ => Morphism.thm (#phi_export export_info) body_thm)
val generated_term =
mk_monad @{const_name "L2_while"} loop_iterators throw_vars [
abs_over_tuple_vars prog_info name_map loop_iterators expr,
new_body,
HOLogic.mk_tuple (dest_sort_extern prog_info loop_iterators |> map name_map),
var_set_to_isa_list ctxt prog_info loop_iterators]
val (thm, cache) =
let
val needed_preserves = ((body_reads UNION read_vars) MINUS body_modifies)
val (preserve_proof, cache) = mk_multivar_preservation_proof ctxt export_info prog_info name_map body_term needed_preserves cache
val tracked_vars = (body_reads UNION read_vars UNION loop_iterators)
val invariant_precond = abs_over_tuple_vars prog_info name_map loop_iterators
(mk_precond ctxt prog_info name_map tracked_vars)
val while_split = corres_while_split (Varset.card loop_iterators)
val base_thm = Utils.named_cterm_instantiate ctxt [
("P", Thm.cterm_of ctxt invariant_precond)
] while_split
in
(mkthm (body_reads UNION read_vars UNION loop_iterators) loop_iterators generated_term
(base_thm OF [body_thm, @{thm validE_weaken} OF [preserve_proof]]), cache)
end
in
inject (body_reads UNION read_vars UNION loop_iterators, loop_iterators, generated_term, thm, cache)
before (timing_msg' 2 ctxt (fn _ => "While") start)
end
| Guarded ((_, guarded_reads, guarded_modifies), (SOME g, _, _), (dest_opt, _, _), bdy) =>
let
val _ = verbose_msg 3 ctxt (fn _ => "Guarded begin " ^ fname)
val start = Timing.start ();
val (_, bdy_reads0, bdy_modifies0) = get_node_data bdy
val bdy_reads = read_vars_of_call bdy
val @{term_pat "L1_guarded ?g' ?C"} = l1_term
val s' = dummy_state_guards_l1 prog_info
val s = dummy_state_guards_l2 prog_info
val bdy_precond = mk_precond ctxt prog_info name_map (bdy_reads0)
val st = ProgramInfo.get_globals_getter prog_info
val ([_, _], ctxt') = Utils.gen_fix_variant_frees true (map dest_Free [s', s]) ctxt
val ([g_thm, st_eq, bdy_precond_thm], ctxt') = ctxt'
|> Assumption.add_assumes (map (Thm.cterm_of ctxt') [
HOLogic.Trueprop $ (g $ s),
HOLogic.Trueprop $ (HOLogic.mk_eq (s, st $ s')),
HOLogic.Trueprop $ (bdy_precond $ s')])
val g_thms = Utils.split_conj g_thm
val (bdy_reads, bdy_rets, new_bdy, bdy_thm, cache)
= do_conv' ctxt' (grds @ [g_thm, st_eq, bdy_precond_thm]) (needed_vars) allow_excess throw_vars bdy cache
val @{term_pat "L2corres ?st ?ret ?ex ?P ?c ?c'"} = bdy_thm |> Thm.concl_of |> HOLogic.dest_Trueprop
val [bdy_thm] = Proof_Context.export ctxt' ctxt [bdy_thm]
in
case dest_opt of
SOME dest =>
let
val @{term_pat "(gets ?dest' ⤜ ?c0')"} = C
val p' = Envir.beta_eta_contract (dest' $ s')
val p = Envir.beta_eta_contract (dest $ s)
val c = Utils.abs_over "p" p c
val c' = Utils.abs_over "p" p' c'
val new_bdy = Utils.abs_over "p" p new_bdy
val ns = CLocals.name_hints ctxt ["p"]
val new_bdy = \<^infer_instantiate>‹dest = dest and bdy = new_bdy and ns=ns in term ‹L2_seq (L2_gets dest ns) bdy›› ctxt
val generated_term = mk_monad @{const_name "L2_guarded"} bdy_rets throw_vars [g, new_bdy]
val rule = @{thm L2corres_guarded_impl''} |> Drule.infer_instantiate' ctxt
(map (SOME o Thm.cterm_of ctxt) [g, st, bdy_precond, ret, ex, c, dest, c', dest', g', ns])
val thm = mkthm bdy_reads bdy_rets generated_term (rule OF [bdy_thm])
in
inject (bdy_reads, bdy_rets, generated_term, thm, cache)
before (timing_msg' 2 ctxt (fn _ => "Guarded") start)
end
| NONE =>
let
val generated_term = mk_monad @{const_name "L2_guarded"} bdy_rets throw_vars [g, new_bdy]
val rule = @{thm L2corres_guarded_impl} |> Drule.infer_instantiate' ctxt
(map (SOME o Thm.cterm_of ctxt) [g, st, bdy_precond, ret, ex, c, c', g'])
val inst_rule = rule OF [bdy_thm]
val thm = mkthm bdy_reads bdy_rets generated_term (inst_rule)
in
inject (bdy_reads, bdy_rets, generated_term, thm, cache)
before (timing_msg' 2 ctxt (fn _ => "Guarded") start)
end
end
| (call_t as Call ((_, call_reads, call_modifies), expr_f, expr_list, (ret_expr, ret_read_vars, _), ret_var, _)) =>
let
val _ = verbose_msg 3 ctxt (fn _ => "Call begin " ^ fname)
val start = Timing.start ();
val @{term_pat "_ ?arg_setup ?callee ?return_norm ?return_exn ?ret_extract"} = l1_term
val l2_callee_thms = Named_Theorems.get ctxt @{named_theorems "l2_corres"}
val (callee, is_fun_ptr, is_fun_ptr_param, ps_opt) = case callee of
(Const (@{const_name "L1_call_simpl"}, _) $ ct $ Gamma $ f') => (f', true, true, NONE)
| @{term_pat "map_of_default ?P ?ps ?f'"} => (f', true, false, SOME callee)
| _ => (callee, false, false, NONE)
val callee_scope = if is_fun_ptr then I else callee_scope prog_info callee
val arg_setup_vals = gen_parse_modify {read_scope = I, write_scope = callee_scope, two_state = false}
ctxt prog_info name_map arg_setup |> List.rev
val arg_rets = gen_parse_modify {read_scope = callee_scope, write_scope = I, two_state = true}
ctxt prog_info Free ret_extract |> map #2 |> flat
val (callee_trm, args, callee_thms, method_as_fun_ptr_param) =
if is_fun_ptr
then
let
val (map_of_default_new, map_of_default_thm) = mk_L2corres_map_of_default_thm ctxt l2_callee_thms (the ps_opt)
val args = arg_setup_vals |> map #1
val (_, _, SOME callee') = parse_expr ctxt prog_info name_map
(Utils.abs_over "s" (dummy_state_guards_l1 prog_info) callee)
val callee' = betapply(callee', dummy_state_guards_l2 prog_info)
in
(map_of_default_new $ callee', args, [map_of_default_thm], false)
end
else
let
val callee_name = Termtab.lookup (#const_to_function l1_call_info) (l1call_function_const callee) |> the
val callee' = Symtab.lookup l1_infos callee_name
val callee_proof = Option.mapPartial
(Symtab.lookup callee_proofs) (Option.map FunctionInfo.get_name callee')
val callee_method_callers = ProgramAnalysis.callers_via_method_or_non_param_of_fun_ptr_param_of (ProgramInfo.get_csenv prog_info) callee_name
val method_as_fun_ptr_param = member (op =) callee_method_callers fname
in
case callee_proof of
SOME (callee_free, callee_thms) =>
(callee_free,
map (apfst NameGeneration.Named) (FunctionInfo.get_plain_args (the callee')),
callee_thms,
method_as_fun_ptr_param)
| NONE =>
(case AutoCorresData.get_function_info (Context.Proof ctxt)
(ProgramInfo.get_prog_name prog_info) (FunctionInfo.L2) callee_name of
SOME callee_info => (FunctionInfo.get_const callee_info,
map (apfst NameGeneration.Named) (FunctionInfo.get_plain_args callee_info),
FunctionInfo.get_corres_thm callee_info |> single,
method_as_fun_ptr_param)
| NONE => error ("do_conv: could not retrieve callee theorem for: " ^
quote (Syntax.string_of_term ctxt callee)))
end
val arg_setup_vals =
map (fn (a, b, c, parsed_expr) =>
case parsed_expr of
NONE =>
raise Utils.InvalidInput ("Could not parse function parameter '" ^ @{make_string} (fst a) ^ "'")
| SOME x =>
(a, b, c, x)
) arg_setup_vals
val _ = if length arg_setup_vals <> length args then
raise TERM ("Argument list length does not match function definition.", [arg_setup])
else
()
fun new_name (NameGeneration.Named x) sfx = Variable.variant_fixes [suffix sfx x] ctxt |> fst |> hd
| new_name (NameGeneration.Positional (i, T)) sfx =
Variable.variant_fixes [suffix sfx (string_of_int i)] ctxt |> fst |> hd
val arg_setup_vals = map (fn ((a, T), b, c, d) => ((new_name a "'param", T), b, c, d)) arg_setup_vals
val args = map (Free o #1) arg_setup_vals
val call_args =
(betapplys (callee_trm, args))
val exn_var_term = name_map exn_name_type
fun mk_exn t = if t = exn_var_term then
\<^instantiate>‹e = exn_var_term in term ‹Nonlocal (the_Nonlocal e)› for e::‹exit_status CProof.c_exntype››
else t
val emb = Utils.abs_over exn_name exn_var_term (HOLogic.mk_tuple (dest_sort_extern prog_info throw_vars |> map (mk_exn o name_map )))
val (call, ret_vars) =
case (filter_out (fn x => x = exn_name_type) ret_var, ret_expr) of
([("globals'", _)], SOME e) =>
(mk_monad @{const_name L2_modifycall} empty_set throw_vars
[call_args, e, emb, CLocals.name_hints ctxt ["ret"]], empty_set)
| ([x], SOME e) =>
(mk_monad @{const_name L2_returncall} (make_set [x]) throw_vars
[call_args, e, emb, CLocals.name_hints ctxt [fst x]], make_set [x])
| ([], _) =>
(mk_monad @{const_name L2_voidcall} empty_set throw_vars
[call_args, emb, CLocals.name_hints ctxt ["ret"]], empty_set)
| _ => error ("LocalVarExtract.do_conv unexpected input for call")
val extractors = foldr (
fn ((updated_var, read_vars, is_globals_reader, expr), rest) =>
let
val ret_type = (make_set [("x'", fastype_of expr |> body_type)])
val rest_type = (make_set [("x'", AutoCorresData.res_type_of_exn_monad rest)])
val getter = mk_monad @{const_name L2_folded_gets} ret_type throw_vars
[expr,
var_set_to_isa_list (callee_scope ctxt) prog_info
(make_set [apfst (unsuffix "'param") updated_var])]
in
mk_monad @{const_name "L2_seq"} rest_type throw_vars [
getter,
Utils.abs_over (fst updated_var) (Free updated_var) rest]
end
)
call
arg_setup_vals
val read_vars = call_reads;
val my_debug_tac = if !d1 then print_tac ctxt else fn _ => all_tac
val L2_call_thms = @{thms L2corres_returncall L2corres_voidcall L2corres_modifycall}
val _ = if (!d1) then tracing ("is_fun_ptr, method_as_fun_ptr_param: " ^ @{make_string} (is_fun_ptr, method_as_fun_ptr_param)) else ()
fun callee_tac ctxt thms =
let
val _ = if not (!d1) then () else tracing (big_list_of_thms "callee_tac thms: " ctxt thms)
in
let
in
SOLVES_debug ctxt ("callee_tac (2): " ^ fname) (
REPEAT1 (EVERY [
my_debug_tac "after only_fun_ptr_simps",
resolve_tac ctxt thms 1, my_debug_tac "after resolve"]))
end
end
val all_callee_thms = callee_thms @
Named_Theorems.get ctxt @{named_theorems "l2_corres"}
val grds' = map (Simplifier.asm_full_simplify (ctxt addsimps @{thms More_Lib.pred_conj_def})) grds
val _ = if !d1 then tracing (big_list_of_thms "grds': " ctxt grds') else ()
fun dtrace ctxt = if !d2 then Config.put Simplifier.simp_trace true ctxt else ctxt
val thm =
mk_corresXF_thm ctxt prog_info name_map ret_vars throw_vars read_vars extractors l1_term (fn {context=ctxt, ...} =>
(my_debug_tac "unfold folded_gets"
THEN (REPEAT (resolve_tac ctxt @{thms L2corres_folded_gets} 1)))
THEN (my_debug_tac "propagate fixed function pointer parameters"
THEN (asm_full_simp_tac (dtrace ctxt addsimps (@{thms More_Lib.pred_conj_def} @ grds')
|> Simplifier.add_cong @{thm L2corres_l2_propagate_fixed_cong''} ) 1)
THEN (my_debug_tac "apply callee proof")
THEN FIRST (map_index (fn (i, thm) =>
trace_resolve_tac false ctxt thm 1 THEN
my_debug_tac ("resolved L2_call_thms (" ^ string_of_int i ^ ")") THEN
callee_tac ctxt all_callee_thms)
L2_call_thms))
THEN (my_debug_tac "final simp"
THEN (REPEAT (CHANGED (asm_full_simp_tac ctxt 1))))
)
in
inject (read_vars, ret_vars, extractors, thm, cache)
before (timing_msg' 2 ctxt (fn _ => "Call") start)
end
| Exec_Spec_Monad ((t, read_vars, ret_vars), arg_exprs, Y) =>
let
val _ = verbose_msg 3 ctxt (fn _ => "Exec_Spec_Monad begin " ^ fname)
val start = Timing.start ();
val @{term_pat ‹L1_exec_spec_monad ?upd_x ?st ?args ?f ?res›} = t
val args' = arg_exprs |> map (fn (SOME t, reads, _) => (t, hd (Varset.dest reads))
| x => error ("Exec_Spec_Monad unexpected: " ^ (@{make_string} x)))
val abs_g = Tuple_Tools.dest_case_prod_abs_body f
val app_g =
if null args' then
dest_unit_abs abs_g
else
fold_rev (fn i => fn x => betapply (x, Bound i)) (0 upto (length args' - 1)) abs_g
val sT = range_type (fastype_of st)
val st' = case st of
@{term_pat globals} => \<^Const>‹id sT›
| @{term_pat "(?lift o globals)"} => lift
| @{term_pat "λs. ?lift (globals s)"} => lift
| t => error ("Exec_Spec_Monad: unexpected state lifting" ^ Syntax.string_of_term ctxt t)
val l2_exec = mk_monad @{const_name "L2_exec_spec_monad"} ret_vars throw_vars [st', app_g]
val l2 = l2_exec |> fold_rev (fn (expr, var as (name, T)) => fn t =>
let
val ret = (make_set [("x'", T)])
val ret_t = (make_set [("x'", AutoCorresData.res_type_of_exn_monad t)])
val get = mk_monad @{const_name L2_folded_gets} ret throw_vars
[expr, var_set_to_isa_list ctxt prog_info (make_set [var])]
in mk_monad @{const_name "L2_seq"} ret_t throw_vars [get, Abs (name, T, t)] end) args'
val thm = mk_corresXF_thm ctxt prog_info name_map ret_vars throw_vars read_vars l2 t (fn {context=ctxt, ...} =>
simp_tac (Simplifier.clear_simpset ctxt addsimps @{thms L2_remove_scaffolding_1}) 1 THEN
match_tac ctxt @{thms L2corres_exec_spec_monad_globals' L2corres_exec_spec_monad'} 1 THEN
ALLGOALS (asm_full_simp_tac (ctxt addsimps @{thms refines_right_eq_id} @ [unit_range_eq])))
in
inject (read_vars, ret_vars, l2, thm, cache)
before (timing_msg' 2 ctxt (fn _ => "Exec_Spec_Monad") start)
end
| Stack ((t, _,_), (SOME expr, read_vars, _), bdy) =>
let
val _ = verbose_msg 3 ctxt (fn _ => "Stack begin " ^ fname)
val start = Timing.start ();
val requested_vars = needed_vars UNION read_vars
val {n, init, c, ...} = with_fresh_stack_ptr.match ctxt t
val Abs (pn, pT, _) = c
val sT = fastype_of init |> domain_type
val ([p], ctxt') = Utils.gen_fix_variant_frees true [(pn, pT)] ctxt
val (bdy_reads, _, new_bdy, bdy_thm, cache)
= do_conv' ctxt' grds requested_vars false throw_vars bdy cache
val [bdy_thm] = [bdy_thm] |> Proof_Context.export ctxt' ctxt
val rules = Named_Theorems.get ctxt @{named_theorems L2corres_with_fresh_stack_ptr}
|> Utils.OFs [bdy_thm]
val with_fresh_stack_ptr = with_fresh_stack_ptr.term ctxt (ProgramInfo.get_globals_type prog_info)
val new_bdy = Term.lambda_name (pn, p) new_bdy
val name_hint = CLocals.name_hint ctxt (TermsTypes.dest_local_ptr_name pn)
|> single |> HOLogic.mk_list \<^typ>‹nat›
val generated_term = \<^infer_instantiate>‹w=with_fresh_stack_ptr and init = expr and c = new_bdy and
nm = name_hint and n=n
in term ‹w n init (L2_VARS c nm)›› ctxt
val block_reads = bdy_reads UNION read_vars
val thm = mkthm block_reads requested_vars generated_term (hd rules)
in
inject (block_reads, requested_vars, generated_term, thm, cache)
before (timing_msg' 2 ctxt (fn _ => "Stack") start)
end
| _ => Utils.invalid_input "a parsed L1 term" (l1_term |> head_of |> @{make_string})
end
val internalN = "lvar'"
val internal_name = prefix internalN
fun get_expected_l2_fn_type prog_info l1_infos fn_name =
let
val (args, retT) = the (Symtab.lookup l1_infos fn_name)
|> (fn info => (FunctionInfo.get_plain_args info, FunctionInfo.get_return_type info))
val fn_params_typ = map (#2) args
in
fn_params_typ ---> AutoCorresData.mk_l2monadT (ProgramInfo.get_globals_type prog_info) retT HP_TermsTypes.c_exntype_ty
end
fun get_expected_l2_fn_args lthy prog_info l1_infos fn_name =
let
val args = the (Symtab.lookup l1_infos fn_name) |> FunctionInfo.get_plain_args
in
map (apfst (ProgramInfo.demangle_name prog_info)) args
end
fun mk_fn_ptr_infos ctxt prog_info fn_args info =
AutoCorresData.mk_fn_ptr_infos ctxt prog_info {ts_monad_name=""} fn_args info
fun get_l2_corres_prop skips prog_info l1_infos ctxt assume fn_name fn_free fn_args =
let
val ctxt = HPInter.enter_scope (ProgramInfo.get_prog_name prog_info) fn_name ctxt
val (input_params, output_params) = get_fn_input_output_vars l1_infos fn_name
val l2_corres_attr = AutoCorresData.corres_thm_attribute (ProgramInfo.get_prog_name prog_info) skips FunctionInfo.L2 fn_name
val (l1_fun, l2_fun, args, l1_props) = the (Symtab.lookup l1_infos fn_name) |> (fn info =>
(FunctionInfo.get_const info, fn_free, FunctionInfo.get_plain_args info, []))
val args = map fst args
val m = Symtab.make (args ~~ fn_args)
fun name_map (n, _) = Symtab.lookup m n |> the
in
( (Logic.list_implies ([],
mk_corresXF_prop ctxt prog_info name_map
output_params exn_var input_params
(betapplys (l2_fun, fn_args))
(l1_fun)), [l2_corres_attr]),
NONE)
end
fun get_body_of_thm ctxt thm =
Thm.concl_of (Variable.gen_all ctxt thm)
|> HOLogic.dest_Trueprop
|> dest_L2corres_term_abs
fun get_l2corres_thm ctxt skips prog_info l1_infos fn_ptr_infos l1_call_info L2_opt trace_opt fn_name
callee_terms fn_args l1_term init_unfold = let
val ctxt = HPInter.enter_scope (ProgramInfo.get_prog_name prog_info) fn_name ctxt
val fn_info = the (Symtab.lookup l1_infos fn_name)
val (fn_input_vars, fn_local_vars, fn_ret_vars) = get_variables l1_infos fn_name
val m = Symtab.make (map fst (FunctionInfo.get_plain_args fn_info) ~~ fn_args)
fun name_map_ext (n, T) = Symtab.lookup m n |> the
val fn_ptr_param_map = fn_ptr_infos
|> map (fn (n, info) => (NameGeneration.un_varname n, #1 (#ptr_val (info FunctionInfo.L2))))
|> AList.lookup (op =)
fun remove_fn_ptr_params vars = vars
|> filter_out (is_some o fn_ptr_param_map o fst)
val fn_input_vars_wo_fn_ptr_params = fn_input_vars |> Varset.dest |> remove_fn_ptr_params |> Varset.make
fun name_map_internal (n, T) =
case fn_ptr_param_map n of
SOME n' => Free (n', T)
| NONE => Free (internal_name n, T)
val init_rule = Thm.cterm_of ctxt l1_term
|> Conv.rewr_conv (safe_mk_meta_eq init_unfold)
val source_term = Thm.concl_of init_rule |> Utils.rhs_of_eq
val parsed_term = parse_l1 ctxt prog_info l1_infos l1_call_info name_map_internal source_term
val _ = verbose_msg 4 ctxt (fn _ => "parsed_term " ^ fn_name ^ ": " ^ @{make_string}
(parsed_term |> Prog.map_prog (fn _ => Bound 0) I I I))
val all_vars = Prog.fold_prog
(K I)
(fn (_, vars, _) => fn old_vars => vars UNION old_vars)
(fn mod_var => fn old_vars => mod_var UNION old_vars)
(K I)
parsed_term empty_set
val liveness_data = calc_live_vars exn_var parsed_term (union_sets [fn_ret_vars]) exn_var
val _ = verbose_msg 3 ctxt (fn _ => "liveness_data " ^ fn_name ^ ": " ^ @{make_string} liveness_data)
val modification_data =
get_modified_vars parsed_term
|> map_prog (fn x => Option.getOpt (x, all_vars)) I I I
fun zip_node_data a b c =
zip_progs a (zip_progs b c)
|> map_prog (fn (a, (b, c)) => (a, b, c)) fst (dest_sort_extern prog_info o fst) fst
val input_term = zip_node_data parsed_term liveness_data modification_data
val fn_inputs = get_node_data liveness_data
val fn_params = FunctionInfo.get_plain_args fn_info
val excess_inputs = fn_inputs MINUS (make_set fn_params)
val _ =
if excess_inputs <> empty_set then
warning
("Input function '" ^ fn_name ^ "' has unresolved variables: "
^ @{make_string} (dest_sort_extern prog_info excess_inputs))
else
()
val all_vars = map Varset.dest [fn_input_vars_wo_fn_ptr_params, fn_local_vars, fn_ret_vars, exn_var] |> flat
val all_vars_internal = map (fn (n,T) => (Term.dest_Free (name_map_internal (n,T)))) all_vars;
val _ = verbose_msg 2 ctxt (fn _ => "all_vars_internal: " ^ @{make_string} all_vars_internal)
val (all_vars_internal', ctxt_internal) = Utils.fix_variant_frees all_vars_internal ctxt;
val _ = verbose_msg 2 ctxt (fn _ => "all_vars_internal': " ^ @{make_string} all_vars_internal')
val phi_import = perhaps (dest_Free #> fst
#> AList.lookup (op =) (map fst all_vars_internal ~~ all_vars_internal'))
val name_map_internal = phi_import o name_map_internal
val ([dummy_value, dummy_init], ctxt') = Variable.variant_fixes ["dummy_val", "dummy_init"] ctxt_internal
val phi_export = Variable.export_morphism ctxt' ctxt
val phi_cache = Variable.export_morphism ctxt' ctxt_internal
fun assert_fixed ctxt name_map = fn x =>
let
val res as Free(n, T) = name_map x
val _ = if Variable.is_fixed ctxt n then () else
error ("unexpected local variable: " ^ quote n ^ " for " ^ @{make_string} x)
in res end;
val checked_name_map_internal = assert_fixed ctxt' name_map_internal
val export_info = {phi_export = phi_export, phi_cache = phi_cache,
dummy_value = dummy_value, dummy_init = dummy_init,
weaken_superset = weaken_superset ctxt' phi_export prog_info checked_name_map_internal}
val ctxt_ss = setup_l2_ss HOL_basic_ss ctxt'
val _ = verbose_msg 2 ctxt' (fn _ => "input_term " ^ fn_name ^ ": " ^ Syntax.string_of_term ctxt' (#1 (get_node_data input_term)))
val _ = verbose_msg 2 ctxt' (fn _ => "input_term (raw) " ^ fn_name ^ ": " ^ @{make_string} (#1 (get_node_data input_term)))
val (_, _, term, thm, cache) = timeit_msg 1 ctxt_ss (fn _ => "Conversion L2 (do_conv) " ^ fn_name ^ ": ") (fn _ =>
do_conv ctxt_ss export_info skips prog_info l1_infos l1_call_info checked_name_map_internal
fn_name [] fn_input_vars
callee_terms [] fn_ret_vars false exn_var input_term (pres_cache_empty ctxt'));
val [thm] = Variable.export ctxt' ctxt [thm];
val _ = verbose_msg 0 ctxt' (fn _ => "preservation_cache (hits: " ^ string_of_int (get_hits cache) ^
", misses: " ^ string_of_int (get_misses cache) ^
", superset: " ^ string_of_int (get_superset cache) ^
", join: " ^ string_of_int (get_join cache) ^
", mode: " ^ string_of_int (get_mode cache) ^ ")")
val fn_params_wo_fn_ptr = remove_fn_ptr_params fn_params
val replacements = (map (dest_Var o Morphism.term phi_export o name_map_internal) fn_params_wo_fn_ptr) ~~
(map (Thm.cterm_of ctxt o name_map_ext) fn_params_wo_fn_ptr)
val inst_extern = AList.lookup (op =) replacements
val thm_folded = timeit_msg 2 ctxt (fn _ => "fold: ") (fn _ => Local_Defs.fold ctxt [init_rule] thm);
val new_thm = Utils.instantiate_thm_vars ctxt inst_extern thm_folded
val @{term_pat "Trueprop (L2corres _ _ _ ?precond _ _)"} = Thm.prop_of new_thm
val canonical_precond = mk_precond ctxt prog_info name_map_ext fn_input_vars
val generalize_precond = \<^infer_instantiate>‹P = canonical_precond and Q = precond in prop ‹pred_imp P Q›› ctxt
val generalize_precond_thm = Goal.prove ctxt [] [] generalize_precond (fn {context, ...} => EVERY [
resolve_tac context @{thms pred_impI} 1,
TRY (resolve_tac context @{thms TrueI} 1),
REPEAT (eresolve_tac context @{thms pred_andE} 1),
REPEAT (TRY (resolve_tac context @{thms pred_andI} 1) THEN assume_tac context 1)])
val new_thm = @{thm L2corres_guard_imp} OF [new_thm, generalize_precond_thm]
fun corres_prog_conv conv = Conv.fconv_rule (Utils.remove_meta_conv (fn ctxt =>
Utils.nth_arg_conv 5 (conv ctxt)) ctxt)
val new_thm = new_thm |> corres_prog_conv (fn ctxt =>
Simplifier.rewrite_wrt ctxt false @{thms L2_remove_scaffolding_1}
then_conv
Simplifier.rewrite_wrt ctxt false @{thms L2_remove_scaffolding_2})
val _ = writeln ("Simplifying (L2opt) " ^ fn_name)
val _ = verbose_msg 1 ctxt (fn _ => "L2 (raw) - " ^ fn_name ^ ": " ^ Thm.string_of_thm ctxt new_thm)
val cleanup_del = @{thms ptr_coerce.simps ptr_add_0_id}
val fn_ptr_guard_simps = callee_terms |> Symtab.dest |> map (#2 o #2) |> flat
val ctxt = ctxt |> AutoCorresTrace.put_trace_info fn_name FunctionInfo.L2 FunctionInfo.PEEP;
val new_thm = timeit_msg 1 ctxt (fn _ => "Simplification (L2opt): " ^ fn_name) (fn _ =>
L2Opt.cleanup_thm_tagged prog_info (ctxt delsimps cleanup_del) fn_ptr_guard_simps []
(SOME map_of_default_args.unfold_map_of_default_conv) new_thm
L2_opt 5 trace_opt FunctionInfo.L2)
val _ = verbose_msg 1 ctxt (fn _ => "L2 (L2opt) - " ^ fn_name ^ ": " ^ Thm.string_of_thm ctxt new_thm)
val _ = writeln ("Introduce nested exceptions (L2exn) " ^ fn_name)
val _ = Utils.verbose_fn 2 ctxt (fn _ => Synthesize_Rules.print_rules (Context.Proof ctxt) @{synthesize_rules_name L2_rel_spec_monad} NONE)
val new_thm = timeit_msg 1 ctxt (fn _ => "Nested exceptions (L2exn): " ^ fn_name) (fn _ =>
new_thm
|> corres_prog_conv (fn ctxt => (L2_Exception_Rewrite.abstract_try_catch_conv ctxt))
)
val _ = verbose_msg 1 ctxt (fn _ => "L2 (L2exn) - " ^ fn_name ^ ": " ^ Thm.string_of_thm ctxt new_thm)
val _ = writeln ("Remove unused tuple components (L2prj) " ^ fn_name)
val new_thm = timeit_msg 1 ctxt (fn _ => "Remove unused tuple components (L2prj): " ^ fn_name) (fn _ =>
new_thm
|> corres_prog_conv (fn ctxt => (L2_Exception_Rewrite.project_used_components_conv ctxt))
)
val _ = verbose_msg 1 ctxt (fn _ => "L2 (L2prj) - " ^ fn_name ^ ": " ^ Thm.string_of_thm ctxt new_thm)
in
new_thm
end
fun mk_l2corres_call_simpl_thm prog_info l1_infos ctxt fn_name fn_args = let
val fn_def = the (Symtab.lookup l1_infos fn_name)
val const = FunctionInfo.get_const fn_def
val args = FunctionInfo.get_plain_args fn_def
val f_info = Utils.the' ("L2 conversion missing info for " ^ fn_name)
(Symtab.lookup l1_infos fn_name);
val (fn_input_vars, fn_ret_vars) = get_fn_input_output_vars l1_infos fn_name
val m = Symtab.make (map fst args ~~ fn_args)
fun name_map_ext (n, T) = Symtab.lookup m n |> the
val arg_xf = mk_precond ctxt prog_info name_map_ext fn_input_vars
val ret_xf = mk_xf ctxt prog_info fn_ret_vars
val ex_xf = mk_xf ctxt prog_info exn_var
val thm = Utils.named_cterm_instantiate ctxt
(map (apsnd (Thm.cterm_of ctxt))
[("l1_f", betapply (const, Free ("rec_measure'", @{typ "nat"}))),
("ex_xf", ex_xf), ("gs", ProgramInfo.get_globals_getter prog_info),
("ret_xf", ret_xf), ("arg_xf", arg_xf)])
@{thm L2corres_L2_call_simpl}
OF [FunctionInfo.get_definition fn_def]
in thm end
fun insert_fn_ptr name =
Varset.insert (name, @{typ "unit ptr"})
val insert_fn_ptrs = fold insert_fn_ptr
fun convert
(lthy: local_theory)
(skips: FunctionInfo.skip_info)
(prog_info: ProgramInfo.prog_info)
(l1_infos: FunctionInfo.function_info Symtab.table)
(L2_opt: FunctionInfo.stage)
(trace_opt: bool)
(l2_function_name: string -> string)
(f_name: string)
: AutoCorresUtil.convert_result = let
val (l1_call_info, l1_infos) = FunctionInfo.calc_call_graph l1_infos;
val f_info = Utils.the' ("L2 conversion missing info for " ^ f_name)
(Symtab.lookup l1_infos f_name);
val callee_names = FunctionInfo.all_callees f_info;
val _ = filter (fn f => not (is_some (Symtab.lookup l1_infos f))) (Symset.dest callee_names)
|> (fn bad => if null bad then () else
error ("L2 conversion missing callees for " ^ f_name ^ ": " ^ commas bad));
val f_args = map (apfst (ProgramInfo.demangle_name prog_info)) (FunctionInfo.get_plain_args f_info);
val (arg_frees, lthy') = Utils.fix_variant_frees f_args lthy;
val fn_ptr_infos = mk_fn_ptr_infos lthy prog_info arg_frees f_info
val rec_clique = FunctionInfo.get_recursive_clique f_info
val (lthy'', callee_terms) =
AutoCorresUtil.assume_called_functions_corres lthy'
rec_clique
(get_expected_l2_fn_type prog_info l1_infos)
(get_l2_corres_prop skips prog_info l1_infos)
(get_expected_l2_fn_args lthy prog_info l1_infos)
l2_function_name;
val f_l1_def = FunctionInfo.get_definition f_info
val thm =
if FunctionInfo.get_is_simpl_wrapper f_info
then mk_l2corres_call_simpl_thm prog_info l1_infos lthy'' f_name arg_frees
else get_l2corres_thm lthy'' skips prog_info l1_infos fn_ptr_infos l1_call_info L2_opt trace_opt f_name
(Symtab.make callee_terms) arg_frees
(FunctionInfo.get_const f_info)
f_l1_def;
val f_body = dest_L2corres_term_abs (HOLogic.dest_Trueprop (Thm.concl_of thm));
val rec_callees = AutoCorresUtil.get_rec_callees callee_terms f_body;
val callee_consts =
callee_terms |> map (fn (callee, (const, _)) => (callee, const)) |> Symtab.make;
in
{ body = f_body,
proof = hd (Proof_Context.export lthy'' lthy [thm]),
rec_callees = rec_callees,
callee_consts = callee_consts,
arg_frees = map dest_Free arg_frees
}
end
fun define
(skips: FunctionInfo.skip_info)
(prog_info: ProgramInfo.prog_info)
(l2_function_name: string -> string)
(funcs: AutoCorresUtil.convert_result Symtab.table)
(lthy: local_theory)
: local_theory =
let
val l1_infos = AutoCorresData.get_default_phase_info (Context.Proof lthy) (ProgramInfo.get_prog_name prog_info) FunctionInfo.L1
val funcs' = Symtab.dest funcs |>
map (fn result as (name, {proof, arg_frees, ...}) =>
(name, (AutoCorresUtil.abstract_fn_body l1_infos result,
proof, arg_frees)));
val clique = map fst funcs'
val (new_thms, lthy) =
AutoCorresUtil.define_funcs
skips
FunctionInfo.L2 prog_info I {concealed_named_theorems=false} l2_function_name
(get_expected_l2_fn_type prog_info l1_infos)
(get_l2_corres_prop skips prog_info l1_infos)
(get_expected_l2_fn_args lthy prog_info l1_infos)
funcs'
lthy;
in lthy end;
fun translate
(skips: FunctionInfo.skip_info)
(base_locale_opt: string option)
(prog_info: ProgramInfo.prog_info)
(L2_opt: FunctionInfo.stage)
(trace_opt: bool)
(parallel: bool)
(cliques: string list list)
(lthy: local_theory)
: string list list * local_theory =
let
val phase = FunctionInfo.L2
val l2_function_name = ProgramInfo.get_mk_fun_name prog_info phase
fun define_worker lthy f_convs =
define skips prog_info (l2_function_name "") f_convs lthy;
in
lthy |>
AutoCorresUtil.convert_and_define_cliques skips base_locale_opt prog_info
phase parallel
(fn lthy => fn l1_infos => convert lthy skips prog_info l1_infos L2_opt trace_opt (l2_function_name ""))
define_worker cliques
end
end