File ‹~~/src/Provers/Arith/assoc_fold.ML›
signature ASSOC_FOLD_DATA =
sig
val assoc_ss: simpset
val eq_reflection: thm
val is_numeral: term -> bool
end;
signature ASSOC_FOLD =
sig
val proc: Proof.context -> term -> thm option
end;
functor Assoc_Fold(Data: ASSOC_FOLD_DATA): ASSOC_FOLD =
struct
exception Assoc_fail;
fun mk_sum plus [] = raise Assoc_fail
| mk_sum plus tms = foldr1 (fn (x, y) => plus $ x $ y) tms;
fun sift_terms plus (t, (lits,others)) =
if Data.is_numeral t then (t::lits, others) else
(case t of
(f as Const _) $ x $ y =>
if f = plus
then sift_terms plus (x, sift_terms plus (y, (lits,others)))
else (lits, t::others)
| _ => (lits, t::others));
fun proc ctxt lhs =
let
val plus = (case lhs of f $ _ $ _ => f | _ => error "Assoc_fold: bad pattern")
val (lits, others) = sift_terms plus (lhs, ([],[]))
val _ = length lits < 2 andalso raise Assoc_fail
val rhs = plus $ mk_sum plus lits $ mk_sum plus others
val th =
Goal.prove ctxt [] [] (Logic.mk_equals (lhs, rhs)) (fn _ =>
resolve_tac ctxt [Data.eq_reflection] 1 THEN
simp_tac (put_simpset Data.assoc_ss ctxt) 1)
in SOME th end handle Assoc_fail => NONE;
end;