File ‹lazy_eval.ML›

signature LAZY_EVAL = sig

datatype pat = AnyPat of indexname | ConsPat of (string * pat list)

type constructor = string * int

type equation = {
    function : term, 
    thm : thm, 
    rhs : term, 
    pats : pat list
  }

type eval_ctxt' = {
    equations : equation list, 
    constructors : constructor list,
    pctxt : Proof.context, 
    facts : thm Net.net,
    verbose : bool
  }

type eval_hook = eval_ctxt' -> term -> (term * conv) option

type eval_ctxt = {
    ctxt : eval_ctxt', 
    hooks : eval_hook list
  }

val is_constructor_name : constructor list -> string -> bool
val constructor_arity : constructor list -> string -> int option

val mk_eval_ctxt : Proof.context -> constructor list -> thm list -> eval_ctxt
val add_facts : thm list -> eval_ctxt -> eval_ctxt
val get_facts : eval_ctxt -> thm list
val get_ctxt : eval_ctxt -> Proof.context
val add_hook : eval_hook -> eval_ctxt -> eval_ctxt
val get_verbose : eval_ctxt -> bool
val set_verbose : bool -> eval_ctxt -> eval_ctxt
val get_constructors : eval_ctxt -> constructor list
val set_constructors : constructor list -> eval_ctxt -> eval_ctxt

val whnf : eval_ctxt -> term -> term * conv
val match : eval_ctxt -> pat -> term -> 
  (indexname * term) list option -> (indexname * term) list option * term * conv
val match_all : eval_ctxt -> pat list -> term list -> 
  (indexname * term) list option -> (indexname * term) list option * term list * conv

end

structure Lazy_Eval : LAZY_EVAL = struct

datatype pat = AnyPat of indexname | ConsPat of (string * pat list)

type constructor = string * int

type equation = {
    function : term, 
    thm : thm, 
    rhs : term, 
    pats : pat list
  }

type eval_ctxt' = {
    equations : equation list, 
    constructors : constructor list,
    pctxt : Proof.context, 
    facts : thm Net.net,
    verbose : bool
  }

type eval_hook = eval_ctxt' -> term -> (term * conv) option

type eval_ctxt = {
    ctxt : eval_ctxt', 
    hooks : eval_hook list
  }

fun add_hook h ({hooks, ctxt} : eval_ctxt) = 
  {hooks = h :: hooks, ctxt = ctxt} : eval_ctxt

fun get_verbose {ctxt = {verbose, ...}, ...} = verbose

fun set_verbose b ({ctxt = {equations, pctxt, facts, constructors, ...}, hooks} : eval_ctxt) =
  {ctxt = {equations = equations, pctxt = pctxt, facts = facts, 
   constructors = constructors, verbose = b}, hooks = hooks}

fun get_constructors ({ctxt = {constructors, ...}, ...} : eval_ctxt) = constructors

fun set_constructors cs ({ctxt = {equations, pctxt, facts, verbose, ...}, hooks} : eval_ctxt) =
  {ctxt = {equations = equations, pctxt = pctxt, facts = facts, 
   verbose = verbose, constructors = cs}, hooks = hooks}

type constructor = string * int

val is_constructor_name = member (op = o apsnd fst)

val constructor_arity = AList.lookup op =

fun stream_pat_of_term _ (Var v) = AnyPat (fst v)
  | stream_pat_of_term constructors t =
      case strip_comb t of
        (Const (c, _), args) =>
          (case constructor_arity constructors c of
             NONE => raise TERM ("Not a valid pattern.", [t])
           | SOME n => 
               if length args = n then
                 ConsPat (c, map (stream_pat_of_term constructors) args)
               else
                 raise TERM ("Not a valid pattern.", [t]))
      | _ => raise TERM ("Not a valid pattern.", [t])

