File ‹compare_code.ML›

signature COMPARE_CODE =
sig

  (* changes the code equations of some constant such that
     two consecutive comparisons via <=, <, or = are turned into one
     invocation of the comparator.
     The difference to a standard code_unfold is that here we change the code-equations
     where an additional sort-constraint on compare_order can be added. Otherwise, there would
     be no compare-function. *)
  val change_compare_code :
    term                      (* the constant *)
    -> string list     (* the list of type parameters which should be constraint to @{sort compare_order} *)
    -> local_theory -> local_theory

end

structure Compare_Code : COMPARE_CODE =
struct

fun drop_leading_qmark s =
  if String.isPrefix "?" s then String.substring (s,1,String.size s - 1) else s

fun change_compare_code const inst_names lthy =
  let
    val inst_names = map drop_leading_qmark inst_names
    val const_string = quote (Pretty.string_of (Syntax.pretty_term lthy const))
    val cname = (case const of Const (cname,_) => cname | _ =>
      error ("expected constant as input, but got " ^ const_string))
    val cert = Code.get_cert lthy [] cname
    val code_eqs = cert |> Code.equations_of_cert (Proof_Context.theory_of lthy)
    |> snd |> these |> map (fst o snd) |> map_filter I
    val _ = if null code_eqs then error "could not find code equations" else ()

    (* adding sort-constraint compare_order within code equations*)
    val const' = hd code_eqs |> Thm.concl_of |> Logic.dest_equals |> fst |> strip_comb |> fst
    val types = Term.add_tvars const' []
    val ctyp_of = TVar #> Thm.ctyp_of lthy
    fun filt s = not (member (op =) (@{sort ord} @ @{sort linorder} @ @{sort order}) s)
    val map_types = maps (fn ty => if List.exists (fn tn => tn = (fst o fst) ty) inst_names then
      [(ty, ctyp_of (apsnd (fn ss => filter filt ss @ @{sort compare_order}) ty))] else
      []) types
    val code_eqs = map (Thm.instantiate (TVars.make map_types, Vars.empty)) code_eqs

    (* replace comparisons and register code eqns *)
    val new_code_eqs = map (Local_Defs.unfold lthy @{thms two_comparisons_into_compare}) code_eqs
    val _ = if map Thm.prop_of new_code_eqs = map Thm.prop_of code_eqs then
      warning ("Code equations for " ^ const_string ^ " did not change\n" ^
      "Perhaps you have to provide some type variables which should be restricted to compare_order\n" ^
      (@{make_string} (map TVar types ~~ map snd types, const', code_eqs)))
      else ()
  in
    lthy
    |> Local_Theory.note ((Binding.name (Long_Name.base_name cname ^ "_compare_code"), []), new_code_eqs)
    |-> (fn (_, eqs) => Code.declare_default_eqns (map (rpair true) eqs))
  end

fun change_compare_code_cmd const tnames_option lthy =
  change_compare_code (Syntax.read_term lthy const) tnames_option lthy

val _ =
  Outer_Syntax.local_theory @{command_keyword compare_code}
    "turn comparisons via <= and < into compare within code-equations"
    (Scan.optional (@{keyword "("} |-- (Parse.list Parse.string) --| @{keyword ")"}) [] --
      Parse.term >> (fn (inst, c) => change_compare_code_cmd c inst))

end