File ‹lazy_case.ML›
signature LAZY_CASE = sig
val lazify: Ctr_Sugar.ctr_sugar -> local_theory -> local_theory
val lazify_typ: typ -> local_theory -> local_theory
val lazify_cmd: string -> local_theory -> local_theory
val lazy_case_plugin: string
val setup: theory -> theory
end
structure Lazy_Case : LAZY_CASE = struct
structure Data = Generic_Data
(
type T = Symtab.set
val empty = Symtab.empty
val merge = Symtab.merge op =
)
fun init [] = error "empty list"
| init [_] = []
| init (x :: xs) = x :: init xs
fun lazify {T, casex, ctrs, case_thms, case_cong, ...} lthy =
let
val is_fun = can dest_funT
val typ_name = fst (dest_Type T)
val len = length ctrs
val idxs = 0 upto len - 1
val (name, typ) = dest_Const casex ||> Logic.unvarifyT_global
val (typs, _) = strip_type typ
val exists = Symtab.defined (Data.get (Context.Proof lthy)) typ_name
val warn = Pretty.separate "" [Syntax.pretty_typ lthy T, Pretty.str "already lazified"]
|> Pretty.block
val _ = if exists then warning (Pretty.string_of warn) else ()
in
if Term.is_dummy_pattern casex orelse forall is_fun (init typs) orelse exists then
lthy
else
let
val arg_typs = init typs
fun apply_to_unit typ idx =
if is_fun typ then
(typ, Bound idx)
else
(@{typ unit} --> typ, Bound idx $ @{term "()"})
val (arg_typs', args) = split_list (map2 apply_to_unit arg_typs (rev idxs))
val def = list_comb (Const (name, typ), args)
|> fold_rev Term.abs (map (pair "c") arg_typs')
val binding = Binding.name "case_lazy"
val ((term, thm), (lthy', lthy)) =
(snd o Local_Theory.begin_nested) lthy
|> Proof_Context.concealed
|> Local_Theory.map_background_naming (Name_Space.mandatory_path typ_name)
|> Local_Theory.define ((binding, NoSyn), ((Thm.def_binding binding, []), def))
|>> apsnd snd
||> `Local_Theory.end_nested
val phi = Proof_Context.export_morphism lthy lthy'
val thm' = Morphism.thm phi thm
val term' = Logic.unvarify_global (Morphism.term phi term)
fun defs_tac ctxt idx =
Local_Defs.unfold_tac ctxt [thm', nth case_thms idx] THEN
HEADGOAL (resolve_tac ctxt @{thms refl})
val frees = fastype_of term' |> strip_type |> fst |> init
val frees_f = Name.invent_names_global "f0" frees
val frees_g = Name.invent_names_global "g0" frees
val fs = map Free frees_f
val gs = map Free frees_g
fun mk_def_goal ctr idx =
let
val (name, len) = dest_Const ctr ||> strip_type ||> fst ||> length
val frees = Name.invent_global "p0" len
val args = map (Free o rpair dummyT) frees
val ctr_val = list_comb (Const (name, dummyT), args)
val lhs = list_comb (term', fs) $ ctr_val
val rhs =
if len = 0 then
nth fs idx $ @{term "()"}
else
list_comb (nth fs idx, args)
in
(frees, HOLogic.mk_Trueprop (Syntax.check_term lthy' (HOLogic.mk_eq (lhs, rhs))))
end
fun prove_defs (frees', goal) idx =
Goal.prove_future lthy' (map fst frees_f @ frees') [] goal (fn {context, ...} =>
defs_tac context idx)
val def_thms = map2 prove_defs (map2 mk_def_goal ctrs idxs) idxs
val frees = Name.invent_names_global "q0" arg_typs
val unfold_goal =
let
val lhs = list_comb (Const (name, typ), map Free frees)
fun mk_abs (name, typ) =
if is_fun typ then
Free (name, typ)
else
Abs ("u", @{typ unit}, Free (name, typ))
val rhs = list_comb (Const (fst (dest_Const term'), dummyT), map mk_abs frees)
in
HOLogic.mk_Trueprop (Syntax.check_term lthy' (HOLogic.mk_eq (lhs, rhs)))
end
fun unfold_tac ctxt =
Local_Defs.unfold_tac ctxt [thm'] THEN
HEADGOAL (resolve_tac ctxt @{thms refl})
val unfold_thm =
Goal.prove_future lthy' (map fst frees) [] unfold_goal (fn {context, ...} =>
unfold_tac context)
fun mk_cong_prem t ctr (f, g) =
let
fun mk_all t = Logic.all_const dummyT $ Abs ("", dummyT, t)
val len = dest_Const ctr |> snd |> strip_type |> fst |> length
val args = map Bound (len - 1 downto 0)
val ctr_val = list_comb (Logic.unvarify_global ctr, args)
val args' = if len = 0 then [Bound 0] else args
val lhs = list_comb (f, args')
val rhs = list_comb (g, args')
val concl = HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs, rhs))
val prem = HOLogic.mk_Trueprop (HOLogic.mk_eq (t, ctr_val))
in fold (K mk_all) args' (Logic.mk_implies (prem, concl)) end
val cong_goal =
let
val t1 = Free ("t1", Logic.unvarifyT_global T)
val t2 = Free ("t2", Logic.unvarifyT_global T)
val prems =
HOLogic.mk_Trueprop (HOLogic.mk_eq (t1, t2)) ::
map2 (mk_cong_prem t2) ctrs (fs ~~ gs)
val lhs = list_comb (term', fs) $ t1
val rhs = list_comb (term', gs) $ t2
val concl = HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs, rhs))
in
Logic.list_implies (prems, concl)
|> Syntax.check_term lthy'
end
fun cong_tac ctxt =
Local_Defs.unfold_tac ctxt [thm'] THEN
HEADGOAL (eresolve_tac ctxt [case_cong]) THEN
ALLGOALS (ctxt |> Subgoal.FOCUS
(fn {context = ctxt, prems, ...} =>
HEADGOAL (resolve_tac ctxt prems THEN' resolve_tac ctxt prems)))
val cong_thm =
Goal.prove_future lthy' ("t1" :: "t2" :: map fst frees_f @ map fst frees_g) [] cong_goal
(fn {context, ...} => cong_tac context)
val upd = Data.map (Symtab.update_new (typ_name, ()))
in
lthy'
|> Local_Theory.note ((Binding.empty, @{attributes [code]}), def_thms) |> snd
|> Local_Theory.note ((Binding.empty, @{attributes [code_unfold]}), [unfold_thm]) |> snd
|> Local_Theory.note ((Binding.empty, @{attributes [fundef_cong]}), [cong_thm]) |> snd
|> Local_Theory.declaration {syntax = false, pervasive = true, pos = ⌂} (K upd)
end
end
fun lazify_typ typ lthy =
lazify (the (Ctr_Sugar.ctr_sugar_of lthy (fst (dest_Type typ)))) lthy
fun lazify_cmd s lthy =
lazify_typ (Proof_Context.read_type_name {proper = true, strict = false} lthy s) lthy
val lazy_case_plugin =
Plugin_Name.declare_setup @{binding lazy_case}
val _ =
Outer_Syntax.local_theory
@{command_keyword "lazify"}
"defines a lazy case constant and sets up the code generator"
(Scan.repeat1 Parse.embedded_inner_syntax >> fold lazify_cmd)
val setup = Ctr_Sugar.ctr_sugar_interpretation lazy_case_plugin (lazify_typ o #T)
end