Theory AutoCorres2.ML_Record_Antiquotation
section "Record Antiquotation"
theory ML_Record_Antiquotation
imports Main
begin
subsection ‹Motivation›
text ‹
A shortcoming of ML records is the lack of proper update / map functions for the record fields.
Manually defining those update functions is considered as painful, as it requires 'quadratic'
editing when adding a new field:
- add the new field to every record pattern (in every already defined update / map function)
- add a new update / map function for the field.
Various workarounds have been proposed. E.g.
- Using mutable references as fields to allow for destructive updates
- Fancy higher order function-combinators (Fold): 🌐‹http://mlton.org/FunctionalRecordUpdate›
Here we develop yet another solution. We provide a ML-Antiquotation that generates the definitions.
For a record specification as datatype ▩‹datatype 'a foo = Foo {f1:'a, f2:bool}›
it generates the datatype as specified with one constructor
▩‹Foo› wrapping the record and additionally all the 'canonical' functions:
▪ ▩‹make_foo›
▪ ▩‹dest_foo›
▪ ▩‹get_f1›
▪ ▩‹get_f2›
▪ ▩‹map_f1›
▪ ▩‹map_f2›
›
subsection ‹Example›
ML_val ‹
datatype 'a foo = Foo of {f1:'a, f2:bool};
val make_foo = Foo;
fun dest_foo (Foo r) = r;
fun get_f1 (Foo r) = #f1 r;
fun get_f2 (Foo r) = #f2 r;
fun map_f1 f (Foo {f1, f2}) =
Foo {f1 = f f1, f2 = f2};
fun map_f2 f (Foo {f1, f2}) =
Foo {f1 = f1, f2 = f f2};
›
subsection ‹Implementation›
ML ‹
structure ML_Record_Antiquotation =
struct
fun remove_comments s =
let
val unbalanced_comments_msg = "unbalanced ML comments (* ... *)"
fun err_unbalanced () = error unbalanced_comments_msg
fun rem n [] = if n = 0 then [] else err_unbalanced ()
| rem n ("("::"*"::xs) = rem (n + 1) xs
| rem n ("*"::")"::xs) = if n > 0 then rem (n - 1) xs else err_unbalanced ()
| rem n (x::xs) = if n = 0 then x::rem n xs else rem n xs
in
s |> Symbol.explode |> rem 0 |> implode
end
fun to_upper_first s =
case String.explode s of
[] => ""
| c::cs => String.implode ((Char.toUpper c)::cs)
fun split_first [] = raise List.Empty
| split_first (x :: xs) = (x, xs);
fun parse_record_declaration s =
let
fun is_whitespace c = member (op =) [" ", "\t", "\n"] (str c)
fun white_space_explode "" = []
| white_space_explode s = String.fields (is_whitespace) s |> filter_out (fn x => x = "")
fun remove p s
= String.explode s |> filter_out p |> String.implode
val normalize_spaces = space_implode " " o white_space_explode;
val remove_delimiter = remove (fn c => member (op =) ["{","(", "}",")",";"] (str c))
val trim = remove (fn c => str c = " ")
fun strip_word word s =
let
val (first, rest) = s |> white_space_explode |> split_first
in
if first = word then space_implode " " rest else error ("expecting '" ^ word ^ "'")
end
val [lhs, rhs] = s |> remove_comments |> normalize_spaces |> space_explode "=" |> map normalize_spaces
fun split_first_word s = s |> white_space_explode |> split_first ||> space_implode " "
fun split_last_word s = s |> white_space_explode |> split_last |>> space_implode " "
fun split_type s =
let
val (params, type_name) = s
|> strip_word "datatype"
|> split_last_word
|>> remove_delimiter
|>> space_explode ","
in (params, type_name) end
val (params, rec_name) = split_type lhs
fun is_type_parameter s = String.isPrefix "'" s orelse String.isPrefix "''" s
fun get_params s = s
|> white_space_explode
|> map remove_delimiter
|> map (space_explode ",")
|> flat
|> filter is_type_parameter
val (constr, field_names_params) =
rhs
|> split_first_word
||> strip_word "of"
||> remove_delimiter
||> space_explode ","
||> map (space_explode ":")
||> map (fn [name, typ] => (trim name, get_params typ))
in ((lhs, rhs), (params, rec_name), constr, field_names_params) end
fun mk_record fields = enclose "{" "}" (commas fields);
fun make_name record = "make_" ^ record;
fun dest_name record = "dest_" ^ record;
fun map_name field = "map_" ^ field;
fun get_name field = "get_" ^ field;
fun mk_record_ass fields values =
fields ~~ values
|> map (fn (f,v) => f ^ " = " ^ v)
|> commas
|> enclose "{" "}"
fun replace p q = map (fn x => if x = p then q else x)
fun mk_map constr fields n =
let
val (field, params) = nth fields n
val field_names = map fst fields
val [f] = Name.variant_list field_names ["f"]
in
"fun " ^ map_name field ^
" f (" ^ constr ^ " " ^ mk_record field_names ^ ") " ^ " = \n" ^
" " ^ constr ^ " " ^ mk_record_ass field_names (replace field (f ^ " " ^ field) field_names) ^ ";"
end
fun mk_maps constr fields =
0 upto (length fields) - 1
|> map (mk_map constr fields)
|> cat_lines
fun mk_get constr field_name =
"fun " ^ get_name field_name ^ " (" ^ constr ^ " r) = #" ^ field_name ^ " r;";
fun mk_gets constr fields =
fields |> map fst |> map (mk_get constr) |> cat_lines
fun mk_record_datatype_declaration add_decl s =
let
val ((lhs, rhs), record as (params, rec_name), constr, fields) = parse_record_declaration s;
val decl = lhs ^ " = " ^ rhs ^ ";";
val make = "val " ^ make_name rec_name ^ " = " ^ constr ^ ";";
val dest = "fun " ^ dest_name rec_name ^ " (" ^ constr ^ " r) = r;";
val gets = mk_gets constr fields;
val maps = mk_maps constr fields;
fun cond_cons P a xs = if P then a::xs else xs
in
cat_lines (cond_cons add_decl decl [make, dest, gets, maps])
end
end
›
text ‹Test the parser and its output.›
ML_val ‹
let
val test = "datatype 'a foo = Foo of {fld1:int, fld2:'a}"
in writeln (ML_Record_Antiquotation.mk_record_datatype_declaration true test) end
›
setup ‹
ML_Antiquotation.inline_embedded (Binding.make ("record",⌂))
(Scan.lift Parse.embedded_input >> (fn source =>
let
val _ = ML_Lex.read_source source;
val sanitized_str = fst (Input.source_content source) |> ML_Lex.tokenize |> ML_Lex.flatten
in ML_Record_Antiquotation.mk_record_datatype_declaration true sanitized_str end))
›
subsection ‹Examples›
ML_val ‹
structure JustAStruct = struct
datatype bar = Bar of int;
@{record ‹datatype 'a foo =
Foo of {f1:'a ,
f2:bar }
›}
val my_foo = Foo {f1 = 1, f2 = Bar 42};
val my_foo1 = map_f1 (K (Bar 44)) my_foo;
val t = get_f1 my_foo1
val t1 = make_foo {f1=3, f2=Bar 4}
val t2 = dest_foo t1
@{record ‹datatype foo2 = Foo2 of {f3:int foo}›}
end
›
text ‹Note that the markup of the datatype declaration is limited to basic Lex-markup.
Experiments with explicitly evaluating the datatype with e.g. the @{ML "ML"} function failed.
The issue there is that the @{ML "ML"} function is not evaluated at the point (aka context)
exactly in the text position, but with the context of the surrounding ML command.
The following examples try to illustrate the issues.
›
text ‹Works as expected, as every type referred to in the datatype spec is already known
at the beginnig of the @{command ML_val} command.›
ML_val ‹
val _ = ML ‹datatype foo = Foo of int›;
›
text ‹The following fails as 'bar' is not known at the beginning of the context.›
ML_val ‹
datatype bar = Bar of int
val _ = ML ‹datatype foo = Foo of bar›
handle ERROR x => warning x
›
text ‹Adding a semicolon after the definition of bar helps.
There the semicolon marks the end of an "evaluation / compilation chunk". So @{ML "ML"} is
evaluated within an already augmented ML-context that knows about ‹bar›.
›
ML_val ‹
datatype bar = Bar of int;
val _ = ML ‹datatype foo = Foo of bar›
›
text ‹In the context of a structure the semicolon is not enough. ›
ML_val ‹
structure InAStruct =
struct
datatype bar = Bar of int;
val _ = ML ‹datatype foo = Foo of bar›
handle ERROR x => warning x
end
›
end