Theory Fixed_Length_Vector

theory Fixed_Length_Vector
  imports "HOL-Library.Numeral_Type" "HOL-Library.Code_Cardinality"
begin

lemma zip_map_same: zip (map f xs) (map g xs) = map (λx. (f x, g x)) xs
  by (induction xs) auto

section ‹Type class for indexing›

text ‹
  The index› class is used to define an isomorphism between some index types with fixed cardinality
  and a subset of the natural numbers. Crucially, infinite types can be made instance of this class
  too, since Isabelle defines infinite cardinality to be equal to zero.

  The index1› class then adds more properties, such as injectivity, which can only be satisfied for
  finite index types.

  This class is roughly similar to the classenum class defined in the Isabelle library, which is
  proved at a later point. However, classenum does not admit infinite types.
›

class index =
  fixes from_index :: 'a  nat
    and to_index :: nat  'a
  assumes to_from_index: n < CARD('a)  from_index (to_index n) = n
  assumes from_index_surj: n < CARD('a)  a. from_index a = n
begin

text ‹A list of all possible indexes.›

definition indexes :: 'a list where
indexes = map to_index [0..<CARD('a)]

text ‹There are as many indexes as the cardinality of type typ'a.›

lemma indexes_length[simp]: length indexes = CARD('a)
  unfolding indexes_def by auto

lemma list_cong_index:
  assumes length xs = CARD('a) length ys = CARD('a)
  assumes i. xs ! from_index i = ys ! from_index i
  shows xs = ys
  apply (rule nth_equalityI)
  using assms from_index_surj by auto

lemma from_indexE:
  assumes n < CARD('a)
  obtains a where from_index a = n
  using assms by (metis from_index_surj)

end

text ‹Restrict classindex to only admit finite types.›

class index1 = index +
  assumes from_index_bound[simp]: from_index a < CARD('a)
  assumes from_index_inj: inj from_index
begin

text ‹Finiteness follows from the class assumptions.›

lemma card_nonzero[simp]: 0 < CARD('a)
  by (metis less_nat_zero_code from_index_bound neq0_conv)

lemma finite_type[simp]: finite (UNIV :: 'a set)
  by (metis card_nonzero  card.infinite less_irrefl)

sublocale finite
  by standard simp

text constto_index and constfrom_index form a bijection.›

lemma from_to_index[simp]: to_index (from_index i) = i
  by (meson injD from_index_bound from_index_inj to_from_index)

lemma indexes_from_index[simp]: indexes ! from_index i = i
  unfolding indexes_def by auto

lemma to_index_inj_on: inj_on to_index {0..<CARD('a)}
  by (rule inj_onI) (force elim: from_indexE)


end

text ‹Finally, we instantiate the class for the pre-defined numeral types.›

instantiation num0 :: index
begin

definition from_index_num0 :: num0  nat where from_index_num0 _ = undefined
definition to_index_num0 :: nat  num0 where to_index_num0 _ = undefined

instance
  by standard auto

end


lemma indexes_zero[simp]: indexes = ([] :: 0 list)
  by (auto simp: indexes_def)

instantiation num1 :: index1
begin

definition from_index_num1 :: num1  nat where [simp]: from_index_num1 _ = 0
definition to_index_num1 :: nat  num1 where [simp]: to_index_num1 _ = 1

instance
  by standard (auto simp: inj_on_def)

end


lemma indexes_one[simp]: indexes = [1 :: 1]
  by (auto simp: indexes_def)

instantiation bit0 :: (finite) index1
begin

definition from_index_bit0 :: 'a bit0  nat where from_index_bit0 x = nat (Rep_bit0 x)

definition to_index_bit0 :: nat  'a bit0 where to_index_bit0  of_nat

instance
  apply standard
  subgoal
    by (simp add: to_index_bit0_def from_index_bit0_def bit0.of_nat_eq Abs_bit0_inverse)
  subgoal for n
    unfolding from_index_bit0_def by (auto simp: Abs_bit0_inverse intro!: exI[where x = Abs_bit0 (int n)])
  subgoal for n
    using Rep_bit0[of n]
    by (simp add: from_index_bit0_def nat_less_iff)
  subgoal
    unfolding from_index_bit0_def inj_on_def
    by (metis Rep_bit0 Rep_bit0_inverse atLeastLessThan_iff int_nat_eq)
  done

end

instantiation bit1 :: (finite) index1
begin

definition from_index_bit1 :: 'a bit1  nat where from_index_bit1 x = nat (Rep_bit1 x)

definition to_index_bit1 :: nat  'a bit1 where to_index_bit1  of_nat

instance
  apply standard
  subgoal
    by (simp add: to_index_bit1_def from_index_bit1_def bit1.of_nat_eq Abs_bit1_inverse)
  subgoal for n
    unfolding from_index_bit1_def by (auto simp: Abs_bit1_inverse intro!: exI[where x = Abs_bit1 (int n)])
  subgoal for n
    using Rep_bit1[of n]
    by (simp add: from_index_bit1_def nat_less_iff)
  subgoal
    unfolding from_index_bit1_def inj_on_def
    by (metis Rep_bit1 Rep_bit1_inverse atLeastLessThan_iff eq_nat_nat_iff)
  done

end

lemma indexes_bit_simps:
  (indexes :: 'a :: finite bit0 list) = map of_nat [0..<2 * CARD('a)]
  (indexes :: 'b :: finite bit1 list) = map of_nat [0..<2 * CARD('b) + 1]
  unfolding indexes_def to_index_bit0_def to_index_bit1_def
  by simp+

text ‹The following class and instance definitions connect constindexes to constenum_class.enum.›

class index_enum = index1 + enum +
  assumes indexes_eq_enum: indexes = enum_class.enum

instance num1 :: index_enum
  by standard (auto simp: indexes_def enum_num1_def)

instance bit0 :: (finite) index_enum
  by standard (auto simp: indexes_def to_index_bit0_def enum_bit0_def Abs_bit0'_def bit0.of_nat_eq)


instance bit1 :: (finite) index_enum
  by standard (auto simp: indexes_def to_index_bit1_def enum_bit1_def Abs_bit1'_def bit1.of_nat_eq)

section ‹Type definition and BNF setup›

text ‹
  A vector is a list with a fixed length, where the length is given by the cardinality of the second
  type parameter. To obtain the unit vector, we can choose an infinite type. There is no reason to
  restrict the index type to a particular sort constraint at this point, even though later on,
  classindex is frequently used.
›

typedef ('a, 'b) vec = "{xs. length xs = CARD('b)} :: 'a list set"
  morphisms list_of_vec vec_of_list
  by (rule exI[where x = replicate CARD('b) undefined]) simp

declare vec.list_of_vec_inverse[simp]

type_notation vec (infixl "^" 15)

setup_lifting type_definition_vec

lift_definition map_vec :: ('a  'b)  'a ^ 'c  'b ^ 'c is map by auto

lift_definition set_vec :: 'a ^ 'b  'a set is set .

lift_definition rel_vec :: ('a  'b  bool)  'a ^ 'c  'b ^ 'c  bool is list_all2 .

lift_definition pred_vec :: ('a  bool)  'a ^ 'b  bool is list_all .

lift_definition zip_vec :: 'a ^ 'c  'b ^ 'c  ('a × 'b) ^ 'c is zip by auto

lift_definition replicate_vec :: 'a  'a ^ 'b is replicate CARD('b) by auto

bnf ('a, 'b) vec
  map: map_vec
  sets: set_vec
  bd: natLeq
  wits: replicate_vec
  rel: rel_vec
  pred: pred_vec
  subgoal
    apply (rule ext)
    by transfer' auto
  subgoal
    apply (rule ext)
    by transfer' auto
  subgoal
    by transfer' auto
  subgoal
    apply (rule ext)
    by transfer' auto
  subgoal by (fact natLeq_card_order)
  subgoal by (fact natLeq_cinfinite)
  subgoal by (fact regularCard_natLeq)
  subgoal
    apply transfer'
    apply (simp flip: finite_iff_ordLess_natLeq)
    done
  subgoal
    apply (rule predicate2I)
    apply transfer'
    by (smt (verit) list_all2_trans relcompp.relcompI)
  subgoal
    apply (rule ext)+
    apply transfer
    by (auto simp: list.in_rel)
  subgoal
    apply (rule ext)
    apply transfer'
    by (auto simp: list_all_iff)
  subgoal
    by transfer' auto
  done

section ‹Indexing›

lift_definition nth_vec' :: 'a ^ 'b  nat  'a is nth .

lift_definition nth_vec :: 'a ^ 'b  'b :: index1  'a (infixl "$" 90)
  ― ‹We fix this to classindex1 because indexing a unit vector makes no sense.›
  is λxs. nth xs  from_index .

lemma nth_vec_alt_def: nth_vec v = nth_vec' v  from_index
  by transfer' auto

text ‹
  We additionally define a notion of converting a function into a vector, by mapping over all
  constindexes.
›

lift_definition vec_lambda :: ('b :: index  'a)  'a ^ 'b (binder "χ" 10)
  is λf. map f indexes by simp

lemma vec_lambda_nth[simp]: vec_lambda f $ i = f i
  by transfer auto

section ‹Unit vector›

text ‹
  The ‹unit vector› is the unique vector of length zero. We use typ0 as index type, but
  typnat or any other infinite type would work just as well.
›

lift_definition unit_vec :: 'a ^ 0 is [] by simp

lemma unit_vec_unique: v = unit_vec
  by transfer auto

lemma unit_vec_unique_simp[simp]: NO_MATCH v unit_vec  v = unit_vec
  by (rule unit_vec_unique)

lemma set_unit_vec[simp]: set_vec (v :: 'a ^ 0) = {}
  by transfer auto

lemma map_unit_vec[simp]: map_vec f v = unit_vec
  by simp

lemma zip_unit_vec[simp]: zip_vec u v = unit_vec
  by simp

lemma rel_unit_vec[simp]: rel_vec R (u :: 'a ^ 0) v  True
  by transfer auto

lemma pred_unit_vec[simp]: pred_vec P (v :: 'a ^ 0)
  by (simp add: vec.pred_set)


section ‹General lemmas›

lemmas vec_simps[simp] =
  map_vec.rep_eq
  zip_vec.rep_eq
  replicate_vec.rep_eq

lemmas map_vec_cong[fundef_cong] = map_cong[Transfer.transferred]

lemmas rel_vec_cong = list.rel_cong[Transfer.transferred]

lemmas pred_vec_cong = list.pred_cong[Transfer.transferred]

lemma vec_eq_if: "list_of_vec f = list_of_vec g  f = g"
  by (metis list_of_vec_inverse)

lemma vec_cong: "(i. f $ i = g $ i)  f = g"
  by transfer (simp add: list_cong_index)

lemma finite_set_vec[intro, simp]: finite (set_vec v)
  by transfer' auto

lemma set_vec_in[intro, simp]: v $ i  set_vec v
  by transfer auto

lemma set_vecE[elim]:
  assumes x  set_vec v
  obtains i where x = v $ i
  using assms
  by transfer (auto simp: in_set_conv_nth elim: from_indexE)

lemma map_nth_vec[simp]: map_vec f v $ i = f (v $ i)
  by transfer auto

lemma replicate_nth_vec[simp]: replicate_vec a $ i = a
  by transfer auto

lemma replicate_set_vec[simp]: set_vec (replicate_vec a :: 'a ^ 'b :: index1) = {a}
  by transfer simp

lemma vec_explode: v = (χ i. v $ i)
  by (rule vec_cong) auto

lemma vec_explode1:
  fixes v :: 'a ^ 1
  obtains a where v = (χ _. a)
  apply (rule that[of v $ 1])
  apply (subst vec_explode[of v])
  apply (rule arg_cong[where f = vec_lambda])
  apply (rule ext)
  apply (subst num1_eq1)
  by (rule refl)

lemma zip_nth_vec[simp]: zip_vec u v $ i = (u $ i, v $ i)
  by transfer auto

lemma zip_vec_fst[simp]: map_vec fst (zip_vec u v) = u
  by transfer auto

lemma zip_vec_snd[simp]: map_vec snd (zip_vec u v) = v
  by transfer auto

lemma zip_lambda_vec[simp]: zip_vec (vec_lambda f) (vec_lambda g) = (χ i. (f i, g i))
  by transfer' (simp add: zip_map_same)

lemma zip_map_vec: zip_vec (map_vec f u) (map_vec g v) = map_vec (λ(x, y). (f x, g y)) (zip_vec u v)
  by transfer' (auto simp: zip_map1 zip_map2)

lemma list_of_vec_length[simp]: length (list_of_vec (v :: 'a ^ 'b)) = CARD('b)
  using list_of_vec by blast

lemma list_vec_list: length xs = CARD('n)  list_of_vec (vec_of_list xs :: 'a ^ 'n) = xs
  by (subst vec.vec_of_list_inverse) auto

lemma map_vec_list: length xs = CARD('n)  map_vec f (vec_of_list xs :: 'a ^ 'n) = vec_of_list (map f xs)
  by (rule map_vec.abs_eq) (auto simp: eq_onp_def)

lemma set_vec_list: length xs = CARD('n)  set_vec (vec_of_list xs :: 'a ^ 'n) = set xs
  by (rule set_vec.abs_eq) (auto simp: eq_onp_def)

lemma list_all_zip: length xs = length ys  list_all P (zip xs ys)  list_all2 (λx y. P (x, y)) xs ys
  by (erule list_induct2) auto

lemma pred_vec_zip: pred_vec P (zip_vec xs ys)  rel_vec (λx y. P (x, y)) xs ys
  by transfer (simp add: list_all_zip)

lemma list_all2_left: length xs = length ys  list_all2 (λx y. P x) xs ys  list_all P xs
  by (erule list_induct2) auto

lemma list_all2_right: length xs = length ys  list_all2 (λ_. P) xs ys  list_all P ys
  by (erule list_induct2) auto

lemma rel_vec_left: rel_vec (λx y. P x) xs ys  pred_vec P xs
  by transfer (simp add: list_all2_left)

lemma rel_vec_right: rel_vec (λ_. P) xs ys  pred_vec P ys
  by transfer (simp add: list_all2_right)


section ‹Instances›

definition bounded_lists :: nat  'a set  'a list set where
bounded_lists n A = {xs. length xs = n  list_all (λx. x  A) xs}

lemma bounded_lists_finite:
  assumes finite A
  shows finite (bounded_lists n A)
proof (induction n)
  case (Suc n)
  moreover have bounded_lists (Suc n) A  (λ(x, xs). x # xs) ` (A × bounded_lists n A)
    unfolding bounded_lists_def
    by (force simp: length_Suc_conv split_beta)
  ultimately show ?case
    using assms by (meson finite_SigmaI finite_imageI finite_subset)
qed (simp add: bounded_lists_def)

lemma bounded_lists_bij: bij_betw list_of_vec (UNIV :: ('a ^ 'b) set) (bounded_lists CARD('b) UNIV)
  unfolding bij_betw_def bounded_lists_def
  by (metis (no_types, lifting) Ball_set Collect_cong UNIV_I inj_def type_definition.Rep_range type_definition_vec vec_eq_if)


text ‹If the base type is classfinite, so is the vector type.›

instance vec :: (finite, type) finite
  apply standard
  apply (subst bij_betw_finite[OF bounded_lists_bij])
  apply (rule bounded_lists_finite)
  by simp

text ‹The constsize of the vector is the constlength of the underlying list.›

instantiation vec :: (type, type) size
begin

lift_definition size_vec :: 'a ^ 'b  nat is length .

instance ..

end

lemma size_vec_alt_def[simp]: size (v :: 'a ^ 'b) = CARD('b)
  by transfer simp

text ‹Vectors can be compared for equality.›

instantiation vec :: (equal, type) equal
begin

lift_definition equal_vec :: 'a ^ 'b  'a ^ 'b  bool is equal_class.equal .

instance
  apply standard
  apply transfer'
  by (simp add: equal_list_def)

end

section ‹Further operations›

subsection ‹Distinctness›

lift_definition distinct_vec :: 'a ^ 'n  bool is distinct .

lemma distinct_vec_alt_def: distinct_vec v  (i j. i  j  v $ i  v $ j)
  apply transfer
  unfolding distinct_conv_nth comp_apply
  by (metis from_index_bound from_to_index to_from_index)

lemma distinct_vecI:
  assumes i j. i  j  v $ i  v $ j
  shows distinct_vec v
  using assms unfolding distinct_vec_alt_def by simp

lemma distinct_vec_mapI: distinct_vec xs  inj_on f (set_vec xs)  distinct_vec (map_vec f xs)
  by transfer' (metis distinct_map)

lemma distinct_vec_zip_fst: distinct_vec u  distinct_vec (zip_vec u v)
  by transfer' (metis distinct_zipI1)

lemma distinct_vec_zip_snd: distinct_vec v  distinct_vec (zip_vec u v)
  by transfer' (metis distinct_zipI2)

lemma inj_set_of_vec: distinct_vec (map_vec f v)  inj_on f (set_vec v)
  by transfer' (metis distinct_map)

lemma distinct_vec_list: length xs = CARD('n)  distinct_vec (vec_of_list xs :: 'a ^ 'n)  distinct xs
  by (subst distinct_vec.rep_eq) (simp add: list_vec_list)

subsection ‹Summing›

lift_definition sum_vec :: 'b::comm_monoid_add ^ 'a  'b is sum_list .

lemma sum_vec_lambda: sum_vec (vec_lambda v) = sum_list (map v indexes)
  by transfer simp

lemma elem_le_sum_vec:
  fixes f :: 'a :: canonically_ordered_monoid_add ^ 'b :: index1
  shows "f $ i  sum_vec f"
  by transfer (simp add: elem_le_sum_list)


section ‹Code setup›

text ‹
  Since constvec_of_list cannot be directly used in code generation, we defined a convenience
  wrapper that checks the length and aborts if necessary.
›

definition replicate' where replicate' n = replicate n undefined

declare [[code abort: replicate']]

lift_definition vec_of_list' :: 'a list  'a ^ 'n
  is λxs. if length xs  CARD('n) then replicate' CARD('n) else xs
  by (auto simp: replicate'_def)


experiment begin

proposition
  sum_vec (χ (i::2). (3::nat)) = 6
  distinct_vec (vec_of_list' [1::nat, 2] :: nat ^ 2)
  ¬ distinct_vec (vec_of_list' [1::nat, 1] :: nat ^ 2)
  by eval+

end

export_code
  sum_vec
  map_vec
  rel_vec
  pred_vec
  set_vec
  zip_vec
  distinct_vec
  list_of_vec
  vec_of_list'
  checking SML

lifting_update vec.lifting
lifting_forget vec.lifting

bundle vec_syntax begin
type_notation
  vec (infixl "^" 15)
notation
  nth_vec (infixl "$" 90) and
  vec_lambda (binder "χ" 10)
end

bundle no_vec_syntax begin
no_type_notation
  vec (infixl "^" 15)
no_notation
  nth_vec (infixl "$" 90) and
  vec_lambda (binder "χ" 10)
end

unbundle no_vec_syntax

end