Theory Collections.Locale_Code
section ‹Code Generation from Locales›
theory Locale_Code
imports ICF_Tools Ord_Code_Preproc
begin
text ‹
Provides a simple mechanism to prepare code equations for
constants stemming from locale interpretations.
The usage pattern is as follows:
‹setup Locale_Code.checkpoint› is called before a series of
interpretations, and afterwards, ‹setup Locale_Code.prepare›
is called. Afterwards, the code generator will correctly recognize
expressions involving terms from the locale interpretation.
›
text ‹Tag to indicate pattern deletion›
definition LC_DEL :: "'a ⇒ unit" where "LC_DEL a ≡ ()"
ML ‹
signature LOCALE_CODE = sig
type pat_eq = cterm * thm list
val open_block: theory -> theory
val close_block: theory -> theory
val del_pat: cterm -> theory -> theory
val add_pat_eq: cterm -> thm list -> theory -> theory
val lc_decl_eq: thm list -> local_theory -> local_theory
val lc_decl_del: term -> local_theory -> local_theory
val setup: theory -> theory
val get_unf_ss: theory -> simpset
val tracing_enabled: bool Unsynchronized.ref
end
structure Locale_Code :LOCALE_CODE = struct
open ICF_Tools
val tracing_enabled = Unsynchronized.ref false;
type pat_eq = cterm * thm list
type block_data = {idx:int, del_pats: cterm list, add_pateqs: pat_eq list}
val closed_block = {idx = ~1, del_pats=[], add_pateqs=[]};
fun init_block idx = {idx = idx, del_pats=[], add_pateqs=[]};
fun is_open ({idx,...}:block_data) = idx <> ~1;
fun assert_open bd
= if is_open bd then () else error "Locale_Code: No open block";
fun assert_closed bd
= if is_open bd then error "Locale_Code: Block already open" else ();
fun merge_bd (bd1,bd2) = (
if is_open bd1 orelse is_open bd2 then
error "Locale_Code: Merge with open block"
else ();
closed_block
);
fun bd_add_del_pats ps {idx,del_pats,add_pateqs}
= {idx = idx, del_pats = ps@del_pats, add_pateqs = add_pateqs};
fun bd_add_add_pateqs pes {idx,del_pats,add_pateqs}
= {idx = idx, del_pats = del_pats, add_pateqs = pes@add_pateqs};
structure BlockData = Theory_Data (
type T = block_data
val empty = (closed_block)
val merge = merge_bd
);
structure FoldSSData = Oc_Simpset (
val prio = 5;
val name = "Locale_Code";
);
fun add_unf_thms thms thy = let
val ctxt = Proof_Context.init_global thy
val thms = map Thm.symmetric thms
in
FoldSSData.map (fn ss =>
put_simpset ss ctxt
|> sss_add thms
|> simpset_of
) thy
end
val get_unf_ss = FoldSSData.get;
fun match_fixed_head (pat,obj) = let
val inst = Thm.first_order_match (chead_of pat, chead_of obj);
val pat = Thm.instantiate_cterm inst pat;
val inst = Thm.first_order_match (pat, obj);
in inst end;
val matches_fixed_head = can match_fixed_head;
fun match_heads (pat,obj) = Thm.first_order_match (chead_of pat, chead_of obj);
val matches_heads = can match_heads;
val pat_nargs = Thm.term_of #> strip_comb #> #2 #> length;
fun norm_thm_pat (thm,pat) = let
val thm = norm_def_thm thm;
val na_pat = pat_nargs pat;
val lhs = Thm.lhs_of thm;
val na_lhs = pat_nargs lhs;
val lhs' = if na_lhs > na_pat then funpow (na_lhs - na_pat) Thm.dest_fun lhs
else lhs;
val inst = Thm.first_order_match (lhs',pat);
in Thm.instantiate inst thm end;
fun del_pat_matches cpat (epat,_) = if pat_nargs cpat = 0 then
matches_heads (cpat,epat)
else
matches_fixed_head (cpat,epat);
local
datatype action = ADD of (cterm * thm list)
| DEL of cterm
fun filter_pat_eq thy thms pat = let
val cpat = Thm.global_cterm_of thy pat;
in
if (pat_nargs cpat = 0) then NONE
else let
val thms' = fold
(fn thm => fn acc => case try norm_thm_pat (thm, cpat) of
NONE => acc | SOME thm => thm::acc
) thms [];
in case thms' of [] => NONE | _ => SOME (ADD (cpat,thms')) end
end;
fun process_actions acc [] = acc
| process_actions acc (ADD peq::acts) = process_actions (peq::acc) acts
| process_actions acc (DEL cpat::acts) = let
val acc' = filter (not o curry renames_cterm cpat o fst) acc;
val _ = if length acc = length acc' then
warning ("Locale_Code: LC_DEL without effect: "
^ @{make_string} cpat)
else ();
in process_actions acc' acts end;
fun pat_eqs_of_spec thy
{rough_classification = Spec_Rules.Equational _, terms = pats, rules = thms, ...} =
map_filter (filter_pat_eq thy thms) pats
| pat_eqs_of_spec thy
{rough_classification = Spec_Rules.Unknown, terms = [Const (@{const_name LC_DEL},_)$pat], ...} =
[(DEL (Thm.global_cterm_of thy pat))]
| pat_eqs_of_spec _ _ = [];
in
fun pat_eqs_of_specs thy specs = map (pat_eqs_of_spec thy) specs
|> flat |> rev |> process_actions [];
end;
fun is_proper_pat cpat = let
val pat = Thm.term_of cpat;
val (f,args) = strip_comb pat;
in
is_Const f
andalso args <> []
andalso not (is_Var (hd (rev args)))
end;
local
fun inst_name lthy pat = let
val (fname,params) = case strip_comb pat of
((Const (fname,_)),params) => (fname,params)
| _ => raise TERM ("inst_name: Expected pattern",[pat]);
fun pname (Const (n,_)) = Long_Name.base_name n
| pname (s$t) = pname s ^ "_" ^ pname t
| pname _ = Name.uu;
in
space_implode "_" (Long_Name.base_name fname::map pname params)
|> gen_variant (can (Proof_Context.read_const {proper = true, strict = false} lthy))
end;
in
fun inst_pat_eq (cpat,thms) =
wrap_lthy_result_global
(fn lthy => let
val (((instT,inst),thms),lthy) = Variable.import true thms lthy;
val cpat = Thm.instantiate_cterm (instT, inst) cpat;
val pat = Thm.term_of cpat;
val name = inst_name lthy pat;
val ((_,(_,def_thm)),lthy)
= Local_Theory.define ((Binding.name name,NoSyn),
((Binding.name (Thm.def_name name),[]),pat)) lthy;
val thms' = map (Local_Defs.fold lthy [def_thm]) thms;
in ((def_thm,thms'),lthy) end)
(fn m => fn (def_thm,thms') =>
(Morphism.thm m def_thm, map (Morphism.thm m) thms'))
#> (fn ((def_thm,thms'),thy) => let
val thy = thy
|> add_unf_thms [def_thm]
|> Code.declare_default_eqns_global (map (rpair true) thms');
in thy end)
end
fun new_specs thy = let
val bd = BlockData.get thy;
val _ = assert_open bd;
val ctxt = Proof_Context.init_global thy;
val srules = Spec_Rules.get ctxt;
val res = take (length srules - #idx bd) srules;
in res end
fun open_block thy = let
val bd = BlockData.get thy;
val _ = assert_closed bd;
val ctxt = Proof_Context.init_global thy;
val idx = length (Spec_Rules.get ctxt);
val thy = BlockData.map (K (init_block idx)) thy;
in thy end;
fun process_block bd thy = let
fun filter_del_pats cpat peqs = let
val peqs' = filter (not o del_pat_matches cpat) peqs
val _ = if length peqs = length peqs' then
warning ("Locale_Code: No pattern-eqs matching filter: " ^
@{make_string} cpat)
else ();
in peqs' end;
fun filter_add_pats (orig_pat,_) = forall (fn (add_pat,_) =>
not (renames_cterm (orig_pat,add_pat)))
(#add_pateqs bd);
val specs = new_specs thy;
val peqs = pat_eqs_of_specs thy specs
|> fold filter_del_pats (#del_pats bd)
|> filter filter_add_pats;
val peqs = peqs @ #add_pateqs bd;
val peqs = rev peqs;
val _ = if !tracing_enabled then
map (fn peq => (tracing (@{make_string} peq); ())) peqs
else [];
val thy = thy |> fold inst_pat_eq peqs;
in thy end;
fun close_block thy = let
val bd = BlockData.get thy;
val _ = assert_open bd;
val thy = process_block bd thy
|> BlockData.map (K closed_block);
in thy end;
fun del_pat cpat thy = let
val bd = BlockData.get thy;
val _ = assert_open bd;
val bd = bd_add_del_pats [cpat] bd;
val thy = BlockData.map (K bd) thy;
in thy end;
fun add_pat_eq cpat thms thy = let
val _ = is_proper_pat cpat
orelse raise CTERM ("add_pat_eq: Not a proper pattern",[cpat]);
fun ntp thm = case try norm_thm_pat (thm,cpat) of
NONE => raise THM ("add_pat_eq: Theorem does not match pattern",~1,[thm])
| SOME thm => thm;
val thms = map ntp thms;
val thy = BlockData.map (bd_add_add_pateqs [(cpat,thms)]) thy;
in thy end;
local
fun cpat_of_thm thm = let
fun strip ct = case Thm.term_of ct of
(_$Var _) => strip (Thm.dest_fun ct)
| _ => ct;
in
strip (Thm.lhs_of thm)
end;
fun adjust_length (cpat1,cpat2) = let
val n1 = cpat1 |> Thm.term_of |> strip_comb |> #2 |> length;
val n2 = cpat2 |> Thm.term_of |> strip_comb |> #2 |> length;
in
if n1>n2 then
(funpow (n1-n2) Thm.dest_fun cpat1, cpat2)
else
(cpat1, funpow (n2-n1) Thm.dest_fun cpat2)
end
fun find_match cpat cpat' = SOME (cpat,rename_cterm (cpat',cpat))
handle Pattern.MATCH => (case Thm.term_of cpat' of
_$_ => find_match (Thm.dest_fun cpat) (Thm.dest_fun cpat')
| _ => NONE
);
fun comp_head thms = case map norm_def_thm thms of
[] => NONE
| thm::thms => let
fun ch [] r = SOME r
| ch (thm::thms) (cpat,acc) = let
val cpat' = cpat_of_thm thm;
val (cpat,cpat') = adjust_length (cpat,cpat')
in case find_match cpat cpat' of NONE => NONE
| SOME (cpat,inst) =>
ch thms (cpat, Drule.instantiate_normalize inst thm :: acc)
end;
in ch thms (cpat_of_thm thm,[thm]) end;
in
fun lc_decl_eq thms lthy = case comp_head thms of
SOME (cpat,thms) => let
val _ = if !tracing_enabled then
tracing ("decl_eq: " ^ @{make_string} cpat ^ ": "
^ @{make_string} thms)
else ();
fun decl m = let
val cpat'::thms' = Morphism.fact m (Drule.mk_term cpat :: thms);
val cpat' = Drule.dest_term cpat';
in
Context.mapping
(BlockData.map (bd_add_add_pateqs [(cpat',thms')])) I
end
in
lthy |> Local_Theory.declaration {syntax = false, pervasive = false, pos = ⌂} decl
end
| NONE => raise THM ("Locale_Code.lc_decl_eq: No common pattern",~1,thms);
end;
fun lc_decl_del pat = let
val ty = fastype_of pat;
val dpat = Const (@{const_name LC_DEL},ty --> @{typ unit})$pat;
in
Spec_Rules.add Binding.empty Spec_Rules.Unknown [dpat] []
end
val setup = FoldSSData.setup;
end
›
setup Locale_Code.setup
attribute_setup lc_delete = ‹
Parse.and_list1' ICF_Tools.parse_cpat >>
(fn cpats => Thm.declaration_attribute (K
(Context.mapping (fold Locale_Code.del_pat cpats) I)))
› "Locale_Code: Delete patterns for current block"
attribute_setup lc_add = ‹
Parse.and_list1' (ICF_Tools.parse_cpat -- Attrib.thms) >>
(fn peqs => Thm.declaration_attribute (K
(Context.mapping (fold (uncurry Locale_Code.add_pat_eq) peqs) I)))
› "Locale_Code: Add pattern-eqs for current block"
end