File ‹applicative.ML›
signature APPLICATIVE =
sig
type afun
val intern: Context.generic -> xstring -> string
val extern: Context.generic -> string -> xstring
val afun_of_generic: Context.generic -> string -> afun
val afun_of: Proof.context -> string -> afun
val afuns_of_term_generic: Context.generic -> term -> afun list
val afuns_of_term: Proof.context -> term -> afun list
val afuns_of_typ_generic: Context.generic -> typ -> afun list
val afuns_of_typ: Proof.context -> typ -> afun list
val name_of_afun: afun -> binding
val unfolds_of_afun: afun -> thm list
type afun_inst
val match_afun_inst: Proof.context -> afun -> term * int -> afun_inst
val import_afun_inst: afun -> Proof.context -> afun_inst * Proof.context
val inner_sort_of: afun_inst -> sort
val mk_type: afun_inst -> typ -> typ
val mk_pure: afun_inst -> typ -> term
val lift_term: afun_inst -> term -> term
val mk_ap: afun_inst -> typ * typ -> term
val mk_comb: afun_inst -> typ -> term * term -> term
val mk_set: afun_inst -> typ -> term
val dest_type: Proof.context -> afun_inst -> typ -> typ option
val dest_type': Proof.context -> afun_inst -> typ -> typ
val dest_pure: Proof.context -> afun_inst -> term -> term
val dest_comb: Proof.context -> afun_inst -> term -> term * term
val infer_comb: Proof.context -> afun_inst -> term * term -> term
val subst_lift_term: afun_inst -> (term * term) list -> term -> term
val generalize_lift_terms: afun_inst -> term list -> Proof.context -> term list * Proof.context
val afun_unfold_tac: Proof.context -> afun -> int -> tactic
val afun_fold_tac: Proof.context -> afun -> int -> tactic
val unfold_all_tac: Proof.context -> int -> tactic
val normalform_conv: Proof.context -> afun -> conv
val normalize_rel_tac: Proof.context -> afun -> int -> tactic
val general_normalform_conv: Proof.context -> afun -> cterm * cterm -> thm * thm
val general_normalize_rel_tac: Proof.context -> afun -> int -> tactic
val forward_lift_rule: Proof.context -> afun -> thm -> thm
val unfold_wrapper_tac: Proof.context -> afun option -> int -> tactic
val fold_wrapper_tac: Proof.context -> afun option -> int -> tactic
val normalize_wrapper_tac: Proof.context -> afun option -> int -> tactic
val lifting_wrapper_tac: Proof.context -> afun option -> int -> tactic
val setup_combinators: (string * thm) list -> local_theory -> local_theory
val combinator_rule_attrib: string list option -> attribute
val parse_opt_afun: afun option context_parser
val applicative_cmd: (((((binding * string list) * string) * string) * string option) * string option) ->
local_theory -> Proof.state
val print_afuns: Proof.context -> unit
val add_unfold_attrib: xstring option -> attribute
val forward_lift_attrib: xstring -> attribute
end;
structure Applicative : APPLICATIVE =
struct
open Ctr_Sugar_Util
fun fold_options xs = fold (fn x =>
(case x of
SOME x' => cons x'
| NONE => I)) xs [];
fun the_pair [x, y] = (x, y)
| the_pair _ = raise General.Size;
fun strip_comb2 (f $ x $ y) = (f, (x, y))
| strip_comb2 t = raise TERM ("strip_comb2", [t]);
fun mk_comb_pattern (t, n) =
let
val Ts = take n (binder_types (fastype_of t));
val maxidx = maxidx_of_term t;
val vars = map (fn (T, i) => ((Name.uu, maxidx + i), T)) (Ts ~~ (1 upto n));
in (vars, Term.betapplys (t, map Var vars)) end;
fun match_comb_pattern ctxt tn u =
let
val thy = Proof_Context.theory_of ctxt;
val (vars, pat) = mk_comb_pattern tn;
val envs = Pattern.match thy (pat, u) (Vartab.empty, Vartab.empty)
handle Pattern.MATCH => raise TERM ("match_comb_pattern", [u, pat]);
in (vars, envs) end;
fun dest_comb_pattern ctxt tn u =
let val (vars, (_, env)) = match_comb_pattern ctxt tn u;
in map (the o Envir.lookup1 env) vars end;
val norm_term_types = Term.map_types o Envir.norm_type_same;
val mk_TFrees_of = mk_TFrees' oo replicate;
fun mk_Free name typ ctxt = yield_singleton Variable.variant_fixes name ctxt
|>> (fn name' => Free (name', typ));
fun mk_tuple' ts = fold_rev (curry HOLogic.mk_prod) ts HOLogic.unit;
fun strip_tuple' (Const (@{const_name Unity}, _)) = []
| strip_tuple' (Const (@{const_name Pair}, _) $ t1 $ t2) = t1 :: strip_tuple' t2
| strip_tuple' t = raise TERM ("strip_tuple'", [t]);
fun mk_eq_on S =
let val (SA, ST) = `HOLogic.dest_setT (fastype_of S);
in Const (@{const_name eq_on}, ST --> BNF_Util.mk_pred2T SA SA) $ S end;
type poly_type = typ list * typ;
type poly_term = typ list * term;
fun instantiate_poly_type (tvars, T) insts = typ_subst_atomic (tvars ~~ insts) T;
fun instantiate_poly_term (tvars, t) insts = subst_atomic_types (tvars ~~ insts) t;
fun dest_poly_type ctxt (tvars, T) U =
let
val thy = Proof_Context.theory_of ctxt;
val tyenv = Sign.typ_match thy (T, U) Vartab.empty
handle Type.TYPE_MATCH => raise TYPE ("dest_poly_type", [U, T], []);
in map (Type.lookup tyenv o dest_TVar) tvars end;
fun poly_type_to_term (tvars, T) = (tvars, Logic.mk_type T);
fun poly_type_of_term (tvars, t) = (tvars, Logic.dest_type t);
fun pack_poly_term (tvars, t) = HOLogic.mk_prod (mk_tuple' (map Logic.mk_type tvars), t);
fun unpack_poly_term t =
let val (tvars, t') = HOLogic.dest_prod t;
in (map Logic.dest_type (strip_tuple' tvars), t') end;
val pack_poly_terms = mk_tuple' o map pack_poly_term;
val unpack_poly_terms = map unpack_poly_term o strip_tuple';
fun match_poly_terms_type ctxt (pt, i) (U, maxidx) =
let
val thy = Proof_Context.theory_of ctxt;
val pt' = Logic.incr_indexes ([], maxidx + 1) pt;
val (tvars, T) = poly_type_of_term (nth (unpack_poly_terms pt') i);
val tyenv = Sign.typ_match thy (T, U) Vartab.empty
handle Type.TYPE_MATCH => raise TYPE ("match_poly_terms", [U, T], []);
val tyenv' = fold Vartab.delete_safe (map (#1 o dest_TVar) tvars) tyenv;
val pt'' = Envir.subst_term_types tyenv' pt';
in unpack_poly_terms pt'' end;
fun match_poly_terms ctxt (pt, i) (t, maxidx) =
match_poly_terms_type ctxt (pt, i) (fastype_of t, maxidx);
fun import_poly_terms pt ctxt =
let
fun insert_paramTs (tvars, t) = fold_types (fold_atyps
(fn TVar v => if member (op =) tvars (TVar v) then I else insert (op =) v
| _ => I)) t;
val paramTs = rev (fold insert_paramTs (unpack_poly_terms pt) []);
val (tfrees, ctxt') = Variable.invent_types (map #2 paramTs) ctxt;
val instT = TVars.make (paramTs ~~ map TFree tfrees);
val params = map (apsnd (Term_Subst.instantiateT instT)) (rev (Term.add_vars pt []));
val (frees, ctxt'') = Variable.variant_fixes (map (Name.clean o #1 o #1) params) ctxt';
val inst = Vars.make (params ~~ map Free (frees ~~ map #2 params));
val pt' = Term_Subst.instantiate (instT, inst) pt;
in (unpack_poly_terms pt', ctxt'') end;
type rel_thms = {
pure_transfer: thm,
ap_rel_fun: thm
};
fun map_rel_thms f {pure_transfer, ap_rel_fun} =
{pure_transfer = f pure_transfer, ap_rel_fun = f ap_rel_fun};
type afun_thms = {
hom: thm,
ichng: thm,
reds: thm Symtab.table,
rel_thms: rel_thms option,
rel_intros: thm list,
pure_comp_conv: thm
};
fun map_afun_thms f {hom, ichng, reds, rel_thms, rel_intros, pure_comp_conv} =
{hom = f hom, ichng = f ichng, reds = Symtab.map (K f) reds,
rel_thms = Option.map (map_rel_thms f) rel_thms, rel_intros = map f rel_intros,
pure_comp_conv = f pure_comp_conv};
datatype afun = AFun of {
name: binding,
terms: term,
rel: term option,
thms: afun_thms,
unfolds: thm Item_Net.T
};
fun rep_afun (AFun af) = af;
val name_of_afun = #name o rep_afun;
val terms_of_afun = #terms o rep_afun;
val rel_of_afun = #rel o rep_afun;
val thms_of_afun = #thms o rep_afun;
val unfolds_of_afun = Item_Net.content o #unfolds o rep_afun;
val red_of_afun = Symtab.lookup o #reds o thms_of_afun;
val has_red_afun = is_some oo red_of_afun;
fun mk_afun name terms rel thms =
AFun {name = name, terms = terms, rel = rel, thms = thms, unfolds = Thm.item_net};
fun map_afun f1 f2 f3 f4 f5 (AFun {name, terms, rel, thms, unfolds}) =
AFun {name = f1 name, terms = f2 terms, rel = f3 rel, thms = f4 thms, unfolds = f5 unfolds};
fun map_unfolds f thms = fold Item_Net.update (map f (Item_Net.content thms)) Thm.item_net;
fun morph_afun phi =
let
val binding = Morphism.binding phi;
val term = Morphism.term phi;
val thm = Morphism.thm phi;
in map_afun binding term (Option.map term) (map_afun_thms thm) (map_unfolds thm) end;
val transfer_afun = morph_afun o Morphism.transfer_morphism;
fun add_unfolds_afun thms = map_afun I I I I (fold Item_Net.update thms);
fun patterns_of_afun af =
let
val [Tt, (_, pure), (_, ap), _] = unpack_poly_terms (terms_of_afun af);
val (_, T) = poly_type_of_term Tt;
in [#2 (mk_comb_pattern (pure, 1)), #2 (mk_comb_pattern (ap, 2)), Net.encode_type T] end;
datatype combinator_rule = Combinator_Rule of {
strong_premises: string Ord_List.T,
weak_premises: bool,
conclusion: string,
eq_thm: thm
};
fun rep_combinator_rule (Combinator_Rule rule) = rule;
val conclusion_of_rule = #conclusion o rep_combinator_rule;
val thm_of_rule = #eq_thm o rep_combinator_rule;
fun eq_combinator_rule (rule1, rule2) =
pointer_eq (rule1, rule2) orelse Thm.eq_thm (thm_of_rule rule1, thm_of_rule rule2);
fun is_applicable_rule rule have_weak have_premises =
let val {strong_premises, weak_premises, ...} = rep_combinator_rule rule;
in (have_weak orelse not weak_premises) andalso have_premises strong_premises end;
fun map_combinator_rule f1 f2 f3 f4
(Combinator_Rule {strong_premises, weak_premises, conclusion, eq_thm}) =
Combinator_Rule {strong_premises = f1 strong_premises, weak_premises = f2 weak_premises,
conclusion = f3 conclusion, eq_thm = f4 eq_thm};
fun transfer_combinator_rule thy = map_combinator_rule I I I (Thm.transfer thy);
fun mk_combinator_rule comb_names weak_premises thm =
let
val (lhs, rhs) = Logic.dest_equals (Thm.prop_of thm);
val conclusion = the (Symtab.lookup comb_names (#1 (dest_Const lhs)));
val premises = Ord_List.make fast_string_ord
(fold_options (map (Symtab.lookup comb_names o #1) (Term.add_consts rhs [])));
val weak_premises' = Ord_List.make fast_string_ord (these weak_premises);
val strong_premises = Ord_List.subtract fast_string_ord weak_premises' premises;
in Combinator_Rule {strong_premises = strong_premises, weak_premises = is_some weak_premises,
conclusion = conclusion, eq_thm = thm} end;
fun merge_afuns _ (af1, af2) = if pointer_eq (af1, af2)
then raise Change_Table.SAME
else map_afun I I I I (fn thms1 => Item_Net.merge (thms1, #unfolds (rep_afun af2))) af1;
structure Data = Generic_Data
(
type T = {
combinators: thm Symtab.table * combinator_rule list,
afuns: afun Name_Space.table,
patterns: (string * term list) Item_Net.T
};
val empty = {
combinators = (Symtab.empty, []),
afuns = Name_Space.empty_table "applicative functor",
patterns = Item_Net.init (op = o apply2 #1) #2
};
fun merge ({combinators = (cd1, cr1), afuns = a1, patterns = p1},
{combinators = (cd2, cr2), afuns = a2, patterns = p2}) =
{combinators = (Symtab.merge (K true) (cd1, cd2), Library.merge eq_combinator_rule (cr1, cr2)),
afuns = Name_Space.join_tables merge_afuns (a1, a2),
patterns = Item_Net.merge (p1, p2)};
);
fun get_combinators context =
let
val thy = Context.theory_of context;
val {combinators = (defs, rules), ...} = Data.get context;
in (Symtab.map (K (Thm.transfer thy)) defs, map (transfer_combinator_rule thy) rules) end;
val get_afun_table = #afuns o Data.get;
val get_afun_space = Name_Space.space_of_table o get_afun_table;
val get_patterns = #patterns o Data.get;
fun map_data f1 f2 f3 {combinators, afuns, patterns} =
{combinators = f1 combinators, afuns = f2 afuns, patterns = f3 patterns};
val intern = Name_Space.intern o get_afun_space;
fun extern context = Name_Space.extern (Context.proof_of context) (get_afun_space context);
local fun undeclared name = error ("Undeclared applicative functor " ^ quote name);
in
fun afun_of_generic context name = case Name_Space.lookup (get_afun_table context) name of
SOME af => transfer_afun (Context.theory_of context) af
| NONE => undeclared name;
val afun_of = afun_of_generic o Context.Proof;
fun update_afun name f context = if Name_Space.defined (get_afun_table context) name
then Data.map (map_data I (Name_Space.map_table_entry name f) I) context
else undeclared name;
end;
fun match_term context = map #1 o Item_Net.retrieve_matching (get_patterns context);
fun match_typ context = match_term context o Net.encode_type;
fun afuns_of_term_generic context = map (afun_of_generic context) o match_term context;
val afuns_of_term = afuns_of_term_generic o Context.Proof;
fun afuns_of_typ_generic context = map (afun_of_generic context) o match_typ context;
val afuns_of_typ = afuns_of_typ_generic o Context.Proof;
fun all_unfolds_of_generic context =
let val unfolds_of = map (Thm.transfer'' context) o unfolds_of_afun;
in Name_Space.fold_table (fn (_, af) => append (unfolds_of af)) (get_afun_table context) [] end;
val all_unfolds_of = all_unfolds_of_generic o Context.Proof;
type afun_inst = {
T: poly_type,
pure: poly_term,
ap: poly_term,
set: poly_term
};
fun mk_afun_inst [T, pure, ap, set] = {T = poly_type_of_term T, pure = pure, ap = ap, set = set};
fun pack_afun_inst {T, pure, ap, set} = pack_poly_terms [poly_type_to_term T, pure, ap, set];
fun match_afun_inst ctxt af = match_poly_terms ctxt (terms_of_afun af, 0) #> mk_afun_inst;
fun import_afun_inst_raw terms = import_poly_terms terms #>> mk_afun_inst;
val import_afun_inst = import_afun_inst_raw o terms_of_afun;
fun inner_sort_of {T = (tvars, _), ...} = Type.sort_of_atyp (the_single tvars);
fun mk_type {T, ...} = instantiate_poly_type T o single;
fun mk_pure {pure, ...} = instantiate_poly_term pure o single;
fun mk_ap {ap, ...} (T1, T2) = instantiate_poly_term ap [T1, T2];
fun mk_set {set, ...} = instantiate_poly_term set o single;
fun lift_term af_inst t = Term.betapply (mk_pure af_inst (Term.fastype_of t), t);
fun mk_comb af_inst funT (t1, t2) = Term.betapplys (mk_ap af_inst (dest_funT funT), [t1, t2]);
fun dest_type ctxt {T, ...} = the_single o dest_poly_type ctxt T;
val dest_type' = the_default HOLogic.unitT ooo dest_type;
fun dest_pure ctxt {pure = (_, pure), ...} = the_single o dest_comb_pattern ctxt (pure, 1);
fun dest_comb ctxt {ap = (_, ap), ...} = the_pair o dest_comb_pattern ctxt (ap, 2);
fun infer_comb ctxt af_inst (t1, t2) =
let val funT = the_default (dummyT --> dummyT) (dest_type ctxt af_inst (fastype_of t1));
in mk_comb af_inst funT (t1, t2) end;
fun subst_lift_term af_inst subst tm =
let
fun subst_lift (s $ t) =
(case (subst_lift s, subst_lift t) of
(NONE, NONE) => NONE
| (SOME s', NONE) => SOME (mk_comb af_inst (fastype_of s) (s', lift_term af_inst t))
| (NONE, SOME t') => SOME (mk_comb af_inst (fastype_of s) (lift_term af_inst s, t'))
| (SOME s', SOME t') => SOME (mk_comb af_inst (fastype_of s) (s', t')))
| subst_lift t = AList.lookup (op aconv) subst t;
in
(case subst_lift tm of
NONE => lift_term af_inst tm
| SOME tm' => tm')
end;
fun add_lifted_vars (s $ t) = add_lifted_vars s #> add_lifted_vars t
| add_lifted_vars (Abs (_, _, t)) = Term.add_vars t
| add_lifted_vars _ = I;
fun generalize_lift_terms af_inst ts ctxt =
let
val vars = subtract (op =) (fold add_lifted_vars ts []) (fold Term.add_vars ts []);
val (var_names, Ts) = split_list vars;
val (free_names, ctxt') = Variable.variant_fixes (map #1 var_names) ctxt;
val Ts' = map (mk_type af_inst) Ts;
val subst = map Var vars ~~ map Free (free_names ~~ Ts');
in (map (subst_lift_term af_inst subst) ts, ctxt') end;
val clean_name = perhaps (perhaps_apply [try Name.dest_skolem, try Name.dest_internal]);
fun term_to_vname (Const (x, _)) = Long_Name.base_name x
| term_to_vname (Free (x, _)) = clean_name x
| term_to_vname (Var ((x, _), _)) = clean_name x
| term_to_vname _ = "x";
fun afuns_of_rel precise ctxt t =
let val (_, (lhs, rhs)) = Variable.focus NONE t ctxt
|> #1 |> #2
|> Logic.strip_imp_concl
|> Envir.beta_eta_contract
|> HOLogic.dest_Trueprop
|> strip_comb2;
in if precise
then (case afuns_of_term ctxt lhs of
[] => afuns_of_term ctxt rhs
| afs => afs)
else afuns_of_typ ctxt (fastype_of lhs) end;
fun AUTO_AFUNS precise tac ctxt opt_af = case opt_af of
SOME af => tac [af]
| NONE => SUBGOAL (fn (goal, i) => (case afuns_of_rel precise ctxt goal of
[] => no_tac
| afs => tac afs i) handle TERM _ => no_tac);
fun AUTO_AFUN precise tac = AUTO_AFUNS precise (tac o hd);
fun binop_par_conv cv ct =
let
val ((binop, arg1), arg2) = Thm.dest_comb ct |>> Thm.dest_comb;
val (th1, th2) = cv (arg1, arg2);
in Drule.binop_cong_rule binop th1 th2 end;
fun binop_par_conv_tac cv = CONVERSION (HOLogic.Trueprop_conv (binop_par_conv cv));
val fold_goal_tac = SELECT_GOAL oo Raw_Simplifier.fold_goals_tac;
fun afun_unfold_tac ctxt af = Raw_Simplifier.rewrite_goal_tac ctxt (unfolds_of_afun af);
fun afun_fold_tac ctxt af = fold_goal_tac ctxt (unfolds_of_afun af);
fun unfold_all_tac ctxt = Raw_Simplifier.rewrite_goal_tac ctxt (all_unfolds_of ctxt);
fun pure_conv ctxt {pure = (_, pure), ...} cv ct =
let
val ([var], (tyenv, env)) = match_comb_pattern ctxt (pure, 1) (Thm.term_of ct);
val arg = the (Envir.lookup1 env var);
val thm = cv (Thm.cterm_of ctxt arg);
in
if Thm.is_reflexive thm then Conv.all_conv ct
else
let val pure_inst = Envir.subst_term_types tyenv pure;
in Drule.arg_cong_rule (Thm.cterm_of ctxt pure_inst) thm end
end;
fun ap_conv ctxt {ap = (_, ap), ...} cv1 cv2 ct =
let
val ([var1, var2], (tyenv, env)) = match_comb_pattern ctxt (ap, 2) (Thm.term_of ct);
val (arg1, arg2) = apply2 (the o Envir.lookup1 env) (var1, var2);
val thm1 = cv1 (Thm.cterm_of ctxt arg1);
val thm2 = cv2 (Thm.cterm_of ctxt arg2);
in
if Thm.is_reflexive thm1 andalso Thm.is_reflexive thm2 then Conv.all_conv ct
else
let val ap_inst = Envir.subst_term_types tyenv ap;
in Drule.binop_cong_rule (Thm.cterm_of ctxt ap_inst) thm1 thm2 end
end;
fun normalform_conv ctxt af ct =
let
val {hom, ichng, pure_comp_conv, ...} = thms_of_afun af;
val the_red = the o red_of_afun af;
val leaf_conv = Conv.rewr_conv (mk_meta_eq (the_red "I") |> Thm.symmetric);
val merge_conv = Conv.rewr_conv (mk_meta_eq hom);
val swap_conv = Conv.rewr_conv (mk_meta_eq ichng);
val rotate_conv = Conv.rewr_conv (mk_meta_eq (the_red "B") |> Thm.symmetric);
val pure_rotate_conv = Conv.rewr_conv (mk_meta_eq pure_comp_conv);
val af_inst = match_afun_inst ctxt af (Thm.term_of ct, Thm.maxidx_of_cterm ct);
fun left_conv cv = ap_conv ctxt af_inst cv Conv.all_conv;
fun norm_pure_nf ct =
((pure_rotate_conv then_conv left_conv norm_pure_nf) else_conv merge_conv) ct;
val norm_nf_pure = swap_conv then_conv norm_pure_nf;
fun norm_nf_nf ct = ((rotate_conv then_conv
left_conv (left_conv norm_pure_nf then_conv norm_nf_nf)) else_conv
norm_nf_pure) ct;
fun normalize ct = ((ap_conv ctxt af_inst normalize normalize then_conv norm_nf_nf) else_conv
pure_conv ctxt af_inst Conv.all_conv else_conv
leaf_conv) ct;
in normalize ct end;
val normalize_rel_tac = binop_par_conv_tac o apply2 oo normalform_conv;
datatype apterm =
Pure of term
| ApVar of int * term
| Ap of apterm * apterm;
fun apterm_vars (Pure _) = I
| apterm_vars (ApVar v) = cons v
| apterm_vars (Ap (t1, t2)) = apterm_vars t1 #> apterm_vars t2;
fun occurs_any _ (Pure _) = false
| occurs_any vs (ApVar (i, _)) = exists (fn j => i = j) vs
| occurs_any vs (Ap (t1, t2)) = occurs_any vs t1 orelse occurs_any vs t2;
fun term_of_apterm ctxt af_inst t =
let
fun tm_of (Pure t) = t
| tm_of (ApVar (_, t)) = t
| tm_of (Ap (t1, t2)) = infer_comb ctxt af_inst (tm_of t1, tm_of t2);
in tm_of t end;
fun apterm_of_term ctxt af_inst t =
let
fun aptm_of t i = case try (dest_comb ctxt af_inst) t of
SOME (t1, t2) => i |> aptm_of t1 ||>> aptm_of t2 |>> Ap
| NONE => if can (dest_pure ctxt af_inst) t
then (Pure t, i)
else (ApVar (i, t), i + 1);
in aptm_of t end;
fun consolidate ctxt af (t1, t2) =
let
fun common_inst (i, t) (j, insts) = case Termtab.lookup insts t of
SOME k => (((i, t), k), (j, insts))
| NONE => (((i, t), j), (j + 1, Termtab.update (t, j) insts));
val (vars, _) = (0, Termtab.empty)
|> fold_map common_inst (apterm_vars t1 [])
||>> fold_map common_inst (apterm_vars t2 []);
fun merge_adjacent (([], _), _) [] = []
| merge_adjacent ((is, t), d) [] = [((is, t), d)]
| merge_adjacent (([], _), _) (((i, t), d)::xs) = merge_adjacent (([i], t), d) xs
| merge_adjacent ((is, t), d) (((i', t'), d')::xs) = if d = d'
then merge_adjacent ((i'::is, t), d) xs
else ((is, t), d) :: merge_adjacent (([i'], t'), d') xs;
fun align _ [] = NONE
| align ((i, t), d) (((i', t'), d')::xs) = if d = d'
then SOME ([((i @ i', t), d)], xs)
else Option.map (apfst (cons ((i', t'), d'))) (align ((i, t), d) xs);
fun merge ([], ys) = ys
| merge (xs, []) = xs
| merge ((xs as ((is1, t1), d1)::xs'), ys as (((is2, t2), d2)::ys')) = if d1 = d2
then ((is1 @ is2, t1), d1) :: merge (xs', ys')
else case (align ((is2, t2), d2) xs, align ((is1, t1), d1) ys) of
(SOME (zs, xs''), NONE) => zs @ merge (xs'', ys')
| (NONE, SOME (zs, ys'')) => zs @ merge (xs', ys'')
| _ => ((is1, t1), d1) :: ((is2, t2), d2) :: merge (xs', ys');
fun unbalanced vs = error ("Unbalanced opaque terms " ^
commas_quote (map (Syntax.string_of_term ctxt o #2 o #1) vs));
fun mismatch (t1, t2) = error ("Mismatched opaque terms " ^
quote (Syntax.string_of_term ctxt t1) ^ " and " ^ quote (Syntax.string_of_term ctxt t2));
fun same ([], []) = []
| same ([], ys) = unbalanced ys
| same (xs, []) = unbalanced xs
| same ((((i1, t1), d1)::xs), (((i2, t2), d2)::ys)) = if d1 = d2
then ((i1 @ i2, t1), d1) :: same (xs, ys)
else mismatch (t1, t2);
in vars
|> has_red_afun af "C" ? apply2 (sort (int_ord o apply2 #2))
|> apply2 (if has_red_afun af "W"
then merge_adjacent (([], Term.dummy), 0)
else map (apfst (apfst single)))
|> (if has_red_afun af "K" then merge else same)
|> map #1
end;
fun ap_cong ctxt af_inst thm1 thm2 =
let
val funT = the_default (dummyT --> dummyT)
(dest_type ctxt af_inst (Thm.typ_of_cterm (Thm.lhs_of thm1)));
val ap_inst = Thm.cterm_of ctxt (mk_ap af_inst (dest_funT funT));
in Drule.binop_cong_rule ap_inst thm1 thm2 end;
fun rewr_subst_ap ctxt af_inst rewr thm1 thm2 =
let
val rule1 = ap_cong ctxt af_inst thm1 thm2;
val rule2 = Conv.rewr_conv rewr (Thm.rhs_of rule1);
in Thm.transitive rule1 rule2 end;
fun merge_pures ctxt af_inst merge_thm tt =
let
fun merge (Pure t) = SOME (Thm.reflexive (Thm.cterm_of ctxt t))
| merge (ApVar _) = NONE
| merge (Ap (tt1, tt2)) = case merge tt1 of
NONE => NONE
| SOME thm1 => case merge tt2 of
NONE => NONE
| SOME thm2 => SOME (rewr_subst_ap ctxt af_inst merge_thm thm1 thm2);
in merge tt end;
exception ASSERT of string;
fun eliminate ctxt (af, af_inst) tt (v, v_tm) =
let
val {hom, ichng, ...} = thms_of_afun af;
val the_red = the o red_of_afun af;
val hom_conv = mk_meta_eq hom;
val ichng_conv = mk_meta_eq ichng;
val mk_combI = Thm.symmetric o mk_meta_eq;
val id_conv = mk_combI (the_red "I");
val comp_conv = mk_combI (the_red "B");
val flip_conv = Option.map mk_combI (red_of_afun af "C");
val const_conv = Option.map mk_combI (red_of_afun af "K");
val dup_conv = Option.map mk_combI (red_of_afun af "W");
val rewr_subst_ap = rewr_subst_ap ctxt af_inst;
fun extract_comb n thm = Pure (thm |> Thm.rhs_of |> funpow n Thm.dest_arg1 |> Thm.term_of);
fun refl_step tt = (tt, Thm.reflexive (Thm.cterm_of ctxt (term_of_apterm ctxt af_inst tt)));
fun comb2_step def (tt1, thm1) (tt2, thm2) =
let val thm = rewr_subst_ap def thm1 thm2;
in (Ap (Ap (extract_comb 3 thm, tt1), tt2), thm) end;
val B_step = comb2_step comp_conv;
fun swap_B_step (tt1, thm1) thm2 =
let
val thm3 = rewr_subst_ap ichng_conv thm1 thm2;
val thm4 = Thm.transitive thm3 (Conv.rewr_conv comp_conv (Thm.rhs_of thm3));
in (Ap (Ap (extract_comb 3 thm4, extract_comb 1 thm3), tt1), thm4) end;
fun I_step tm =
let val thm = Conv.rewr_conv id_conv (Thm.cterm_of ctxt tm)
in (extract_comb 1 thm, thm) end;
fun W_step s1 s2 =
let
val (Ap (Ap (tt1, tt2), tt3), thm1) = B_step s1 s2;
val thm2 = Conv.rewr_conv comp_conv (Thm.rhs_of thm1 |> funpow 2 Thm.dest_arg1);
val thm3 = merge_pures ctxt af_inst hom_conv tt3 |> the;
val (tt4, thm4) = swap_B_step (Ap (Ap (extract_comb 3 thm2, tt1), tt2), thm2) thm3;
val var = Thm.rhs_of thm1 |> Thm.dest_arg;
val thm5 = rewr_subst_ap (the dup_conv) thm4 (Thm.reflexive var);
val thm6 = Thm.transitive thm1 thm5;
in (Ap (extract_comb 2 thm6, tt4), thm6) end;
fun S_step s1 s2 =
let
val (Ap (Ap (tt1, tt2), tt3), thm1) = comb2_step (the flip_conv) s1 s2;
val thm2 = Conv.rewr_conv comp_conv (Thm.rhs_of thm1 |> Thm.dest_arg1);
val var = Thm.rhs_of thm1 |> Thm.dest_arg;
val thm3 = rewr_subst_ap (the dup_conv) thm2 (Thm.reflexive var);
val thm4 = Thm.transitive thm1 thm3;
val tt = Ap (extract_comb 2 thm4, Ap (Ap (extract_comb 3 thm2, Ap (tt1, tt2)), tt3));
in (tt, thm4) end;
fun K_step tt tm =
let
val ct = Thm.cterm_of ctxt tm;
val T_opt = Term.fastype_of tm |> dest_type ctxt af_inst |> Option.map (Thm.ctyp_of ctxt);
val thm = Thm.instantiate' [T_opt] [SOME ct]
(Conv.rewr_conv (the const_conv) (term_of_apterm ctxt af_inst tt |> Thm.cterm_of ctxt))
in (Ap (extract_comb 2 thm, tt), thm) end;
fun unreachable _ = raise ASSERT "eliminate: assertion failed";
fun elim (Pure _) = unreachable ()
| elim (ApVar (i, t)) = if exists (fn x => x = i) v then I_step t else unreachable ()
| elim (Ap (t1, t2)) = (case (occurs_any v t1, occurs_any v t2) of
(false, false) => unreachable ()
| (false, true) => B_step (refl_step t1) (elim t2)
| (true, false) => (case merge_pures ctxt af_inst hom_conv t2 of
SOME thm => swap_B_step (elim t1) thm
| NONE => comb2_step (the flip_conv) (elim t1) (refl_step t2))
| (true, true) => if is_some flip_conv
then S_step (elim t1) (elim t2)
else W_step (elim t1) (elim t2));
in if occurs_any v tt
then elim tt
else K_step tt v_tm
end;
fun general_normalform_conv ctxt af cts =
let
val (t1, t2) = apply2 (Thm.term_of) cts;
val maxidx = Int.max (apply2 Thm.maxidx_of_cterm cts);
val af_inst = match_afun_inst ctxt af (t1, maxidx);
val ((apt1, apt2), _) = 0 |> apterm_of_term ctxt af_inst t1 ||>> apterm_of_term ctxt af_inst t2;
val vs = consolidate ctxt af (apt1, apt2);
val merge_thm = mk_meta_eq (#hom (thms_of_afun af));
fun elim_all tt [] = the (merge_pures ctxt af_inst merge_thm tt)
| elim_all tt (v::vs) =
let
val (tt', rule1) = eliminate ctxt (af, af_inst) tt v;
val rule2 = elim_all tt' vs;
val (_, vartm) = dest_comb ctxt af_inst (Thm.term_of (Thm.rhs_of rule1));
val rule3 = ap_cong ctxt af_inst rule2 (Thm.reflexive (Thm.cterm_of ctxt vartm));
in Thm.transitive rule1 rule3 end;
in (elim_all apt1 vs, elim_all apt2 vs) end;
val general_normalize_rel_tac = binop_par_conv_tac oo general_normalform_conv;
fun rename_params names i st =
let
val (_, Bs, Bi, C) = Thm.dest_state (st, i);
val Bi' = Logic.list_rename_params names Bi;
in Thm.renamed_prop (Logic.list_implies (Bs @ [Bi'], C)) st end;
fun head_cong_tac ctxt af renames =
let
val {rel_intros, ...} = thms_of_afun af;
fun term_name tm = case AList.lookup (op aconv) renames tm of
SOME n => n
| NONE => term_to_vname tm;
fun gather_vars' af_inst tm = case try (dest_comb ctxt af_inst) tm of
SOME (t1, t2) => term_name t2 :: gather_vars' af_inst t1
| NONE => [];
fun gather_vars prop = case prop of
Const (@{const_name Trueprop}, _) $ (_ $ rhs) =>
rev (gather_vars' (match_afun_inst ctxt af (rhs, maxidx_of_term prop)) rhs)
| _ => [];
in SUBGOAL (fn (subgoal, i) =>
(REPEAT_DETERM (resolve_tac ctxt rel_intros i) THEN
REPEAT_DETERM (resolve_tac ctxt [ext, @{thm rel_fun_eq_onI}] i ORELSE
eresolve_tac ctxt [@{thm UNIV_E}] i) THEN
PRIMITIVE (rename_params (gather_vars subgoal) i)))
end;
fun forward_lift_rule ctxt af thm =
let
val thm = Object_Logic.rulify ctxt thm;
val (af_inst, ctxt_inst) = import_afun_inst af ctxt;
val (prop, ctxt_Ts) = yield_singleton Variable.importT_terms (Thm.prop_of thm) ctxt_inst;
val (lhs, rhs) = prop |> HOLogic.dest_Trueprop |> HOLogic.dest_eq;
val ([lhs', rhs'], ctxt_lifted) = generalize_lift_terms af_inst [lhs, rhs] ctxt_Ts;
val lifted = HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs', rhs'));
val (lifted', ctxt') = yield_singleton (Variable.import_terms true) lifted ctxt_lifted;
fun tac {prems, context} = HEADGOAL (general_normalize_rel_tac context af THEN'
head_cong_tac context af [] THEN'
resolve_tac context [prems MRS thm]);
val thm' = singleton (Variable.export ctxt' ctxt)
(Goal.prove ctxt' [] [] lifted' tac);
val thm'' = Raw_Simplifier.fold_rule ctxt (unfolds_of_afun af) thm';
in thm'' end;
fun forward_lift_attrib name =
Thm.rule_attribute [] (fn context => fn thm =>
let val af = afun_of_generic context (intern context name)
in forward_lift_rule (Context.proof_of context) af thm end);
fun unfold_wrapper_tac ctxt = AUTO_AFUNS false (fn afs =>
Simplifier.safe_asm_full_simp_tac (ctxt addsimps flat (map unfolds_of_afun afs))) ctxt;
fun fold_wrapper_tac ctxt = AUTO_AFUN true (fold_goal_tac ctxt o unfolds_of_afun) ctxt;
fun WRAPPER tac ctxt opt_af =
REPEAT_DETERM o resolve_tac ctxt [@{thm allI}] THEN'
Subgoal.FOCUS (fn {context = ctxt, params, ...} =>
let val renames = map (swap o apsnd Thm.term_of) params
in
AUTO_AFUNS false (EVERY' o map (afun_unfold_tac ctxt)) ctxt opt_af 1 THEN
AUTO_AFUN true (fn af =>
afun_unfold_tac ctxt af THEN'
CONVERSION Drule.beta_eta_conversion THEN'
tac ctxt af THEN'
head_cong_tac ctxt af renames) ctxt opt_af 1
end) ctxt THEN'
Raw_Simplifier.rewrite_goal_tac ctxt [Drule.triv_forall_equality];
val normalize_wrapper_tac = WRAPPER normalize_rel_tac;
val lifting_wrapper_tac = WRAPPER general_normalize_rel_tac;
val parse_opt_afun = Scan.peek (fn context =>
Scan.option Parse.name >> Option.map (intern context #> afun_of_generic context));
fun declare_combinators combs phi =
let
val (names, thms) = split_list combs;
val thms' = map (Morphism.thm phi) thms;
fun add_combs (defs, rules) = (fold (Symtab.insert (K false)) (names ~~ thms') defs, rules);
in Data.map (map_data add_combs I I) end;
val setup_combinators =
Local_Theory.declaration {syntax = false, pervasive = false, pos = ⌂} o declare_combinators;
fun combinator_of_red thm =
let
val (lhs, _) = Logic.dest_equals (Thm.prop_of thm);
val (head, _) = strip_comb lhs;
in #1 (dest_Const head) end;
fun register_combinator_rule weak_premises thm context =
let
val (lhs, rhs) = Logic.dest_equals (Thm.prop_of thm);
val ltvars = Term.add_tvars lhs [];
val rtvars = Term.add_tvars rhs [];
val _ = if exists (not o member op = ltvars) rtvars
then Pretty.breaks
[Pretty.str "Combinator equation",
Pretty.quote (Syntax.pretty_term (Context.proof_of context) (Thm.prop_of thm)),
Pretty.str "has additional type variables on right-hand side."]
|> Pretty.block |> Pretty.string_of |> error
else ();
val (defs, _) = #combinators (Data.get context);
val comb_names =
Symtab.make (map (fn (name, thm) => (combinator_of_red thm, name)) (Symtab.dest defs));
val rule = mk_combinator_rule comb_names weak_premises thm;
fun add_rule (defs, rules) = (defs, insert eq_combinator_rule rule rules);
in Data.map (map_data add_rule I I) context end;
val combinator_rule_attrib = Thm.declaration_attribute o register_combinator_rule;
fun combinator_closure rules have_weak combs =
let
fun apply rule (cs, changed) =
if not (Ord_List.member fast_string_ord cs (conclusion_of_rule rule)) andalso
is_applicable_rule rule have_weak (fn prems => Ord_List.subset fast_string_ord (prems, cs))
then (Ord_List.insert fast_string_ord (conclusion_of_rule rule) cs, true)
else (cs, changed);
fun loop cs =
(case fold apply rules (cs, false) of
(cs', true) => loop cs'
| (_, false) => cs);
in loop combs end;
fun derive_combinator_red ctxt af_inst red_thms (base_thm, eq_thm) =
let
val base_prop = Thm.prop_of base_thm;
val tvars = Term.add_tvars base_prop [];
val (Ts, ctxt_Ts) = mk_TFrees_of (length tvars) (inner_sort_of af_inst) ctxt;
val base_prop' = base_prop |> Term_Subst.instantiate (TVars.make (tvars ~~ Ts), Vars.empty);
val (lhs, rhs) = Logic.dest_equals base_prop';
val ([lhs', rhs'], ctxt') = generalize_lift_terms af_inst [lhs, rhs] ctxt_Ts;
val lifted_prop = (lhs', rhs') |> HOLogic.mk_eq |> HOLogic.mk_Trueprop;
val unfold_comb_conv = HOLogic.Trueprop_conv
(HOLogic.eq_conv (Conv.top_sweep_rewrs_conv [eq_thm] ctxt') Conv.all_conv);
fun tac goal_ctxt =
HEADGOAL (CONVERSION unfold_comb_conv THEN'
Raw_Simplifier.rewrite_goal_tac goal_ctxt red_thms THEN'
resolve_tac goal_ctxt [@{thm refl}]);
in
singleton (Variable.export ctxt' ctxt) (Goal.prove ctxt' [] [] lifted_prop (tac o #context))
end;
fun weak_red_closure ctxt (af_inst, merge_thm) strong_red =
let
val (lhs, _) = Thm.prop_of strong_red |> Logic.dest_equals;
val vars = rev (Term.add_vars lhs []);
fun closure [] prev thms = (prev::thms)
| closure ((v, af_T)::vs) prev thms =
(case try (dest_type ctxt af_inst) af_T of
NONE => closure vs prev thms
| SOME T_opt =>
let
val (T, ctxt') = (case T_opt of
NONE => yield_singleton Variable.invent_types (inner_sort_of af_inst) ctxt
|>> TFree
| SOME T => (T, ctxt));
val (v', ctxt'') = mk_Free (#1 v) T ctxt';
val pure_v = Thm.cterm_of ctxt'' (lift_term af_inst v');
val next = Drule.instantiate_normalize (TVars.empty, Vars.make [((v, af_T), pure_v)]) prev;
val next' = Raw_Simplifier.rewrite_rule ctxt'' [merge_thm] next;
val next'' = singleton (Variable.export ctxt'' ctxt) next';
in closure vs next'' (prev::thms) end);
in closure vars strong_red [] end;
fun combinator_red_closure ctxt (comb_defs, rules) (af_inst, merge_thm) weak_reds combs =
let
val have_weak = not (null weak_reds);
val red_thms0 = Symtab.fold (fn (_, thm) => cons (mk_meta_eq thm)) combs weak_reds;
val red_thms = flat (map (weak_red_closure ctxt (af_inst, merge_thm)) red_thms0);
fun apply rule ((cs, rs), changed) =
if not (Symtab.defined cs (conclusion_of_rule rule)) andalso
is_applicable_rule rule have_weak (forall (Symtab.defined cs))
then
let
val conclusion = conclusion_of_rule rule;
val def = the (Symtab.lookup comb_defs conclusion);
val new_red_thm = derive_combinator_red ctxt af_inst rs (def, thm_of_rule rule);
val new_red_thms = weak_red_closure ctxt (af_inst, merge_thm) (mk_meta_eq new_red_thm);
in ((Symtab.update (conclusion, new_red_thm) cs, new_red_thms @ rs), true) end
else ((cs, rs), changed);
fun loop xs =
(case fold apply rules (xs, false) of
(xs', true) => loop xs'
| (_, false) => xs);
in #1 (loop (combs, red_thms)) end;
fun mk_terms ctxt (raw_pure, raw_ap, raw_rel, raw_set) =
let
val thy = Proof_Context.theory_of ctxt;
val show_typ = quote o Syntax.string_of_typ ctxt;
val show_term = quote o Syntax.string_of_term ctxt;
fun closed_poly_term t =
let val poly_t = singleton (Variable.polymorphic ctxt) t;
in case Term.add_vars (singleton
(Variable.export_terms (Proof_Context.augment t ctxt) ctxt) t) [] of
[] => (case (Term.hidden_polymorphism poly_t) of
[] => poly_t
| _ => error ("Hidden type variables in term " ^ show_term t))
| _ => error ("Locally free variables in term " ^ show_term t)
end;
val pure = closed_poly_term raw_pure;
val (tvar, T1) = fastype_of pure |> dest_funT |>> dest_TVar
handle TYPE _ => error ("Bad type for pure: " ^ show_typ (fastype_of pure));
val maxidx_pure = maxidx_of_term pure;
val ap = Logic.incr_indexes ([], maxidx_pure + 1) (closed_poly_term raw_ap);
fun bad_ap _ = error ("Bad type for ap: " ^ show_typ (fastype_of ap));
val (T23, (T2, T3)) = fastype_of ap |> dest_funT ||> dest_funT
handle TYPE _ => bad_ap ();
val maxidx_common = Term.maxidx_term ap maxidx_pure;
fun no_unifier (T, U) = error ("Unable to infer common functor type from " ^
commas (map show_typ [T, U]));
fun unify_ap_type T (tyenv, maxidx) =
let
val argT = TVar ((Name.aT, maxidx + 1), []);
val T1' = Term_Subst.instantiateT (TVars.make [(tvar, argT)]) T1;
val (tyenv', maxidx') = Sign.typ_unify thy (T1', T) (tyenv, maxidx + 1)
handle Type.TUNIFY => no_unifier (T1', T);
in (argT, (tyenv', maxidx')) end;
val (ap_args, (ap_env, maxidx_env)) =
fold_map unify_ap_type [T2, T3, T23] (Vartab.empty, maxidx_common);
val [T2_arg, T3_arg, T23_arg] = map (Envir.norm_type ap_env) ap_args;
val (tvar2, tvar3) = (dest_TVar T2_arg, dest_TVar T3_arg) handle TYPE _ => bad_ap ();
val _ = if T23_arg = T2_arg --> T3_arg then () else bad_ap ();
val sort = foldl1 (Sign.inter_sort thy) (map #2 [tvar, tvar2, tvar3]);
val _ = Sign.of_sort thy (Term.aT sort --> Term.aT sort, sort) orelse
error ("Sort constraint " ^ quote (Syntax.string_of_sort ctxt sort) ^
" not closed under function types");
fun update_sort (v, S) (tyenv, maxidx) =
(Vartab.update_new (v, (S, TVar ((Name.aT, maxidx + 1), sort))) tyenv, maxidx + 1);
val (common_env, _) = fold update_sort [tvar, tvar2, tvar3] (ap_env, maxidx_env);
val tvar' = Envir.norm_type common_env (TVar tvar);
val pure' = norm_term_types common_env pure;
val (tvar2', tvar3') = apply2 (Envir.norm_type common_env) (T2_arg, T3_arg);
val ap' = norm_term_types common_env ap;
fun bad_set set = error ("Bad type for set: " ^ show_typ (fastype_of set));
fun mk_set set =
let
val tyenv = Sign.typ_match thy (domain_type (fastype_of set), range_type (fastype_of pure'))
Vartab.empty
handle Type.TYPE_MATCH => bad_set set;
val set' = Envir.subst_term_types tyenv set;
val set_tvar = fastype_of set' |> range_type |> HOLogic.dest_setT |> dest_TVar
handle TYPE _ => bad_set set;
val _ = if Term.eq_tvar (dest_TVar tvar', set_tvar) then () else bad_set set;
in ([tvar'], set') end
val set = (case raw_set of
NONE => ([tvar'], Abs ("x", tvar', HOLogic.mk_UNIV tvar'))
| SOME t => mk_set (closed_poly_term t));
val terms = Term_Subst.zero_var_indexes (pack_poly_terms
[poly_type_to_term ([tvar'], range_type (fastype_of pure')),
([tvar'], pure'), ([tvar2', tvar3'], ap'), set]);
fun bad_rel rel = error ("Bad type for rel: " ^ show_typ (fastype_of rel));
fun mk_rel rel =
let
val ((T1, T2), (T1_af, T2_af)) = fastype_of rel
|> dest_funT
|>> BNF_Util.dest_pred2T
||> BNF_Util.dest_pred2T;
val _ = (dest_TVar T1; dest_TVar T2);
val _ = if T1 = T2 then bad_rel rel else ();
val af_inst = mk_afun_inst (match_poly_terms_type ctxt (terms, 0) (T1_af, maxidx_of_term rel));
val (T1', T2') = apply2 (dest_type ctxt af_inst) (T1_af, T2_af);
val _ = if (is_none T1' andalso is_none T2') orelse (T1' = SOME T1 andalso T2' = SOME T2)
then () else bad_rel rel;
in Term_Subst.zero_var_indexes (pack_poly_terms [([T1, T2], rel)]) end
handle TYPE _ => bad_rel rel;
val rel = Option.map (mk_rel o closed_poly_term) raw_rel;
in (terms, rel) end;
fun mk_rel_intros {pure_transfer, ap_rel_fun} =
let val pure_rel_intro = pure_transfer RS @{thm rel_funD};
in [pure_rel_intro, ap_rel_fun] end;
fun mk_afun_thms ctxt af_inst (hom_thm, ichng_thm, reds, rel_axioms) =
let
val pure_comp_conv =
let
val ([T1, T2, T3], ctxt_Ts) = mk_TFrees_of 3 (inner_sort_of af_inst) ctxt;
val (((g, f), x), ctxt') = ctxt_Ts
|> mk_Free "g" (T2 --> T3)
||>> mk_Free "f" (mk_type af_inst (T1 --> T2))
||>> mk_Free "x" (mk_type af_inst T1);
val comb = mk_comb af_inst;
val lhs = comb (T2 --> T3) (lift_term af_inst g, comb (T1 --> T2) (f, x));
val B_g = Abs ("f", T1 --> T2, Abs ("x", T1, Term.betapply (g, Bound 1 $ Bound 0)));
val rhs = comb (T1 --> T3)
(comb ((T1 --> T2) --> T1 --> T3) (lift_term af_inst B_g, f), x);
val prop = HOLogic.mk_eq (lhs, rhs) |> HOLogic.mk_Trueprop;
val merge_rule = mk_meta_eq hom_thm;
val B_intro = the (Symtab.lookup reds "B") |> mk_meta_eq |> Thm.symmetric;
fun tac goal_ctxt =
HEADGOAL (Raw_Simplifier.rewrite_goal_tac goal_ctxt [B_intro, merge_rule] THEN'
resolve_tac goal_ctxt [@{thm refl}]);
in
singleton (Variable.export ctxt' ctxt) (Goal.prove ctxt' [] [] prop (tac o #context))
end;
val eq_intros =
let
val ([T1, T2], ctxt_Ts) = mk_TFrees_of 2 (inner_sort_of af_inst) ctxt;
val T12 = mk_type af_inst (T1 --> T2);
val (((((x, y), x'), f), g), ctxt') = ctxt_Ts
|> mk_Free "x" T1
||>> mk_Free "y" T1
||>> mk_Free "x" (mk_type af_inst T1)
||>> mk_Free "f" T12
||>> mk_Free "g" T12;
val pure_fun = mk_pure af_inst T1;
val pure_cong = Drule.infer_instantiate' ctxt'
(map (SOME o Thm.cterm_of ctxt') [x, y, pure_fun]) @{thm arg_cong};
val ap_fun = mk_ap af_inst (T1, T2);
val ap_cong1 = Drule.infer_instantiate' ctxt'
(map (SOME o Thm.cterm_of ctxt') [f, g, ap_fun, x']) @{thm arg1_cong};
in Variable.export ctxt' ctxt [pure_cong, ap_cong1] end;
val rel_intros = case rel_axioms of
NONE => []
| SOME axioms => mk_rel_intros axioms;
in
{hom = hom_thm,
ichng = ichng_thm,
reds = reds,
rel_thms = rel_axioms,
rel_intros = eq_intros @ rel_intros,
pure_comp_conv = pure_comp_conv}
end;
fun reuse_TFrees n S (ctxt, Ts) =
let
val have_n = Int.min (n, length Ts);
val (more_Ts, ctxt') = mk_TFrees_of (n - have_n) S ctxt;
in (take have_n Ts @ more_Ts, (ctxt', Ts @ more_Ts)) end;
fun mk_comb_prop lift_pos thm af_inst ctxt_Ts =
let
val base = Thm.prop_of thm;
val tvars = Term.add_tvars base [];
val (Ts, (ctxt', Ts')) = reuse_TFrees (length tvars) (inner_sort_of af_inst) ctxt_Ts;
val base' = base
|> Term_Subst.instantiate (TVars.make (tvars ~~ Ts), Vars.empty);
val (lhs, rhs) = Logic.dest_equals base';
val (_, lhs_args) = strip_comb lhs;
val lift_var = Var o apsnd (mk_type af_inst) o dest_Var;
val (lhs_args', subst) = fold_index (fn (i, v) =>
if member (op =) lift_pos i then apfst (cons v)
else map_prod (cons (lift_var v)) (cons (v, lift_var v))) lhs_args ([], []);
val (lhs', rhs') = apply2 (subst_lift_term af_inst subst) (lhs, rhs);
val lifted = (lhs', rhs') |> HOLogic.mk_eq |> HOLogic.mk_Trueprop;
in (fold Logic.all lhs_args' lifted, (ctxt', Ts')) end;
fun mk_homomorphism_prop af_inst ctxt_Ts =
let
val ([T1, T2], (ctxt', Ts')) = reuse_TFrees 2 (inner_sort_of af_inst) ctxt_Ts;
val ((f, x), _) = ctxt'
|> mk_Free "f" (T1 --> T2)
||>> mk_Free "x" T1;
val lhs = mk_comb af_inst (T1 --> T2) (lift_term af_inst f, lift_term af_inst x);
val rhs = lift_term af_inst (f $ x);
val prop = HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs, rhs));
in (Logic.all f (Logic.all x prop), (ctxt', Ts')) end;
fun mk_interchange_prop af_inst ctxt_Ts =
let
val ([T1, T2], (ctxt', Ts')) = reuse_TFrees 2 (inner_sort_of af_inst) ctxt_Ts;
val ((f, x), _) = ctxt'
|> mk_Free "f" (mk_type af_inst (T1 --> T2))
||>> mk_Free "x" T1;
val lhs = mk_comb af_inst (T1 --> T2) (f, lift_term af_inst x);
val T_x = Abs ("f", T1 --> T2, Bound 0 $ x);
val rhs = mk_comb af_inst ((T1 --> T2) --> T2) (lift_term af_inst T_x, f);
val prop = HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs, rhs));
in (Logic.all f (Logic.all x prop), (ctxt', Ts')) end;
fun mk_rel_props (af_inst, rel_inst) ctxt_Ts =
let
fun mk_af_rel tm =
let val (T1, T2) = BNF_Util.dest_pred2T (fastype_of tm);
in betapply (instantiate_poly_term rel_inst [T1, T2], tm) end;
val ([T1, T2, T3], (ctxt', Ts')) = reuse_TFrees 3 (inner_sort_of af_inst) ctxt_Ts;
val (pure_R, _) = mk_Free "R" (T1 --> T2 --> @{typ bool}) ctxt';
val rel_pure = BNF_Util.mk_rel_fun pure_R (mk_af_rel pure_R) $ mk_pure af_inst T1 $
mk_pure af_inst T2;
val pure_prop = Logic.all pure_R (HOLogic.mk_Trueprop rel_pure);
val ((((f, g), x), ap_R), _) = ctxt'
|> mk_Free "f" (mk_type af_inst (T1 --> T2))
||>> mk_Free "g" (mk_type af_inst (T1 --> T3))
||>> mk_Free "x" (mk_type af_inst T1)
||>> mk_Free "R" (T2 --> T3 --> @{typ bool});
val fun_rel = BNF_Util.mk_rel_fun (mk_eq_on (mk_set af_inst T1 $ x)) ap_R;
val rel_ap = Logic.mk_implies (HOLogic.mk_Trueprop (mk_af_rel fun_rel $ f $ g),
HOLogic.mk_Trueprop (mk_af_rel ap_R $ mk_comb af_inst (T1 --> T2) (f, x) $
mk_comb af_inst (T1 --> T3) (g, x)));
val ap_prop = fold_rev Logic.all [ap_R, f, g, x] rel_ap;
in ([pure_prop, ap_prop], (ctxt', Ts')) end;
fun mk_interchange ctxt ((comb_defs, _), comb_unfolds) (af_inst, merge_thm) reds =
let
val T_def = the (Symtab.lookup comb_defs "T");
val T_red = the (Symtab.lookup reds "T");
val (weak_prop, (ctxt', _)) = mk_comb_prop [0] T_def af_inst (ctxt, []);
fun tac goal_ctxt =
HEADGOAL (Raw_Simplifier.rewrite_goal_tac goal_ctxt [Thm.symmetric merge_thm] THEN'
resolve_tac goal_ctxt [T_red]);
val weak_red =
singleton (Variable.export ctxt' ctxt) (Goal.prove ctxt' [] [] weak_prop (tac o #context));
in Raw_Simplifier.rewrite_rule ctxt (comb_unfolds) weak_red RS sym end;
fun mk_weak_reds ctxt ((comb_defs, _), comb_unfolds) af_inst (hom_thm, ichng_thm, reds) =
let
val unfolded_reds =
Symtab.map (K (Raw_Simplifier.rewrite_rule ctxt comb_unfolds)) reds;
val af_thms = mk_afun_thms ctxt af_inst (hom_thm, ichng_thm, unfolded_reds, NONE);
val af = mk_afun Binding.empty (pack_afun_inst af_inst) NONE af_thms;
fun tac goal_ctxt =
HEADGOAL (normalize_wrapper_tac goal_ctxt (SOME af) THEN'
Raw_Simplifier.rewrite_goal_tac goal_ctxt comb_unfolds THEN'
resolve_tac goal_ctxt [refl]);
fun mk comb lift_pos =
let
val def = the (Symtab.lookup comb_defs comb);
val (prop, (ctxt', _)) = mk_comb_prop lift_pos def af_inst (ctxt, []);
val hol_thm =
singleton (Variable.export ctxt' ctxt) (Goal.prove ctxt' [] [] prop (tac o #context));
in mk_meta_eq hol_thm end;
val uncurry_thm = mk_meta_eq (forward_lift_rule ctxt af @{thm uncurry_pair});
in
[mk "C" [1], mk "C" [2], uncurry_thm]
end;
fun mk_comb_reds ctxt combss af_inst user_combs (hom_thm, user_thms, ichng_thms) =
let
val ((comb_defs, comb_rules), comb_unfolds) = combss;
val merge_thm = mk_meta_eq hom_thm;
val user_reds = Symtab.make (user_combs ~~ user_thms);
val reds0 = combinator_red_closure ctxt (comb_defs, comb_rules) (af_inst, merge_thm) [] user_reds;
val ichng_thm = case ichng_thms of
[] => singleton (Variable.export ctxt ctxt) (mk_interchange ctxt combss (af_inst, merge_thm) reds0)
| [thm] => thm;
val weak_reds = mk_weak_reds ctxt combss af_inst (hom_thm, ichng_thm, reds0);
val reds1 = combinator_red_closure ctxt (comb_defs, comb_rules) (af_inst, merge_thm) weak_reds reds0;
val unfold = Raw_Simplifier.rewrite_rule ctxt comb_unfolds;
in (Symtab.map (K unfold) reds1, ichng_thm) end;
fun note_afun_thms af =
let
val thms = thms_of_afun af;
val named_thms =
[("homomorphism", [#hom thms]),
("interchange", [#ichng thms]),
("afun_rel_intros", #rel_intros thms)] @
map (fn (name, thm) => ("pure_" ^ name ^ "_conv", [thm])) (Symtab.dest (#reds thms)) @
(case #rel_thms thms of
NONE => []
| SOME rel_thms' =>
[("pure_transfer", [#pure_transfer rel_thms']),
("ap_rel_fun_cong", [#ap_rel_fun rel_thms'])]);
val base_name = Binding.name_of (name_of_afun af);
fun mk_note (name, thms) =
((Binding.qualify true base_name (Binding.name name), []), [(thms, [])]);
in Local_Theory.notes (map mk_note named_thms) #> #2 end;
fun register_afun af =
let fun decl phi context = Data.map (fn {combinators, afuns, patterns} =>
let
val af' = morph_afun phi af;
val (name, afuns') = Name_Space.define context true (name_of_afun af', af') afuns;
val patterns' = Item_Net.update (name, patterns_of_afun af') patterns;
in {combinators = combinators, afuns = afuns', patterns = patterns'} end) context;
in Local_Theory.declaration {syntax = false, pervasive = false, pos = ⌂} decl end;
fun applicative_cmd (((((name, flags), raw_pure), raw_ap), raw_rel), raw_set) lthy =
let
val comb_unfolds = Named_Theorems.get lthy @{named_theorems combinator_unfold};
val comb_reprs = Named_Theorems.get lthy @{named_theorems combinator_repr};
val (comb_defs, comb_rules) = get_combinators (Context.Proof lthy);
val _ = fold (fn name =>
if Symtab.defined comb_defs name then I else error ("Unknown combinator " ^ quote name))
flags ();
val _ = if has_duplicates op = flags
then warning "Ignoring duplicate combinators"
else ();
val user_combs0 = Ord_List.make fast_string_ord flags;
val raw_pure' = Syntax.read_term lthy raw_pure;
val raw_ap' = Syntax.read_term lthy raw_ap;
val raw_rel' = Option.map (Syntax.read_term lthy) raw_rel;
val raw_set' = Option.map (Syntax.read_term lthy) raw_set;
val (terms, rel) = mk_terms lthy (raw_pure', raw_ap', raw_rel', raw_set');
val derived_combs0 = combinator_closure comb_rules false user_combs0;
val required_combs = Ord_List.make fast_string_ord ["B", "I"];
val user_combs = Ord_List.union fast_string_ord user_combs0
(Ord_List.subtract fast_string_ord derived_combs0 required_combs);
val derived_combs1 = combinator_closure comb_rules false user_combs;
val derived_combs2 = combinator_closure comb_rules true derived_combs1;
fun is_redundant comb = eq_list (op =) (derived_combs2,
(combinator_closure comb_rules true (Ord_List.remove fast_string_ord comb user_combs)));
val redundant_combs = filter is_redundant user_combs;
val _ = if null redundant_combs then () else
warning ("Redundant combinators: " ^ commas redundant_combs);
val prove_interchange = not (Ord_List.member fast_string_ord derived_combs1 "T");
val (af_inst, ctxt_af) = import_afun_inst_raw terms lthy;
val (rel_insts, ctxt_inst) = (case rel of
NONE => (NONE, ctxt_af)
| SOME r =>
let
val (rel_inst, ctxt') = import_poly_terms r ctxt_af |>> the_single;
val T = fastype_of (#2 rel_inst) |> range_type |> domain_type;
val af_inst = match_poly_terms_type ctxt' (terms, 0) (T, ~1) |> mk_afun_inst;
in (SOME (af_inst, rel_inst), ctxt') end);
val mk_propss = [apfst single o mk_homomorphism_prop af_inst,
fold_map (fn comb => mk_comb_prop [] (the (Symtab.lookup comb_defs comb)) af_inst) user_combs,
if prove_interchange then apfst single o mk_interchange_prop af_inst else pair [],
if is_some rel then mk_rel_props (the rel_insts) else pair []];
val (propss, (ctxt_Ts, _)) = fold_map I mk_propss (ctxt_inst, []);
fun repr_tac ctxt = Raw_Simplifier.rewrite_goals_tac ctxt comb_reprs;
fun after_qed thmss lthy' =
let
val [[hom_thm], user_thms, ichng_thms, rel_thms] = map (Variable.export lthy' ctxt_inst) thmss;
val (reds, ichng_thm) = mk_comb_reds ctxt_inst ((comb_defs, comb_rules), comb_unfolds)
af_inst user_combs (hom_thm, user_thms, ichng_thms);
val rel_axioms = case rel_thms of
[] => NONE
| [thm1, thm2] => SOME {pure_transfer = thm1, ap_rel_fun = thm2};
val af_thms = mk_afun_thms ctxt_inst af_inst (hom_thm, ichng_thm, reds, rel_axioms);
val af_thms = map_afun_thms (singleton (Variable.export ctxt_inst lthy)) af_thms;
val af = mk_afun name terms rel af_thms;
in lthy
|> register_afun af
|> note_afun_thms af
end;
in
Proof.theorem NONE after_qed ((map o map) (rpair []) propss) ctxt_Ts
|> Proof.refine (Method.Basic (SIMPLE_METHOD o repr_tac))
|> Seq.the_result ""
end;
fun print_afuns ctxt =
let
fun pretty_afun (name, af) =
let
val [pT, (_, pure), (_, ap), (_, set)] = unpack_poly_terms (terms_of_afun af);
val ([tvar], T) = poly_type_of_term pT;
val rel = Option.map (#2 o the_single o unpack_poly_terms) (rel_of_afun af);
val combinators = Symtab.keys (#reds (thms_of_afun af));
in Pretty.block (Pretty.fbreaks ([Pretty.block [Pretty.str name, Pretty.str ":", Pretty.brk 1,
Pretty.quote (Syntax.pretty_typ ctxt T), Pretty.brk 1, Pretty.str "of", Pretty.brk 1,
Syntax.pretty_typ ctxt tvar],
Pretty.block [Pretty.str "pure:", Pretty.brk 1, Pretty.quote (Syntax.pretty_term ctxt pure)],
Pretty.block [Pretty.str "ap:", Pretty.brk 1, Pretty.quote (Syntax.pretty_term ctxt ap)],
Pretty.block [Pretty.str "set:", Pretty.brk 1, Pretty.quote (Syntax.pretty_term ctxt set)]] @
(case rel of
NONE => []
| SOME rel' => [Pretty.block [Pretty.str "rel:", Pretty.brk 1,
Pretty.quote (Syntax.pretty_term ctxt rel')]]) @
[Pretty.block ([Pretty.str "combinators:", Pretty.brk 1] @
Pretty.commas (map Pretty.str combinators))])) end;
val afuns = sort_by #1 (Name_Space.fold_table cons (get_afun_table (Context.Proof ctxt)) []);
in Pretty.writeln (Pretty.big_list "Registered applicative functors:" (map pretty_afun afuns)) end;
fun add_unfold_thm name thm context =
let
val (lhs, _) = Thm.prop_of thm |> HOLogic.dest_Trueprop |> HOLogic.dest_eq
handle TERM _ => error "Not an equation";
val names = case name of
SOME n => [intern context n]
| NONE => case match_typ context (Term.fastype_of lhs) of
ns as (_::_) => ns
| [] => error "Unable to determine applicative functor instance";
val _ = map (afun_of_generic context) names;
val thm' = mk_meta_eq thm;
in fold (fn n => update_afun n (add_unfolds_afun [thm'])) names context end;
fun add_unfold_attrib name = Thm.declaration_attribute (add_unfold_thm name);
end;