Theory Containers.RBT_Set2

(*  Title:      Containers/RBT_Set2.thy
    Author:     Andreas Lochbihler, KIT *)

theory RBT_Set2 
imports
  RBT_Mapping2
begin

section ‹Sets implemented by red-black trees›

lemma map_of_map_Pair_const:
  "map_of (map (λx. (x, v)) xs) = (λx. if x  set xs then Some v else None)"
by(induct xs) auto

lemma map_of_rev_unit [simp]: 
  fixes xs :: "('a * unit) list" 
  shows "map_of (rev xs) = map_of xs"
by(induct xs rule: rev_induct)(auto simp add: map_add_def split: option.split)

lemma fold_split_conv_map_fst: "fold (λ(x, y). f x) xs = fold f (map fst xs)"
by(simp add: fold_map o_def split_def)

lemma foldr_split_conv_map_fst: "foldr (λ(x, y). f x) xs = foldr f (map fst xs)"
by(simp add: foldr_map o_def split_def fun_eq_iff)

lemma set_foldr_Cons:
  "set (foldr (λx xs. if P x xs then x # xs else xs) as [])  set as"
by(induct as) auto

lemma distinct_fst_foldr_Cons:
  "distinct (map f as)  distinct (map f (foldr (λx xs. if P x xs then x # xs else xs) as []))"
proof(induct as)
  case (Cons a as)
  with set_foldr_Cons[of P as] show ?case by auto
qed simp

lemma filter_conv_foldr:
  "filter P xs = foldr (λx xs. if P x then x # xs else xs) xs []"
by(induct xs) simp_all

lemma map_of_filter: "map_of (filter (λx. P (fst x)) xs) = map_of xs |` Collect P"
by(induct xs)(simp_all add: fun_eq_iff restrict_map_def)

lemma map_of_map_Pair_key: "map_of (map (λk. (k, f k)) xs) x = (if x  set xs then Some (f x) else None)"
by(induct xs) simp_all

lemma neq_Empty_conv: "t  rbt.Empty  (c l k v r. t = Branch c l k v r)"
by(cases t) simp_all

context linorder begin

lemma is_rbt_RBT_fold_rbt_insert [simp]:
  "is_rbt t  is_rbt (fold (λ(k, v). rbt_insert k v) xs t)"
by(induct xs arbitrary: t)(simp_all add: split_beta)

lemma rbt_lookup_RBT_fold_rbt_insert [simp]: 
  "is_rbt t  rbt_lookup (fold (λ(k, v). rbt_insert k v) xs t) = rbt_lookup t ++ map_of (rev xs)"
apply(induct xs arbitrary: t rule: rev_induct)
apply(simp_all add: split_beta fun_eq_iff rbt_lookup_rbt_insert)
done

lemma is_rbt_fold_rbt_delete [simp]:
  "is_rbt t  is_rbt (fold rbt_delete xs t)"
by(induct xs arbitrary: t)(simp_all)

lemma rbt_lookup_fold_rbt_delete [simp]: 
  "is_rbt t  rbt_lookup (fold rbt_delete xs t) = rbt_lookup t |` (- set xs)"
apply(induct xs rule: rev_induct)
apply(simp_all add: rbt_lookup_rbt_delete ext)
apply(metis Un_insert_right compl_sup sup_bot_right)
done

lemma is_rbt_fold_rbt_insert: "is_rbt t  is_rbt (fold (λk. rbt_insert k (f k)) xs t)"
by(induct xs rule: rev_induct) simp_all

lemma rbt_lookup_fold_rbt_insert: 
  "is_rbt t  
  rbt_lookup (fold (λk. rbt_insert k (f k)) xs t) = 
  rbt_lookup t ++ map_of (map (λk. (k, f k)) xs)"
by(induct xs arbitrary: t)(auto simp add: rbt_lookup_rbt_insert map_add_def fun_eq_iff map_of_map_Pair_key split: option.splits)

end

definition fold_rev :: "('a  'b  'c  'c)  ('a, 'b) rbt  'c  'c"
where "fold_rev f t = List.foldr (λ(k, v). f k v) (RBT_Impl.entries t)"

lemma fold_rev_simps [simp, code]:
  "fold_rev f RBT_Impl.Empty = id"
  "fold_rev f (Branch c l k v r) = fold_rev f l o f k v o fold_rev f r"
by(simp_all add: fold_rev_def fun_eq_iff)

context linorder begin

lemma sorted_fst_foldr_Cons:
  "sorted (map f as)  sorted (map f (foldr (λx xs. if P x xs then x # xs else xs) as []))"
proof(induct as)
  case (Cons a as)
  with set_foldr_Cons[of P as] show ?case by(auto)
qed simp

end

subsection ‹Type and operations›

type_synonym 'a set_rbt = "('a, unit) mapping_rbt"

