File ‹simps_case_conv.ML›
signature SIMPS_CASE_CONV =
sig
val to_case: Proof.context -> thm list -> thm
val gen_to_simps: Proof.context -> thm list -> thm -> thm list
val to_simps: Proof.context -> thm -> thm list
end
structure Simps_Case_Conv: SIMPS_CASE_CONV =
struct
fun collect_Tcons (Type (name,Ts)) = name :: maps collect_Tcons Ts
| collect_Tcons (TFree _) = []
| collect_Tcons (TVar _) = []
fun get_type_infos ctxt =
maps collect_Tcons
#> distinct (op =)
#> map_filter (Ctr_Sugar.ctr_sugar_of ctxt)
fun get_split_ths ctxt = get_type_infos ctxt #> map #split
val strip_eq = Thm.prop_of #> HOLogic.dest_Trueprop #> HOLogic.dest_eq
fun to_case ctxt ths =
let
fun ctr_count (ctr, T) =
let
val tyco = body_type T |> dest_Type |> fst
val info = Ctr_Sugar.ctr_sugar_of ctxt tyco
val _ = if is_none info then error ("Pattern match on non-constructor constant " ^ ctr) else ()
in
info |> the |> #ctrs |> length
end
val thms = Case_Converter.to_case ctxt Case_Converter.keep_constructor_context ctr_count ths
in
case thms of SOME thms => hd thms
| _ => error ("Conversion to case expression failed.")
end
local
fun was_split t =
let
val is_free_eq_imp = is_Free o fst o HOLogic.dest_eq o fst o HOLogic.dest_imp
val get_conjs = HOLogic.dest_conj o HOLogic.dest_Trueprop
val dest_alls = Term.strip_qnt_body \<^const_name>‹All›
in forall (is_free_eq_imp o dest_alls) (get_conjs t) end
handle TERM _ => false
fun apply_split ctxt split thm = Seq.of_list
let val ((_,thm'), ctxt') = Variable.import false [thm] ctxt in
(Variable.export ctxt' ctxt) (filter (was_split o Thm.prop_of) (thm' RL [split]))
end
fun forward_tac rules t = Seq.of_list ([t] RL rules)
val refl_imp = refl RSN (2, mp)
val get_rules_once_split =
REPEAT (forward_tac [conjunct1, conjunct2])
THEN REPEAT (forward_tac [spec])
THEN (forward_tac [refl_imp])
fun do_split ctxt split =
case try op RS (split, iffD1) of
NONE => raise TERM ("malformed split rule", [Thm.prop_of split])
| SOME split' =>
let val split_rhs = Thm.concl_of (hd (snd (fst (Variable.import false [split'] ctxt))))
in if was_split split_rhs
then DETERM (apply_split ctxt split') THEN get_rules_once_split
else raise TERM ("malformed split rule", [split_rhs])
end
val atomize_meta_eq = forward_tac [meta_eq_to_obj_eq]
in
fun gen_to_simps ctxt splitthms thm =
let val splitthms' = filter (fn t => not (Thm.eq_thm (t, Drule.dummy_thm))) splitthms
in
Seq.list_of ((TRY atomize_meta_eq THEN (REPEAT (FIRST (map (do_split ctxt) splitthms')))) thm)
end
fun to_simps ctxt thm =
let
val T = thm |> strip_eq |> fst |> strip_comb |> fst |> fastype_of
val splitthms = get_split_ths ctxt [T]
in gen_to_simps ctxt splitthms thm end
end
fun case_of_simps_cmd (bind, thms_ref) int lthy =
let
val bind' = apsnd (map (Attrib.check_src lthy)) bind
val thm = Attrib.eval_thms lthy thms_ref |> to_case lthy
val (res, lthy') = Local_Theory.note (bind', [thm]) lthy
val _ =
Proof_Display.print_results
{interactive = int, pos = Position.thread_data (), proof_state = false}
lthy' ((Thm.theoremK, ""), [res])
in lthy' end
fun simps_of_case_cmd ((bind, thm_ref), splits_ref) int lthy =
let
val bind' = apsnd (map (Attrib.check_src lthy)) bind
val thm = singleton (Attrib.eval_thms lthy) thm_ref
val simps = if null splits_ref
then to_simps lthy thm
else gen_to_simps lthy (Attrib.eval_thms lthy splits_ref) thm
val (res, lthy') = Local_Theory.note (bind', simps) lthy
val _ =
Proof_Display.print_results
{interactive = int, pos = Position.thread_data (), proof_state = false}
lthy' ((Thm.theoremK, ""), [res])
in lthy' end
val _ =
Outer_Syntax.local_theory' \<^command_keyword>‹case_of_simps›
"turn a list of equations into a case expression"
(Parse_Spec.opt_thm_name ":" -- Parse.thms1 >> case_of_simps_cmd)
val parse_splits = \<^keyword>‹(› |-- Parse.reserved "splits" |-- \<^keyword>‹:› |--
Parse.thms1 --| \<^keyword>‹)›
val _ =
Outer_Syntax.local_theory' \<^command_keyword>‹simps_of_case›
"perform case split on rule"
(Parse_Spec.opt_thm_name ":" -- Parse.thm --
Scan.optional parse_splits [] >> simps_of_case_cmd)
end