File ‹rewrite.ML›

(*  Title:      HOL/Library/rewrite.ML
    Author:     Christoph Traut, Lars Noschinski, TU Muenchen

This is a rewrite method that supports subterm-selection based on patterns.

The patterns accepted by rewrite are of the following form:
  <atom>    ::= <term> | "concl" | "asm" | "for" "(" <names> ")"
  <pattern> ::= (in <atom> | at <atom>) [<pattern>]
  <args>    ::= [<pattern>] ("to" <term>) <thms>

This syntax was clearly inspired by Gonthier's and Tassi's language of
patterns but has diverged significantly during its development.

We also allow introduction of identifiers for bound variables,
which can then be used to match arbitrary subterms inside abstractions.
*)

infix 1 then_pconv;
infix 0 else_pconv;

signature REWRITE =
sig
  type patconv = Proof.context -> Type.tyenv * (string * term) list -> cconv
  val then_pconv: patconv * patconv -> patconv
  val else_pconv: patconv * patconv -> patconv
  val abs_pconv:  patconv -> string option * typ -> patconv (*XXX*)
  val fun_pconv: patconv -> patconv
  val arg_pconv: patconv -> patconv
  val imp_pconv: patconv -> patconv
  val params_pconv: patconv -> patconv
  val forall_pconv: patconv -> string option * typ option -> patconv
  val all_pconv: patconv
  val for_pconv: patconv -> (string option * typ option) list -> patconv
  val concl_pconv: patconv -> patconv
  val asm_pconv: patconv -> patconv
  val asms_pconv: patconv -> patconv
  val judgment_pconv: patconv -> patconv
  val in_pconv: patconv -> patconv
  val match_pconv: patconv -> term * (string option * typ) list -> patconv
  val rewrs_pconv: term option -> thm list -> patconv

  datatype ('a, 'b) pattern = At | In | Term of 'a | Concl | Asm | For of 'b list

  val mk_hole: int -> typ -> term

  val rewrite_conv: Proof.context
    -> (term * (string * typ) list, string * typ option) pattern list * term option
    -> thm list
    -> conv
end

structure Rewrite : REWRITE =
struct

datatype ('a, 'b) pattern = At | In | Term of 'a | Concl | Asm | For of 'b list

exception NO_TO_MATCH

val holeN = Name.internal "_hole"

fun prep_meta_eq ctxt = Simplifier.mksimps ctxt #> map Drule.zero_var_indexes


(* holes *)

fun mk_hole i T = Var ((holeN, i), T)

fun is_hole (Var ((name, _), _)) = (name = holeN)
  | is_hole _ = false

fun is_hole_const (Const (const_namerewrite_HOLE, _)) = true
  | is_hole_const _ = false

