File ‹swf_util.ML›
signature SWF_UTIL =
sig
exception RANKING_INDEX of int
exception PARSE
type ranking = int list
type sp_clauses = ranking list * ranking list * (ranking * ranking) list
type line_iterator = unit -> string option
val fold1 : ('a -> 'a -> 'a) -> 'a list -> 'a
val bind_list : 'a list -> ('a -> 'b list) -> 'b list
val par_bind_list : 'a list -> ('a -> 'b list) -> 'b list
val subtract1 : ('a * 'b -> bool) -> 'a list -> 'b list -> 'b list
val lists_of_length : int -> 'a list -> 'a list list
val multisets_of_length : int -> 'a list -> 'a list list
val permutations : 'a list -> 'a list list
val majority_ord : 'a ord list -> 'a ord
val fact : int -> int
val inversions : 'a ord -> 'a list -> int
val swap_dist : ''a list -> ''a list -> int
val ord_of_ranking : ''a list -> ''a ord
val ranking_to_nat : ranking -> int
val nat_to_ranking : int -> int -> ranking
val mk_distinct_thms : Proof.context -> thm -> thm list
val mk_inversion_cache : Proof.context -> int -> (int * thm) array
val prove_mset_eq : Proof.context -> cterm * cterm -> thm
val simp_prove : Proof.context -> cterm -> thm
val simp_discharge : Proof.context -> thm -> thm
val simp_discharge_all : Proof.context -> thm -> thm
val list_iterator : string list -> line_iterator
val file_iterator : Path.T -> line_iterator
val read_profiles_gen : line_iterator -> ranking list list
val read_profiles_file : Path.T * bool -> ranking list list
val read_sp_clauses_file : Path.T * bool -> sp_clauses list
val read_sp_clauses_gen : line_iterator -> sp_clauses list
end
structure SWF_Util : SWF_UTIL =
struct
type ranking = int list
type sp_clauses = ranking list * ranking list * (ranking * ranking) list
type line_iterator = unit -> string option
val distinct_unfold_thms =
[@{lemma "distinct [] ⟷ True" by simp},
@{lemma "distinct (x # xs) ⟷ list_all (λy. x ≠ y ∧ y ≠ x) xs ∧ distinct xs"
by (auto simp: list.pred_set)}]
fun mk_distinct_thms ctxt thm =
thm
|> Local_Defs.unfold ctxt (distinct_unfold_thms @ @{thms HOL.simp_thms list.pred_inject})
|> HOLogic.conj_elims
fun lists_of_length n xs =
let
fun go acc 0 = [rev acc]
| go acc n = maps (fn x => go (x :: acc) (n-1)) xs
in
go [] n
end
fun multisets_of_length n xs =
let
fun go acc 0 _ = [rev acc]
| go _ _ [] = []
| go acc n (x :: xs) = go (x :: acc) (n-1) (x :: xs) @ go acc n xs
in
go [] n xs
end
fun pick1 xs =
let
fun go _ [] = []
| go acc (x :: xs) = (x, rev acc @ xs) :: go (x :: acc) xs
in
go [] xs
end
fun permutations xs =
let
fun go acc [] = [rev acc]
| go acc xs =
maps (fn (x, ys) => go (x :: acc) ys) (pick1 xs)
in
go [] xs
end
fun count p xs = fold (fn x => fn acc => if p x then acc + 1 else acc) xs 0
fun inversions ord =
let
fun go acc [] = acc
| go acc (x :: xs) = go (acc + count (fn y => is_less (ord (y, x))) xs) xs
in
go 0
end
fun swap_dist xs ys =
map (fn y => find_index (fn x => x = y) xs) ys
|> inversions int_ord
fun ord_of_ranking zs (x, y) =
let
fun go (z ::zs) =
if z = y then if z = x then EQUAL else LESS else if z = x then GREATER else go zs
| go _ = raise Match
in
go zs
end
fun majority_ord zss (x,y) =
let
val (d1, d2) =
List.partition (fn ord => is_less_equal (ord (x,y))) zss
|> apply2 length
in
int_ord (d2, d1)
end
fun fact n = if n <= 1 then 1 else n * fact (n - 1)
fun ranking_to_nat xs =
let
fun go acc _ [] = acc
| go acc n (x :: xs) = go (acc * n + x) (n-1) (map (fn y => if y < x then y else y - 1) xs)
in
go 0 (length xs) xs
end
exception RANKING_INDEX of int
fun nat_to_ranking n k =
let
fun go 0 _ = []
| go n k =
let
val (x, y) = Integer.div_mod k (fact (n - 1))
in
x :: map (fn y => if y < x then y else y + 1) (go (n-1) y)
end
val _ = if k < 0 orelse k >= fact n then raise RANKING_INDEX k else ()
in
go n k
end
val inversion_cache_simpset = simpset_of (
@{context}
delsimps @{thms snd_sort_and_count_inversions}
addsimps @{thms snd_sort_and_count_inversions[symmetric]
sort_and_count_inversions.simps split_list_def})
fun mk_inversion_cache ctxt n =
let
fun mk_thm i =
let
val xs = nat_to_ranking n i
val t_xs = xs |> map (HOLogic.mk_number HOLogic.natT) |> HOLogic.mk_list HOLogic.natT
val res = inversions int_ord xs
val t_res = res |> HOLogic.mk_number HOLogic.natT
val goal =
\<^instantiate>‹xs = t_xs and res = t_res in prop "inversion_number (xs :: nat list) = res"›
val thm = Goal.prove ctxt [] [] goal (fn {context = ctxt, ...} =>
ALLGOALS (Simplifier.simp_tac (put_simpset inversion_cache_simpset ctxt)))
in
(res, thm)
end
in
Array.fromList (Par_List.map mk_thm (0 upto (fact n - 1)))
end
fun normalize_mset_conv (ctxt : Proof.context) ct =
let
fun add_mset_conv ct =
ct |> (
case Thm.term_of ct of
Const (\<^const_name>‹add_mset›, _) $ t_x $
(Const (\<^const_name>‹add_mset›, _) $ t_y $ _) =>
if is_less_equal (Term_Ord.term_ord (t_x, t_y)) then
Conv.all_conv
else
Conv.rewr_conv @{thm add_mset_commute[THEN eq_reflection]}
then_conv Conv.arg_conv add_mset_conv
| _ => Conv.all_conv)
in
ct |> (
case Thm.term_of ct of
Const (\<^const_name>‹zero_class.zero›, _) => Conv.all_conv
| Const (\<^const_name>‹add_mset›, _) $ _ $ _ =>
Conv.arg_conv (normalize_mset_conv ctxt)
then_conv add_mset_conv
| _ => Conv.no_conv)
end
fun prove_mset_eq ctxt (ct1, ct2) =
let
val (thm1, thm2) = apply2 (normalize_mset_conv ctxt) (ct1, ct2)
in
Thm.transitive thm1 (Thm.symmetric thm2)
end
fun simp_prove ctxt ct =
let
val result =
ct |> Goal.init |> SINGLE (ALLGOALS (Simplifier.simp_tac ctxt))
in
case result of
SOME thm => Goal.finish ctxt thm
| NONE => raise CTERM ("simp_prove: Failed to finish proof", [])
end
fun simp_discharge ctxt thm =
let
val prem = Thm.cprem_of thm 1
in
Thm.implies_elim thm (simp_prove ctxt prem)
end
fun simp_discharge_all ctxt thm =
if Thm.nprems_of thm = 0 then thm else simp_discharge_all ctxt (simp_discharge ctxt thm)
fun fold1 _ [] = raise Empty
| fold1 f (x :: xs) = fold f xs x
fun bind_list xs f = maps f xs
fun par_bind_list xs f = flat (Par_List.map f xs)
fun subtract1 eq = fold (remove1 eq)
fun list_iterator ls =
let
val r = Unsynchronized.ref ls
fun itr () =
case ! r of
[] => NONE
| l :: ls =>
let
val _ = r := ls
in
SOME l
end
in
itr
end
fun file_iterator f = list_iterator (File.read_lines f)
fun read_sp_clauses_gen it =
let
exception PARSE
fun it' () =
case it () of
NONE => NONE
| SOME l =>
if String.size l = 0 orelse String.sub (l, 0) = #"#" then it' ()
else SOME (map (the o Int.fromString) (String.tokens (fn c => c = #" ") l))
handle Option.Option => raise PARSE
val SOME [m, n] = it' () handle Bind => raise PARSE
val f = nat_to_ranking m
handle RANKING_INDEX m => raise RANKING_INDEX m
fun go acc =
case it' () of
NONE => rev acc
| SOME xs =>
let
val (p1, xs) = chop n xs
val (p2, xs) = chop n xs
val ys = map (fn [x, y] => (f x, f y)) (chop_groups 2 xs)
handle Match => raise PARSE
in
go ((map f p1, map f p2, ys) :: acc)
end
in
go []
end
fun read_sp_clauses_file (path, false) =
read_sp_clauses_gen (file_iterator path)
| read_sp_clauses_file (path, true) =
path
|> Bytes.read
|> XZ.uncompress
|> Bytes.trim_split_lines
|> list_iterator
|> read_sp_clauses_gen
exception PARSE
fun read_profiles_gen it =
let
fun it' () =
case it () of
NONE => NONE
| SOME l =>
if String.size l = 0 orelse String.sub (l, 0) = #"#" then it' ()
else SOME (map (the o Int.fromString) (String.tokens (fn c => c = #" ") l))
handle Option.Option => raise PARSE
val SOME [m, n] = it' ()
val f = nat_to_ranking m
handle RANKING_INDEX m => raise RANKING_INDEX m
fun go acc =
case it' () of
NONE => rev acc
| SOME xs => if length xs <> n then raise PARSE else go (map f xs :: acc)
in
go []
end
handle Bind => raise PARSE
fun read_profiles_file (path, false) =
read_profiles_gen (file_iterator path)
| read_profiles_file (path, true) =
path
|> Bytes.read
|> XZ.uncompress
|> Bytes.trim_split_lines
|> list_iterator
|> read_profiles_gen
end