translations 
  (type) "'a set_rbt" <= (type) "('a, unit) mapping_rbt"

abbreviation (input) Set_RBT :: "('a :: ccompare, unit) RBT_Impl.rbt  'a set_rbt"
where "Set_RBT  Mapping_RBT"

subsection ‹Primitive operations›

lift_definition member :: "'a :: ccompare set_rbt  'a  bool" is
  "λt x. x  dom (rbt_comp_lookup ccomp t)" .

abbreviation empty :: "'a :: ccompare set_rbt"
where "empty  RBT_Mapping2.empty"

abbreviation insert :: "'a :: ccompare  'a set_rbt  'a set_rbt" 
where "insert k  RBT_Mapping2.insert k ()"

abbreviation remove :: "'a :: ccompare  'a set_rbt  'a set_rbt"
where "remove  RBT_Mapping2.delete"

lift_definition bulkload :: "'a :: ccompare list  'a set_rbt" is
  "rbt_comp_bulkload ccomp  map (λx. (x, ()))"
by(auto 4 3 intro: linorder.rbt_bulkload_is_rbt ID_ccompare simp: rbt_comp_bulkload[OF ID_ccompare'])

abbreviation is_empty :: "'a :: ccompare set_rbt  bool"
where "is_empty  RBT_Mapping2.is_empty"

abbreviation union :: "'a :: ccompare set_rbt  'a set_rbt  'a set_rbt"
where "union  RBT_Mapping2.join (λ_ _. id)"

abbreviation inter :: "'a :: ccompare set_rbt  'a set_rbt  'a set_rbt"
where "inter  RBT_Mapping2.meet (λ_ _. id)"

lift_definition inter_list :: "'a :: ccompare set_rbt  'a list  'a set_rbt" is
  "λt xs. fold (λk. rbt_comp_insert ccomp k ()) [x  xs. rbt_comp_lookup ccomp t x  None] RBT_Impl.Empty"
by(auto 4 3 intro: ID_ccompare linorder.is_rbt_fold_rbt_insert ord.Empty_is_rbt simp: rbt_comp_simps[OF ID_ccompare'])

lift_definition minus :: "'a :: ccompare set_rbt  'a set_rbt  'a set_rbt" is 
  "rbt_comp_minus ccomp"
by(auto 4 3 intro: linorder.rbt_minus_is_rbt ID_ccompare simp: rbt_comp_minus[OF ID_ccompare'])

abbreviation filter :: "('a :: ccompare  bool)  'a set_rbt  'a set_rbt"
where "filter P  RBT_Mapping2.filter (P  fst)"

lift_definition fold :: "('a :: ccompare  'b  'b)  'a set_rbt  'b  'b" is "λf. RBT_Impl.fold (λa _. f a)" .

lift_definition fold1 :: "('a :: ccompare  'a  'a)  'a set_rbt  'a" is "RBT_Impl_fold1" .

lift_definition keys :: "'a :: ccompare set_rbt  'a list" is "RBT_Impl.keys" .

abbreviation all :: "('a :: ccompare  bool)  'a set_rbt  bool"
where "all P  RBT_Mapping2.all (λk _. P k)"

abbreviation ex :: "('a :: ccompare  bool)  'a set_rbt  bool"
where "ex P  RBT_Mapping2.ex (λk _. P k)"

definition product :: "'a :: ccompare set_rbt  'b :: ccompare set_rbt  ('a × 'b) set_rbt"
where "product rbt1 rbt2 = RBT_Mapping2.product (λ_ _ _ _. ()) rbt1 rbt2"

abbreviation Id_on :: "'a :: ccompare set_rbt  ('a × 'a) set_rbt"
where "Id_on  RBT_Mapping2.diag"

abbreviation init :: "'a :: ccompare set_rbt  ('a, unit, 'a) rbt_generator_state"
where "init  RBT_Mapping2.init"

subsection ‹Properties›

lemma member_empty [simp]:
  "member empty = (λ_. False)"
by(simp add: member_def empty_def Mapping_RBT_inverse ord.Empty_is_rbt ord.rbt_lookup.simps fun_eq_iff)

lemma fold_conv_fold_keys: "RBT_Set2.fold f rbt b = List.fold f (RBT_Set2.keys rbt) b"
by(simp add: RBT_Set2.fold_def RBT_Set2.keys_def RBT_Impl.fold_def RBT_Impl.keys_def fold_map o_def split_def)

lemma fold_conv_fold_keys':
  "fold f t = List.fold f (RBT_Impl.keys (RBT_Mapping2.impl_of t))"
by(simp add: fold.rep_eq RBT_Impl.fold_def RBT_Impl.keys_def fold_map o_def split_def)

lemma member_lookup [code]: "member t x  RBT_Mapping2.lookup t x = Some ()"
by transfer auto

lemma unfoldr_rbt_keys_generator:
  "list.unfoldr rbt_keys_generator (init t) = keys t"
by transfer(simp add: unfoldr_rbt_keys_generator)

lemma keys_eq_Nil_iff [simp]: "keys rbt = []  rbt = empty"
by transfer(case_tac rbt, simp_all)

lemma fold1_conv_fold: "fold1 f rbt = List.fold f (tl (keys rbt)) (hd (keys rbt))"
by transfer(simp add: RBT_Impl_fold1_def)

context assumes ID_ccompare_neq_None: "ID CCOMPARE('a :: ccompare)  None"
begin

lemma set_linorder: "class.linorder (cless_eq :: 'a  'a  bool) cless"
using ID_ccompare_neq_None by(clarsimp)(rule ID_ccompare)

lemma ccomp_comparator: "comparator (ccomp :: 'a comparator)"
  using ID_ccompare_neq_None by(clarsimp)(rule ID_ccompare')
  
lemmas rbt_comps = rbt_comp_simps[OF ccomp_comparator] rbt_comp_minus[OF ccomp_comparator] 

lemma is_rbt_impl_of [simp, intro]:
  fixes t :: "'a set_rbt"
  shows "ord.is_rbt cless (RBT_Mapping2.impl_of t)"
  using ID_ccompare_neq_None impl_of [of t] by auto

lemma member_RBT:
  "ord.is_rbt cless t  member (Set_RBT t) (x :: 'a)  ord.rbt_lookup cless t x = Some ()"
by(auto simp add: member_def Mapping_RBT_inverse rbt_comps)

lemma member_impl_of:
  "ord.rbt_lookup cless (RBT_Mapping2.impl_of t) (x :: 'a) = Some ()  member t x"
by transfer (auto simp: rbt_comps)

lemma member_insert [simp]:
  "member (insert x (t :: 'a set_rbt)) = (member t)(x := True)"
by transfer (simp add: fun_eq_iff linorder.rbt_lookup_rbt_insert[OF set_linorder] ID_ccompare_neq_None)

lemma member_fold_insert [simp]:
  "member (List.fold insert xs (t :: 'a set_rbt)) = (λx. member t x  x  set xs)"
by(induct xs arbitrary: t) auto

lemma member_remove [simp]:
  "member (remove (x :: 'a) t) = (member t)(x := False)"
by transfer (simp add: linorder.rbt_lookup_rbt_delete[OF set_linorder] ID_ccompare_neq_None fun_eq_iff)

lemma member_bulkload [simp]:
  "member (bulkload xs) (x :: 'a)  x  set xs"
by transfer (auto simp add: linorder.rbt_lookup_rbt_bulkload[OF set_linorder] rbt_comps map_of_map_Pair_const split: if_split_asm)

lemma member_conv_keys: "member t = (λx :: 'a. x  set (keys t))"
by(transfer)(simp add: ID_ccompare_neq_None linorder.rbt_lookup_keys[OF set_linorder] ord.is_rbt_rbt_sorted)

lemma is_empty_empty [simp]:
  "is_empty t  t = empty"
by transfer (simp split: rbt.split)

lemma RBT_lookup_empty [simp]:
  "ord.rbt_lookup cless (t :: ('a, unit) rbt) = Map.empty  t = RBT_Impl.Empty"
proof -
  interpret linorder "cless_eq :: 'a  'a  bool" cless by(rule set_linorder)
  show ?thesis by(cases t)(auto simp add: fun_eq_iff)
qed

lemma member_empty_empty [simp]:
  "member t = (λ_. False)  (t :: 'a set_rbt) = empty"
by transfer(simp add: ID_ccompare_neq_None fun_eq_iff RBT_lookup_empty[symmetric])

lemma member_union [simp]:
  "member (union (t1 :: 'a set_rbt) t2) = (λx. member t1 x  member t2 x)"
by(auto simp add: member_lookup fun_eq_iff lookup_join[OF ID_ccompare_neq_None] split: option.split)

lemma member_minus [simp]:
  "member (minus (t1 :: 'a set_rbt) t2) = (λx. member t1 x  ¬ member t2 x)"
by(transfer)(auto simp add: ID_ccompare_neq_None fun_eq_iff rbt_comps linorder.rbt_lookup_rbt_minus[OF set_linorder] ord.is_rbt_rbt_sorted)

lemma member_inter [simp]:
  "member (inter (t1 :: 'a set_rbt) t2) = (λx. member t1 x  member t2 x)"
by(auto simp add: member_lookup fun_eq_iff lookup_meet[OF ID_ccompare_neq_None] split: option.split)

lemma member_inter_list [simp]:
  "member (inter_list (t :: 'a set_rbt) xs) = (λx. member t x  x  set xs)"
by transfer(auto simp add: ID_ccompare_neq_None fun_eq_iff linorder.rbt_lookup_fold_rbt_insert[OF set_linorder] ord.Empty_is_rbt map_of_map_Pair_key ord.rbt_lookup.simps rel_option_iff split: if_split_asm option.split_asm)

lemma member_filter [simp]:
  "member (filter P (t :: 'a set_rbt)) = (λx. member t x  P x)"
by(simp add: member_lookup fun_eq_iff lookup_filter[OF ID_ccompare_neq_None] split: option.split)

lemma distinct_keys [simp]:
  "distinct (keys (rbt :: 'a set_rbt))"
by transfer(simp add: ID_ccompare_neq_None RBT_Impl.keys_def ord.is_rbt_rbt_sorted linorder.distinct_entries[OF set_linorder])

lemma all_conv_all_member:
  "all P t  (x :: 'a. member t x  P x)"
by(simp add: member_lookup all_conv_all_lookup[OF ID_ccompare_neq_None])

lemma ex_conv_ex_member:
  "ex P t  (x :: 'a. member t x  P x)"
by(simp add: member_lookup ex_conv_ex_lookup[OF ID_ccompare_neq_None])

lemma finite_member: "finite (Collect (RBT_Set2.member (t :: 'a set_rbt)))"
by transfer (simp add: rbt_comps linorder.finite_dom_rbt_lookup[OF set_linorder])

lemma member_Id_on: "member (Id_on t) = (λ(k :: 'a, k'). k = k'  member t k)"
by(simp add: member_lookup[abs_def] diag_lookup[OF ID_ccompare_neq_None] fun_eq_iff)

context assumes ID_ccompare_neq_None': "ID CCOMPARE('b :: ccompare)  None"
begin

lemma set_linorder': "class.linorder (cless_eq :: 'b  'b  bool) cless"
using ID_ccompare_neq_None' by(clarsimp)(rule ID_ccompare)

lemma member_product:
  "member (product rbt1 rbt2) = (λab :: 'a × 'b. ab  Collect (member rbt1) × Collect (member rbt2))"
by(auto simp add: fun_eq_iff member_lookup product_def RBT_Mapping2.lookup_product ID_ccompare_neq_None ID_ccompare_neq_None' split: option.splits)

end

end

lemma sorted_RBT_Set_keys: 
  "ID CCOMPARE('a :: ccompare) = Some c 
   linorder.sorted (le_of_comp c) (RBT_Set2.keys rbt)"
by transfer(auto simp add: RBT_Set2.keys.rep_eq RBT_Impl.keys_def linorder.rbt_sorted_entries[OF ID_ccompare] ord.is_rbt_rbt_sorted)

context assumes ID_ccompare_neq_None: "ID CCOMPARE('a :: {ccompare, lattice})  None"
begin

lemma set_linorder2: "class.linorder (cless_eq :: 'a  'a  bool) cless"
using ID_ccompare_neq_None by(clarsimp)(rule ID_ccompare)

end

lemma set_keys_Mapping_RBT: "set (keys (Mapping_RBT t)) = set (RBT_Impl.keys t)"
proof(cases t)
  case Empty thus ?thesis
    by(clarsimp simp add: Mapping_RBT_def keys.rep_eq is_ccompare_def Mapping_RBT'_inverse ord.is_rbt_def ord.rbt_sorted.simps)
next
  case (Branch c l k v r)
  show ?thesis
  proof(cases "is_ccompare TYPE('a)  ¬ ord.is_rbt cless (Branch c l k v r)")
    case False thus ?thesis using Branch
      by(auto simp add: Mapping_RBT_def keys.rep_eq is_ccompare_def Mapping_RBT'_inverse simp del: not_None_eq)
  next
    case True
    thus ?thesis using Branch
      by(clarsimp simp add: Mapping_RBT_def keys.rep_eq is_ccompare_def Mapping_RBT'_inverse RBT_ext.linorder.is_rbt_fold_rbt_insert_impl[OF ID_ccompare] linorder.rbt_insert_is_rbt[OF ID_ccompare] ord.Empty_is_rbt)(subst linorder.rbt_lookup_keys[OF ID_ccompare, symmetric], assumption, auto simp add: linorder.rbt_sorted_fold_insert[OF ID_ccompare] RBT_ext.linorder.rbt_lookup_fold_rbt_insert_impl[OF ID_ccompare] RBT_ext.linorder.rbt_lookup_rbt_insert'[OF ID_ccompare] linorder.rbt_insert_rbt_sorted[OF ID_ccompare] ord.is_rbt_rbt_sorted ord.Empty_is_rbt dom_map_of_conv_image_fst RBT_Impl.keys_def ord.rbt_lookup.simps)
  qed
qed

hide_const (open) member empty insert remove bulkload union minus
  keys fold fold_rev filter all ex product Id_on init

end