File ‹casify.ML›
signature CASIFY =
sig
datatype options = Options of { simp_all_cases: bool, split_right_only: bool,
protect_subgoals: bool }
val hyp_subst_tac: Proof.context -> int -> tactic
val SPLIT_subst_tac: Proof.context -> int -> tactic
val extract_cases_tac: context_tactic
val prepare_labels_tac: Proof.context -> int -> tactic
val split_bind_all_tac: {right_only: bool, simp_all_cases: bool} -> Proof.context -> int -> tactic
val casify_tac: options -> context_tactic
val casify_options: options -> options parser
val casify_method_setup: options -> (Proof.context -> Method.method) context_parser
end
structure Casify : CASIFY =
struct
val bind_unnamedN = "case"
val case_premsN = "prems"
val case_unnamedN = "unnamed"
datatype 'a prg_ctxt = Block of (string * 'a list)
datatype 'a prg_concl = Prg_Concl of ((string * 'a) option * term)
fun dest_var (Const (@{const_name Case_Labeling.VAR}, _) $ t) = Util.dest_tuple t
| dest_var t = raise TERM ("dest_var", [t])
val dest_vars = HOLogic.dest_list #> maps dest_var
val dest_ct =
let
fun mk_block na ic vs =
let val (_, idx) = HOLogic.dest_number ic
in (idx, Block (HOLogic.dest_string na, dest_vars vs)) end
fun dest_block (Const (@{const_name Pair}, _) $ na $ (Const (@{const_name Pair}, _) $ ic $ vs)) =
mk_block na ic vs
| dest_block t = raise TERM ("dest_block", [t])
in HOLogic.dest_list #> rev #> map dest_block end
fun dest_VC (Const (@{const_name Case_Labeling.VC}, _) $ ct $ t) = (dest_ct ct, t)
| dest_VC t = ([], t)
fun try_dest_Trueprop t = case try HOLogic.dest_Trueprop t of
NONE => t
| SOME t' => t'
fun dest_BIND (Const (@{const_name Case_Labeling.BIND}, _) $ na $ ic $ t) =
let
val s = HOLogic.dest_string na
val (_, n) = HOLogic.dest_number ic
in (SOME (s,n), t) end
| dest_BIND t = (NONE, t)
fun dest_SPLIT (Const (@{const_name Case_Labeling.SPLIT}, _) $ t $ u) = (t,u)
| dest_SPLIT t = raise TERM ("dest_SPLIT", [t])
fun dest_Bound (Bound i) = [i]
| dest_Bound t = raise TERM ("dest_bound", [t])
fun dest_HIER (Const (@{const_name Case_Labeling.HIER},_) $ ct $ t) =
(SOME (length (HOLogic.dest_list ct)), t)
| dest_HIER t = (NONE, t)
fun parse_label prop =
let
val vars = Term.strip_all_vars prop
val ((prems, label), _) = prop
|> (Logic.strip_horn o Term.strip_all_body)
||> try_dest_Trueprop
||>> dest_VC
in { vars=vars, label=label, prems=prems} end
fun parse_prem prop =
let
val vars = Term.strip_all_vars prop
val (prems, (hier, concl)) = prop
|> (Logic.strip_horn o Term.strip_all_body)
||> try_dest_Trueprop
||> dest_HIER
val prop' = Logic.list_all (vars, Logic.list_implies (prems, HOLogic.mk_Trueprop concl))
in (hier, prop') end
fun strip_prg_ctxt n ((params, t), _) =
let
fun lookup_delete eq y =
let
fun aux _ [] = error ("Term is not a parameter: " ^ @{make_string} y)
| aux acc ((k,v) :: xs) =
if eq (k, y) then (v, rev acc @ xs)
else aux ((k,v) :: acc) xs
in aux [] end
fun upd_label_vars ps [] = [(~n, Block (case_unnamedN, map snd ps))]
| upd_label_vars ps [(n, Block (s,vs))] =
let
val (vs', ps') = fold_map (lookup_delete (op aconv)) vs ps
in [(n, Block (s, vs' @ map snd ps'))] end
| upd_label_vars ps ((n, Block (s,vs)) :: ct) =
let val (vs', ps') = fold_map (lookup_delete (op aconv)) vs ps
in (n, Block (s,vs')) :: upd_label_vars ps' ct end
val build_param_map = map (fn (s,(s', T)) => (Free (s',T), {fix= (s,T), abs=(s',T)}))
val (asms, (label, concl)) = t
|> Logic.strip_horn
||> try_dest_Trueprop
||> dest_VC
val prg_ctxt = upd_label_vars (build_param_map params) label
val ctxt_len = length prg_ctxt
val asms' = map (parse_prem #> apfst (the_default ctxt_len)) asms
in
(prg_ctxt, (asms', concl))
end
fun strip_prg_ctxts xs =
let val n = length xs in map_index (fn (i,x) => strip_prg_ctxt (n - i) x) xs end
datatype ('a,'b) precase = Precase of
{fixes: (binding * typ) list,
assumes: ('b * term list) list,
binds: (indexname * term option) list,
cases: ('a * ('a,'b) precase) list}
fun bindings args = map (apfst Binding.name) args
fun coalesce_order ord = sort (Util.fst_ord ord) #> AList.coalesce (is_equal o ord)
fun unique_names xs = fst (fold_map (Util.infst Name.variant) xs Name.context)
fun build_precase (prg_ctxt, (prems, t))=
let
val sorted_prems = prems
|> map (apsnd (fn t => (fst (dest_BIND (try_dest_Trueprop t)), t)))
|> sort (prod_ord int_ord (prod_ord (option_ord (prod_ord fast_string_ord int_ord)) (K EQUAL)))
fun drop_labels (Const (@{const_name "Case_Labeling.BIND"}, _) $ _ $ _ $ t ) = drop_labels t
| drop_labels (t1 $ t2) = drop_labels t1 $ drop_labels t2
| drop_labels (Abs (x,T,t)) = Abs (x,T, drop_labels t)
| drop_labels t = t
fun find_binds (t as Const (@{const_name "Case_Labeling.BIND"}, _) $ _ $ _ $ _) =
( case dest_BIND t of
(NONE, t) => find_binds t
| (SOME (s,n), t) => (n,(s,t)) :: find_binds t)
| find_binds (t1 $ t2) = find_binds t1 @ find_binds t2
| find_binds (Abs (_,_,t)) = find_binds t
| find_binds _ = []
fun has_loose_bounds t = case loose_bnos t of
[] => false
| _ :: _ => (warning "loose bounds in term"; true)
fun unique_binds _ [] = []
| unique_binds acc ((s,x) :: xs) =
(case AList.lookup (op=) acc s of
NONE => ((s,0), x) :: unique_binds ((s,0) :: acc) xs
| SOME n => ((s,n+1), x) :: unique_binds (AList.update (op=) (s,n+1) acc) xs
)
fun mk_precase _ _ [] [] = []
| mk_precase _ _ [] (_::_) = error "premise with a too long HIERarchy label"
| mk_precase n abs_ofixes (bl :: prg_ctxt) prems =
let
val (m, Block (s, vars)) = bl
val fixes = map #fix vars
val params = map #abs vars
val (prems1, prems2) = chop_prefix (fn (m,_) => m = n) prems
val abs_fixes = abs_ofixes o fold_rev Term.absfree params
val prems' = prems1
|> map (fn (_, (x,t)) => (swap (the_default (case_premsN, ~1) x), t))
|> coalesce_order (prod_ord int_ord string_ord)
|> map (apfst snd)
|> unique_names
val assumes = map (apsnd (map abs_fixes)) prems'
val fixes' = bindings fixes
val binds =
let
val prem_bs = maps find_binds (maps snd prems')
val (concl_cases, concl_bs) = case prg_ctxt of
[] => ([(bind_unnamedN, abs_fixes t)], find_binds t)
| (_::_) => ([],[])
val bs = concl_bs @ prem_bs
|> sort (prod_ord int_ord (prod_ord string_ord Term_Ord.fast_term_ord))
|> map snd
|> map (apsnd abs_fixes)
|> filter_out (has_loose_bounds o snd)
in
concl_cases @ bs
|> unique_binds []
|> map (apsnd SOME)
end
val precase = Precase {
fixes = fixes',
assumes = assumes,
binds = binds,
cases = mk_precase (n+1) abs_fixes prg_ctxt prems2 }
in [((m,s), precase)] end
in mk_precase 1 drop_labels prg_ctxt sorted_prems end
fun merge_precases precases =
let
val _ = precases : ((int * string) * (int * string, string) precase) list
fun merge_p (_, []) = error "empty case"
| merge_p (s, Precase {fixes, assumes, binds, cases} :: pcs) =
let
fun sel_cases (Precase {cases, ...}) = cases
val cases' = merge_precases (cases @ maps sel_cases pcs)
in (s, Rule_Cases.Case {fixes = fixes, assumes = assumes, binds = binds, cases = cases'}) end
val cases = precases
|> coalesce_order (prod_ord int_ord string_ord)
|> map (apfst snd)
|> unique_names
|> map merge_p
in cases end
fun mk_cases' ctxt = Thm.prems_of
#> map (fn t => Variable.focus NONE t ctxt)
#> strip_prg_ctxts
#> maps build_precase
#> merge_precases
#> map (apsnd SOME)
fun normalize_conv cv ctxt ct = Conv.bottom_conv (fn ctxt => Conv.try_conv (
cv then_conv normalize_conv cv ctxt
)) ctxt ct
fun prepare_labels_tac ctxt =
let
val suc_numeral_simps = @{thms
Suc_numeral_simps[THEN eq_reflection]
append.simps[THEN eq_reflection]}
val app_simps = @{thms append.simps[THEN eq_reflection]}
val suc_to_num_conv = normalize_conv (Conv.rewrs_conv (suc_numeral_simps @ app_simps))
val label_fun_conv = Conv.fun_conv o suc_to_num_conv
fun label_conv ctxt ct =
(case Thm.term_of ct of
Const (@{const_name Case_Labeling.BIND}, _) $ _ $ _ $ _ => label_fun_conv ctxt ct
| Const (@{const_name Case_Labeling.HIER}, _) $ _ $ _ => label_fun_conv ctxt ct
| Const (@{const_name Case_Labeling.VC}, _) $ _ $ _ => label_fun_conv ctxt ct
| _ => Conv.no_conv ct)
fun norm_labels_conv ctxt ct =
Conv.bottom_conv (Conv.try_conv o label_conv) ctxt ct
in
CONVERSION (norm_labels_conv ctxt)
end
fun extract_cases_tac (ctxt, st) =
let val tac = unfold_tac ctxt @{thms LABEL_simps}
in CONTEXT_CASES (mk_cases' ctxt st) tac (ctxt, st) end
structure Splitsubst = Hypsubst
(
val dest_Trueprop = HOLogic.dest_Trueprop
val dest_eq = dest_SPLIT
val dest_imp = HOLogic.dest_imp
val eq_reflection = @{thm SPLIT_reflection}
val rev_eq_reflection = @{thm rev_SPLIT_reflection}
val imp_intr = @{thm impI}
val rev_mp = @{thm rev_mp}
val subst = @{thm SPLIT_subst}
val sym = @{thm SPLIT_sym}
val thin_refl = @{thm SPLIT_thin_refl}
)
val hyp_subst_tac = Splitsubst.hyp_subst_tac
fun SPLIT_subst_tac ctxt =
REPEAT_ALL_NEW (REPEAT_ALL_NEW (ematch_tac ctxt @{thms SPLIT_prodE}) THEN' Splitsubst.hyp_subst_tac ctxt)
local
val case_prod_th = @{thm split_conv[THEN eq_reflection]}
fun case_prod_conv ctxt =
Conv.forall_conv (fn (x, ctxt) =>
Conv.forall_conv (fn (y, ctxt) => fn ct =>
let
val insts = [NONE, SOME x, SOME y]
val typ_insts = map (Option.map Thm.ctyp_of_cterm) insts
val th = Thm.instantiate' typ_insts insts case_prod_th
in Conv.bottom_conv (fn _ => Conv.try_conv (Conv.rewr_conv th)) ctxt ct end
) ctxt) ctxt
in
fun split_nth_all_conv {right_only:bool} =
let
val desc_conv = if right_only then Conv.try_conv else Conv.repeat_conv
fun conv ctxt 0 ct =
desc_conv (
Conv.rewr_conv @{thm split_paired_all}
then_conv case_prod_conv ctxt
then_conv Conv.try_conv (conv ctxt 1)
) ct
| conv ctxt n ct =
Conv.forall_conv (fn (_, ctxt) => conv ctxt (n - 1)) ctxt ct
in conv end
fun split_bind_all_conv right_only ctxt ct =
let
val {vars, label, prems} = parse_label (Thm.term_of ct)
val nvars = length vars
val bv_ts = label |> map (fn x => case x of (_, Block (_, vs)) => vs) |> flat
val sv_ts = prems
|> map_filter (try dest_SPLIT)
|> map snd
|> maps Util.dest_tuple
val lvars = bv_ts @ sv_ts
|> maps dest_Bound
|> sort_distinct int_ord
|> map (fn n => nvars - n - 1)
in Conv.every_conv (map (split_nth_all_conv right_only ctxt) lvars) ct end
fun RAWCONV cv i st = Seq.single (Conv.gconv_rule cv i st)
fun split_bind_all_tac {right_only: bool, simp_all_cases: bool} ctxt =
let
val rewr_tac = if simp_all_cases
then Raw_Simplifier.rewrite_goal_tac ctxt [case_prod_th] else K all_tac
in
RAWCONV (split_bind_all_conv {right_only = right_only} ctxt) THEN' rewr_tac
end
end
datatype options = Options of { simp_all_cases: bool, split_right_only: bool, protect_subgoals: bool}
fun casify_tac (Options { simp_all_cases, split_right_only, protect_subgoals }) (ctxt, st) =
let
fun inst_disambig ct =
Thm.instantiate (TVars.empty, Vars.make [((("n", 0), @{typ nat}), ct)]) @{thm DISAMBIG_I}
fun disambig_tac i =
if protect_subgoals
then match_tac ctxt [inst_disambig (Thm.cterm_of ctxt (HOLogic.mk_number @{typ nat} i))]
else K all_tac
val prep_tac = DETERM (ALLGOALS (fn i => EVERY'
[ prepare_labels_tac ctxt,
split_bind_all_tac { simp_all_cases = simp_all_cases, right_only = split_right_only } ctxt,
TRY o SPLIT_subst_tac ctxt,
disambig_tac i
] i
))
in (ctxt, st) |> (prep_tac THEN_CONTEXT extract_cases_tac) end
local
fun set_simp_all_cases simp_all_cases
(Options { simp_all_cases = _, split_right_only, protect_subgoals}) =
Options { simp_all_cases = simp_all_cases, split_right_only = split_right_only,
protect_subgoals = protect_subgoals }
fun set_protect_subgoals protect_subgoals
(Options { simp_all_cases, split_right_only, protect_subgoals = _}) =
Options { simp_all_cases = simp_all_cases, split_right_only = split_right_only,
protect_subgoals = protect_subgoals }
fun set_split_right_only split_right_only
(Options { simp_all_cases, split_right_only = _, protect_subgoals}) =
Options { simp_all_cases = simp_all_cases, split_right_only = split_right_only,
protect_subgoals = protect_subgoals }
val options = map (fn (s,f) => Args.parens (Args.$$$ s) >> K f)
[ ("simp", set_simp_all_cases true),
("no_simp", set_simp_all_cases false),
("disambig_subgoals", set_protect_subgoals true),
("no_disambig_subgoals", set_protect_subgoals false),
("split_right_only", set_split_right_only true),
("no_split_right_only", set_split_right_only false)
]
fun scan_alt scans = fold (fn scan1 => fn scan2 => scan1 || scan2) scans Scan.fail
in
fun casify_options def = Scan.repeat (scan_alt options)
>> (fn xs => fold (fn f => fn x => f x) xs def)
end
fun casify_method_setup def =
Scan.lift (casify_options def) >>
(fn opt => fn _ => Util.SIMPLE_METHOD_CASES (casify_tac opt))
end