fun analyze_eq constructors thm = 
  let
    val ((f, pats), rhs) = thm |> Thm.concl_of |> HOLogic.dest_Trueprop |> HOLogic.dest_eq |> 
      apfst (strip_comb #> apsnd (map (stream_pat_of_term constructors)))
    handle TERM _ => raise THM ("Not a valid function equation.", 0, [thm])
  in 
    {function = f, thm = thm RS @{thm eq_reflection}, rhs = rhs, pats = pats} : equation
  end

fun mk_eval_ctxt ctxt (constructors : constructor list) thms : eval_ctxt = {
    ctxt = {
        equations = map (analyze_eq constructors) thms, 
        facts = Net.empty, 
        verbose = false,
        pctxt = ctxt,
        constructors = constructors
      }, 
      hooks = []
    }

fun add_facts facts' {ctxt = {equations, pctxt, facts, verbose, constructors}, hooks} =
  let
    val eq = op = o apply2 Thm.prop_of
    val facts' = 
      fold (fn thm => fn net => Net.insert_term eq (Thm.prop_of thm, thm) net
        handle Net.INSERT => net) facts' facts
  in
    {ctxt = {equations = equations, pctxt = pctxt, facts = facts', 
       verbose = verbose, constructors = constructors}, 
     hooks = hooks}
  end

val get_facts = Net.content o #facts o #ctxt

val get_ctxt = (#pctxt o #ctxt : eval_ctxt -> Proof.context)

fun find_eqs (eval_ctxt : eval_ctxt) f = 
  let
    fun eq_const (Const (c, _)) (Const (c', _)) = c = c'
      | eq_const _ _ = false
  in
    map_filter (fn eq => if eq_const f (#function eq) then SOME eq else NONE) 
      (#equations (#ctxt eval_ctxt))
  end

datatype ('a, 'b) either = Inl of 'a | Inr of 'b

fun whnf (ctxt : eval_ctxt) t =
      case whnf_aux1 ctxt (Envir.beta_norm t) of
        (t', conv) => 
          if t aconv t' then 
            (t', conv)
          else
            case whnf ctxt t' of
              (t'', conv') => (t'', conv then_conv conv')
  
and whnf_aux1 (ctxt as {hooks, ctxt = ctxt'}) t =
      case get_first (fn h => h ctxt' t) hooks of
        NONE => whnf_aux2 ctxt t
      | SOME (t', conv) => case whnf ctxt t' of (t'', conv') => 
          (t'', conv then_conv conv')
and whnf_aux2 ctxt t =
  let
    val (f, args) = strip_comb t

    fun instantiate table (Var (x, _)) = the (AList.lookup op = table x)
      | instantiate table (s $ t) = instantiate table s $ instantiate table t
      | instantiate _ t = t
    fun apply_eq {thm, rhs, pats, ...} conv args = 
      let
        val (table, args', conv') = match_all ctxt pats args (SOME [])
      in (
        case table of
          SOME _ => (
            let
              val thy = Proof_Context.theory_of (get_ctxt ctxt)
              val t' = list_comb (f, args')
              val lhs = Thm.term_of (Thm.lhs_of thm)
              val env = Pattern.match thy (lhs, t') (Vartab.empty, Vartab.empty)
              val rhs = Thm.term_of (Thm.rhs_of thm)
              val rhs = Envir.subst_term env rhs |> Envir.beta_norm
            in
              Inr (rhs, conv then_conv conv' then_conv Conv.rewr_conv thm)
            end 
              handle Pattern.MATCH => Inl (args', conv then_conv conv'))
        | NONE => Inl (args', conv then_conv conv'))
       end
    
    fun apply_eqs [] args conv = (list_comb (f, args), conv)
      | apply_eqs (eq :: ctxt) args conv =
          (case apply_eq eq conv args of
             Inr res => res
           | Inl (args', conv) => apply_eqs ctxt args' conv)

  in
    case f of
      Const (f', _) => 
        if is_constructor_name (get_constructors ctxt) f' then
          (t, Conv.all_conv)
        else
          apply_eqs (find_eqs ctxt f) args Conv.all_conv
    | _ => (t, Conv.all_conv)
  end
and match_all ctxt pats args table =
  let
    fun match_all' [] [] acc conv table = (table, rev acc, conv)
      | match_all' (_ :: pats) (arg :: args) acc conv NONE =
        match_all' pats args (arg :: acc) (Conv.fun_conv conv) NONE
      | match_all' (pat :: pats) (arg :: args) acc conv (SOME table) =
          let 
            val (table', arg', conv') = match ctxt pat arg (SOME table)
            val conv = Conv.combination_conv conv conv'
            val acc = arg' :: acc
          in
            match_all' pats args acc conv table'
          end
      | match_all' _ _ _ _ _ = raise Match
  in
    if length pats = length args then
      match_all' pats args [] Conv.all_conv table
    else
      (NONE, args, Conv.all_conv)
  end
and match _ _ t NONE = (NONE, t, Conv.all_conv)
  | match _ (AnyPat v) t (SOME table) = (SOME ((v, t) :: table), t, Conv.all_conv)
  | match ctxt (ConsPat (c, pats)) t (SOME table) =
      let 
        val (t', conv) = whnf ctxt t
        val (f, args) = strip_comb t'
      in
        case f of
          Const (c', _) => 
            if c = c' then
              case match_all ctxt pats args (SOME table) of
                (table, args', conv') => (table, list_comb (f, args'), conv then_conv conv')
            else
              (NONE, t', conv)
          | _ => (NONE, t', conv)
      end

end