val hole_syntax =
  let
    (* Modified variant of Term.replace_hole *)
    fun replace_hole Ts (Const (const_namerewrite_HOLE, T)) i =
          (list_comb (mk_hole i (Ts ---> T), map_range Bound (length Ts)), i + 1)
      | replace_hole Ts (Abs (x, T, t)) i =
          let val (t', i') = replace_hole (T :: Ts) t i
          in (Abs (x, T, t'), i') end
      | replace_hole Ts (t $ u) i =
          let
            val (t', i') = replace_hole Ts t i
            val (u', i'') = replace_hole Ts u i'
          in (t' $ u', i'') end
      | replace_hole _ a i = (a, i)
    fun prep_holes ts = #1 (fold_map (replace_hole []) ts 1)
  in
    Context.proof_map (Syntax_Phases.term_check 101 "hole_expansion" (K prep_holes))
    #> Proof_Context.set_mode Proof_Context.mode_pattern
  end


(* pattern conversions *)

type patconv = Proof.context -> Type.tyenv * (string * term) list -> cterm -> thm

fun (cv1 then_pconv cv2) ctxt tytenv ct = (cv1 ctxt tytenv then_conv cv2 ctxt tytenv) ct

fun (cv1 else_pconv cv2) ctxt tytenv ct = (cv1 ctxt tytenv else_conv cv2 ctxt tytenv) ct

fun raw_abs_pconv cv ctxt tytenv ct =
  case Thm.term_of ct of
    Abs _ => CConv.abs_cconv (fn (x, ctxt') => cv x ctxt' tytenv) ctxt ct
  | t => raise TERM ("raw_abs_pconv", [t])

fun raw_fun_pconv cv ctxt tytenv ct =
  case Thm.term_of ct of
    _ $ _ => CConv.fun_cconv (cv ctxt tytenv) ct
  | t => raise TERM ("raw_fun_pconv", [t])

fun raw_arg_pconv cv ctxt tytenv ct =
  case Thm.term_of ct of
    _ $ _ => CConv.arg_cconv (cv ctxt tytenv) ct
  | t => raise TERM ("raw_arg_pconv", [t])

fun abs_pconv cv (s,T) ctxt (tyenv, ts) ct =
  let val u = Thm.term_of ct
  in
    case try (fastype_of #> dest_funT) u of
      NONE => raise TERM ("abs_pconv: no function type", [u])
    | SOME (U, _) =>
        let
          val tyenv' =
            if T = dummyT then tyenv
            else Sign.typ_match (Proof_Context.theory_of ctxt) (T, U) tyenv
          val eta_expand_cconv =
            case u of
              Abs _=> Thm.reflexive
            | _ => CConv.rewr_cconv @{thm eta_expand}
          fun add_ident NONE _ l = l
            | add_ident (SOME name) ct l = (name, Thm.term_of ct) :: l
          val abs_cv = CConv.abs_cconv (fn (ct, ctxt) => cv ctxt (tyenv', add_ident s ct ts)) ctxt
        in (eta_expand_cconv then_conv abs_cv) ct end
        handle Pattern.MATCH => raise TYPE ("abs_pconv: types don't match", [T,U], [u])
  end

fun fun_pconv cv ctxt tytenv ct =
  case Thm.term_of ct of
    _ $ _ => CConv.fun_cconv (cv ctxt tytenv) ct
  | Abs (_, T, _ $ Bound 0) => abs_pconv (fun_pconv cv) (NONE, T) ctxt tytenv ct
  | t => raise TERM ("fun_pconv", [t])

local

fun arg_pconv_gen cv0 cv ctxt tytenv ct =
  case Thm.term_of ct of
    _ $ _ => cv0 (cv ctxt tytenv) ct
  | Abs (_, T, _ $ Bound 0) => abs_pconv (arg_pconv_gen cv0 cv) (NONE, T) ctxt tytenv ct
  | t => raise TERM ("arg_pconv_gen", [t])

in

fun arg_pconv ctxt = arg_pconv_gen CConv.arg_cconv ctxt
fun imp_pconv ctxt = arg_pconv_gen (CConv.concl_cconv 1) ctxt

end

(* Move to B in !!x_1 ... x_n. B. Do not eta-expand *)
fun params_pconv cv ctxt tytenv ct =
  let val pconv =
    case Thm.term_of ct of
      Const (const_namePure.all, _) $ Abs _ => (raw_arg_pconv o raw_abs_pconv) (fn _ => params_pconv cv)
    | Const (const_namePure.all, _) => raw_arg_pconv (params_pconv cv)
    | _ => cv
  in pconv ctxt tytenv ct end

fun forall_pconv cv ident ctxt tytenv ct =
  case Thm.term_of ct of
    Const (const_namePure.all, T) $ _ =>
      let
        val def_U = T |> dest_funT |> fst |> dest_funT |> fst
        val ident' = apsnd (the_default (def_U)) ident
      in arg_pconv (abs_pconv cv ident') ctxt tytenv ct end
  | t => raise TERM ("forall_pconv", [t])

fun all_pconv _ _ = Thm.reflexive

fun for_pconv cv idents ctxt tytenv ct =
  let
    fun f rev_idents (Const (const_namePure.all, _) $ t) =
        let val (rev_idents', cv') = f rev_idents (case t of Abs (_,_,u) => u | _ => t)
        in
          case rev_idents' of
            [] => ([], forall_pconv cv' (NONE, NONE))
          | (x :: xs) => (xs, forall_pconv cv' x)
        end
      | f rev_idents _ = (rev_idents, cv)
  in
    case f (rev idents) (Thm.term_of ct) of
      ([], cv') => cv' ctxt tytenv ct
    | _ => raise CTERM ("for_pconv", [ct])
  end

fun concl_pconv cv ctxt tytenv ct =
  case Thm.term_of ct of
    (Const (const_namePure.imp, _) $ _) $ _ => imp_pconv (concl_pconv cv) ctxt tytenv ct
  | _ => cv ctxt tytenv ct

fun asm_pconv cv ctxt tytenv ct =
  case Thm.term_of ct of
    (Const (const_namePure.imp, _) $ _) $ _ => CConv.with_prems_cconv ~1 (cv ctxt tytenv) ct
  | t => raise TERM ("asm_pconv", [t])

fun asms_pconv cv ctxt tytenv ct =
  case Thm.term_of ct of
    (Const (const_namePure.imp, _) $ _) $ _ =>
      ((CConv.with_prems_cconv ~1 oo cv) else_pconv imp_pconv (asms_pconv cv)) ctxt tytenv ct
  | t => raise TERM ("asms_pconv", [t])

fun judgment_pconv cv ctxt tytenv ct =
  if Object_Logic.is_judgment ctxt (Thm.term_of ct)
  then arg_pconv cv ctxt tytenv ct
  else cv ctxt tytenv ct

fun in_pconv cv ctxt tytenv ct =
  (cv else_pconv 
   raw_fun_pconv (in_pconv cv) else_pconv
   raw_arg_pconv (in_pconv cv) else_pconv
   raw_abs_pconv (fn _  => in_pconv cv))
  ctxt tytenv ct

fun replace_idents idents t =
  let
    fun subst ((n1, s)::ss) (t as Free (n2, _)) = if n1 = n2 then s else subst ss t
      | subst _ t = t
  in Term.map_aterms (subst idents) t end

fun match_pconv cv (t,fixes) ctxt (tyenv, env_ts) ct =
  let
    val t' = replace_idents env_ts t
    val thy = Proof_Context.theory_of ctxt
    val u = Thm.term_of ct

    fun descend_hole fixes (Abs (_, _, t)) =
        (case descend_hole fixes t of
          NONE => NONE
        | SOME (fix :: fixes', pos) => SOME (fixes', abs_pconv pos fix)
        | SOME ([], _) => raise Match (* less fixes than abstractions on path to hole *))
      | descend_hole fixes (t as l $ r) =
        let val (f, _) = strip_comb t
        in
          if is_hole f
          then SOME (fixes, cv)
          else
            (case descend_hole fixes l of
              SOME (fixes', pos) => SOME (fixes', fun_pconv pos)
            | NONE =>
              (case descend_hole fixes r of
                SOME (fixes', pos) => SOME (fixes', arg_pconv pos)
              | NONE => NONE))
        end
      | descend_hole fixes t =
        if is_hole t then SOME (fixes, cv) else NONE

    val to_hole = descend_hole (rev fixes) #> the_default ([], cv) #> snd
  in
    case try (Pattern.match thy (apply2 Logic.mk_term (t',u))) (tyenv, Vartab.empty) of
      NONE => raise TERM ("match_pconv: Does not match pattern", [t, t',u])
    | SOME (tyenv', _) => to_hole t ctxt (tyenv', env_ts) ct
  end

fun rewrs_pconv to thms ctxt (tyenv, env_ts) =
  let
    fun instantiate_normalize_env env thm =
      let
        val prop = Thm.prop_of thm
        val norm_type = Envir.norm_type o Envir.type_env
        val insts = Term.add_vars prop []
          |> map (fn x as (s, T) =>
              ((s, norm_type env T), Thm.cterm_of ctxt (Envir.norm_term env (Var x))))
        val tyinsts = Term.add_tvars prop []
          |> map (fn x => (x, Thm.ctyp_of ctxt (norm_type env (TVar x))))
      in Drule.instantiate_normalize (TVars.make tyinsts, Vars.make insts) thm end
    
    fun unify_with_rhs to env thm =
      let
        val (_, rhs) = thm |> Thm.concl_of |> Logic.dest_equals
        val env' = Pattern.unify (Context.Proof ctxt) (Logic.mk_term to, Logic.mk_term rhs) env
          handle Pattern.Unif => raise NO_TO_MATCH
      in env' end
    
    fun inst_thm_to (NONE, _) thm = thm
      | inst_thm_to (SOME to, env) thm =
          instantiate_normalize_env (unify_with_rhs to env thm) thm
    
    fun inst_thm idents (to, tyenv) thm =
      let
        (* Replace any identifiers with their corresponding bound variables. *)
        val maxidx = Term.maxidx_typs (map (snd o snd) (Vartab.dest tyenv)) 0
        val env = Envir.Envir {maxidx = maxidx, tenv = Vartab.empty, tyenv = tyenv}
        val maxidx = Envir.maxidx_of env |> fold Term.maxidx_term (the_list to)
        val thm' = Thm.incr_indexes (maxidx + 1) thm
      in SOME (inst_thm_to (Option.map (replace_idents idents) to, env) thm') end
      handle NO_TO_MATCH => NONE
    
  in CConv.rewrs_cconv (map_filter (inst_thm env_ts (to, tyenv)) thms) end

fun rewrite_conv ctxt (pattern, to) thms ct =
  let
    fun apply_pat At = judgment_pconv
      | apply_pat In = in_pconv
      | apply_pat Asm = params_pconv o asms_pconv
      | apply_pat Concl = params_pconv o concl_pconv
      | apply_pat (For idents) = (fn cv => for_pconv cv (map (apfst SOME) idents))
      | apply_pat (Term x) = (fn cv => match_pconv cv (apsnd (map (apfst SOME)) x))

    val cv = fold_rev apply_pat pattern

    fun distinct_prems th =
      case Seq.pull (distinct_subgoals_tac th) of
        NONE => th
      | SOME (th', _) => th'

    val rewrite = rewrs_pconv to (maps (prep_meta_eq ctxt) thms)
  in cv rewrite ctxt (Vartab.empty, []) ct |> distinct_prems end

fun rewrite_export_tac ctxt (pat, pat_ctxt) thms =
  let
    val export = case pat_ctxt of
        NONE => I
      | SOME ctxt' => singleton (Proof_Context.export ctxt' ctxt)
  in CCONVERSION (export o rewrite_conv ctxt pat thms) end

val _ =
  Theory.setup
  let
    fun mk_fix s = (Binding.name s, NONE, NoSyn)

    val raw_pattern : (string, binding * string option * mixfix) pattern list parser =
      let
        val sep = (Args.$$$ "at" >> K At) || (Args.$$$ "in" >> K In)
        val atom =  (Args.$$$ "asm" >> K Asm) ||
          (Args.$$$ "concl" >> K Concl) ||
          (Args.$$$ "for" |-- Args.parens (Scan.optional Parse.vars []) >> For) ||
          (Parse.term >> Term)
        val sep_atom = sep -- atom >> (fn (s,a) => [s,a])

        fun append_default [] = [Concl, In]
          | append_default (ps as Term _ :: _) = Concl :: In :: ps
          | append_default [For x, In] = [For x, Concl, In]
          | append_default (For x :: (ps as In :: Term _:: _)) = For x :: Concl :: ps
          | append_default ps = ps

      in Scan.repeats sep_atom >> (rev #> append_default) end

    fun context_lift (scan : 'a parser) f = fn (context : Context.generic, toks) =>
      let
        val (r, toks') = scan toks
        val (r', context') = Context.map_proof_result (fn ctxt => f ctxt r) context
      in (r', (context', toks')) end

    fun read_fixes fixes ctxt =
      let fun read_typ (b, rawT, mx) = (b, Option.map (Syntax.read_typ ctxt) rawT, mx)
      in Proof_Context.add_fixes (map read_typ fixes) ctxt end

    fun prep_pats ctxt (ps : (string, binding * string option * mixfix) pattern list) =
      let
        fun add_constrs ctxt n (Abs (x, T, t)) =
            let
              val (x', ctxt') = yield_singleton Proof_Context.add_fixes (mk_fix x) ctxt
            in
              (case add_constrs ctxt' (n+1) t of
                NONE => NONE
              | SOME ((ctxt'', n', xs), t') =>
                  let
                    val U = Type_Infer.mk_param n []
                    val u = Type.constraint (U --> dummyT) (Abs (x, T, t'))
                  in SOME ((ctxt'', n', (x', U) :: xs), u) end)
            end
          | add_constrs ctxt n (l $ r) =
            (case add_constrs ctxt n l of
              SOME (c, l') => SOME (c, l' $ r)
            | NONE =>
              (case add_constrs ctxt n r of
                SOME (c, r') => SOME (c, l $ r')
              | NONE => NONE))
          | add_constrs ctxt n t =
            if is_hole_const t then SOME ((ctxt, n, []), t) else NONE

        fun prep (Term s) (n, ctxt) =
            let
              val t = Syntax.parse_term ctxt s
              val ((ctxt', n', bs), t') =
                the_default ((ctxt, n, []), t) (add_constrs ctxt (n+1) t)
            in (Term (t', bs), (n', ctxt')) end
          | prep (For ss) (n, ctxt) =
            let val (ns, ctxt') = read_fixes ss ctxt
            in (For ns, (n, ctxt')) end
          | prep At (n,ctxt) = (At, (n, ctxt))
          | prep In (n,ctxt) = (In, (n, ctxt))
          | prep Concl (n,ctxt) = (Concl, (n, ctxt))
          | prep Asm (n,ctxt) = (Asm, (n, ctxt))

        val (xs, (_, ctxt')) = fold_map prep ps (0, ctxt)

      in (xs, ctxt') end

    fun prep_args ctxt (((raw_pats, raw_to), raw_ths)) =
      let

        fun check_terms ctxt ps to =
          let
            fun safe_chop (0: int) xs = ([], xs)
              | safe_chop n (x :: xs) = chop (n - 1) xs |>> cons x
              | safe_chop _ _ = raise Match

            fun reinsert_pat _ (Term (_, cs)) (t :: ts) =
                let val (cs', ts') = safe_chop (length cs) ts
                in (Term (t, map dest_Free cs'), ts') end
              | reinsert_pat _ (Term _) [] = raise Match
              | reinsert_pat ctxt (For ss) ts =
                let val fixes = map (fn s => (s, Variable.default_type ctxt s)) ss
                in (For fixes, ts) end
              | reinsert_pat _ At ts = (At, ts)
              | reinsert_pat _ In ts = (In, ts)
              | reinsert_pat _ Concl ts = (Concl, ts)
              | reinsert_pat _ Asm ts = (Asm, ts)

            fun free_constr (s,T) = Type.constraint T (Free (s, dummyT))
            fun mk_free_constrs (Term (t, cs)) = t :: map free_constr cs
              | mk_free_constrs _ = []

            val ts = maps mk_free_constrs ps @ the_list to
              |> Syntax.check_terms (hole_syntax ctxt)
            val ctxt' = fold Variable.declare_term ts ctxt
            val (ps', (to', ts')) = fold_map (reinsert_pat ctxt') ps ts
              ||> (fn xs => case to of NONE => (NONE, xs) | SOME _ => (SOME (hd xs), tl xs))
            val _ = case ts' of (_ :: _) => raise Match | [] => ()
          in ((ps', to'), ctxt') end

        val (pats, ctxt') = prep_pats ctxt raw_pats

        val ths = Attrib.eval_thms ctxt' raw_ths
        val to = Option.map (Syntax.parse_term ctxt') raw_to

        val ((pats', to'), ctxt'') = check_terms ctxt' pats to

      in ((pats', ths, (to', ctxt)), ctxt'') end

    val to_parser = Scan.option ((Args.$$$ "to") |-- Parse.term)

    val subst_parser =
      let val scan = raw_pattern -- to_parser -- Parse.thms1
      in context_lift scan prep_args end
  in
    Method.setup bindingrewrite (subst_parser >>
      (fn (pattern, inthms, (to, pat_ctxt)) => fn orig_ctxt =>
        SIMPLE_METHOD' (rewrite_export_tac orig_ctxt ((pattern, to), SOME pat_ctxt) inthms)))
      "single-step rewriting, allowing subterm selection via patterns"
  end
end