Theory Ix

(*<*)
theory Ix
imports
  Main
begin

(*>*)
subsection‹ A Haskell-like ‹Ix› class\label{sec:Ix} ›

text‹

We allow arbitrary indexing schemes for user-facing arrays via the
Ix› class, which essentially represents a bijection
between a subset of an arbitrary type and an initial segment of the
naturals.

Source materials:
  Haskell 2010 report: 🌐‹https://www.haskell.org/onlinereport/haskell2010/haskellch19.html›
  GHC implementation: 🌐‹https://hackage.haskell.org/package/base-4.16.0.0/docs/src/GHC.Ix.html›
  Haskell pure arrays (just for colour): 🌐‹https://www.haskell.org/onlinereport/haskell2010/haskellch14.html›
  SML 2D arrays: 🌐‹https://smlfamily.github.io/Basis/array2.html›

Observations:
  follow Haskell convention here: include the bounds
  could alternatively use an array of one-dimensional arrays but those are not necessarily rectangular
  we can't use classenum as that requires the whole type to be enumerable

Limitations:
  the basic design assumes laziness; we don't ever want to build the list of indices
   can be improved either by tweaking the code generator setup or changing the constants here
  array indices typically have partial predecessor and successor operations and are totally ordered on their domain
  no guarantee the interval› is correct (does not prevent off-by-one errors in instances)

›

class Ix =
  fixes interval :: "'a × 'a  'a list"
  fixes index :: "'a × 'a  'a  nat"
  assumes index: "i  set (interval b)  interval b ! index b i = i"
  assumes interval: "map (index b) (interval b) = [0..<length (interval b)]"

lemma index_length:
  assumes "i  set (interval b)"
  shows "index b i < length (interval b)"
proof -
  from assms[unfolded in_set_conv_nth]
  obtain j where "j < length (interval b)" and "interval b ! j = i"
    by blast
  with arg_cong[where f="λx. List.nth x j", OF interval[of b]] show ?thesis
    by simp
qed

lemma distinct_interval:
  shows "distinct (interval b)"
by (metis distinct_map distinct_upt interval)

lemma inj_on_index:
  shows "inj_on (index b) (set (interval b))"
by (metis distinct_map distinct_upt interval)

lemma index_eq_conv:
  assumes "i  set (interval b)"
  assumes "j  set (interval b)"
  shows "index b i = index b j  i = j"
by (metis assms index)

lemma index_inv_into:
  assumes "i < length (interval b)"
  shows "inv_into (set (interval b)) (index b) i  set (interval b)"
by (metis assms add.left_neutral inv_into_into length_map list.set_map interval nth_mem nth_upt)

lemma linear_order_on:
  shows "linear_order_on (set (interval b)) {(i, j). {i, j}  set (interval b)  index b i  index b j}"
by (force simp: linear_order_on_def partial_order_on_def preorder_on_def refl_on_def total_on_def
         intro: transI antisymI
          dest: index)

lemma interval_map:
  shows "map (λi. f (interval b ! i)) [0..<length (interval b)] = map f (interval b)"
by (simp add: map_equality_iff)

lemma index_forE:
  assumes "i < length (interval b)"
  obtains j where "j  set (interval b)" and "index b j = i"
using assms index index_length nth_eq_iff_index_eq[OF distinct_interval] nth_mem[OF assms] by blast

instantiation unit :: Ix
begin

definition "interval_unit = (λ(x::unit, y::unit). [()])"
definition "index_unit = (λ(x::unit, y::unit) _::unit. 0::nat)"

instance by standard (auto simp: interval_unit_def index_unit_def)

end

instantiation nat :: Ix
begin

definition "interval_nat = (λ(l, u::nat). [l..<Suc u])" ―‹ bounds are inclusive ›
definition "index_nat = (λ(l, u::nat) i::nat. i - l)"

lemma upt_minus:
  shows "map (λi. i - l) [l..<u] = [0..<u - l]"
by (induct u) (auto simp: Suc_diff_le)

instance by standard (auto simp: interval_nat_def index_nat_def upt_minus nth_append)

end

instantiation int :: Ix
begin

definition "interval_int = (λ(l, u::int). [l..u])" ―‹ bounds are inclusive ›
definition "index_int = (λ(l, u::int) i::int. nat (i - l))"

lemma upto_minus:
  shows "map (λi. nat (i - l)) [l..u] = [0..<nat (u - l + 1)]"
proof(induct "nat(u - l + 1)" arbitrary: u)
  case (Suc i)
  from Suc.hyps(1)[of "u - 1"] Suc.hyps(2) show ?case
    by (simp add: upto_rec2 ac_simps Suc_nat_eq_nat_zadd1 flip: upt_Suc_append)
qed simp

instance by standard (auto simp: interval_int_def index_int_def upto_minus)

end

type_synonym ('i, 'j) two_dim = "('i × 'j) × ('i × 'j)"

instantiation prod :: (Ix, Ix) Ix
begin

definition "interval_prod = (λ((l, l'), (u, u')). List.product (interval (l, u)) (interval (l', u')))"
definition "index_prod = (λ((l, l'), (u, u')) (i, i'). index (l, u) i * length (interval (l', u')) + index (l', u') i')"

abbreviation (input) fst_bounds :: "('a × 'b) × ('a × 'b)  ('a × 'a)" where
  "fst_bounds b  (fst (fst b), fst (snd b))"

abbreviation (input) snd_bounds :: "('a × 'b) × ('a × 'b)  ('b × 'b)" where
  "snd_bounds b  (snd (fst b), snd (snd b))"

lemma inj_on_index_prod:
  shows "inj_on (index ((l, l'), (u, u'))) (set (interval ((l, l'), (u, u'))))"
by (clarsimp simp: inj_on_def interval_prod_def index_prod_def)
   (metis index index_length length_pos_if_in_set add_diff_cancel_right'
          div_mult_self_is_m mod_less mod_mult_self3)

instance
proof
  show "interval b ! index b i = i" if "i  set (interval b)" for b and i :: "'a × 'b"
  proof -
    have *: "i * n + j < m * n" if "i < m"  and "j < n"
     for i j m n :: nat
      using that by (metis bot_nat_0.extremum_strict div_less div_less_iff_less_mult div_mult_self3 nat_arith.rule0 not_gr_zero)
    from that
    have "index (fst_bounds b) (fst i) * length (interval (snd_bounds b))
            + index (snd_bounds b) (snd i)
        < length (interval (fst_bounds b)) * length (interval (snd_bounds b))"
      by (clarsimp simp: interval_prod_def index_prod_def * dest!: index_length)
    then show ?thesis
      using that length_pos_if_in_set
      by (fastforce simp: interval_prod_def index_prod_def List.product_nth index index_length)
  qed
  show "map (index b) (interval b) = [0..<length (interval b)]" for b :: "('a × 'b) × ('a × 'b)"
    by (rule iffD2[OF list_eq_iff_nth_eq])
       (clarsimp simp: interval_prod_def index_prod_def split_def product_nth ac_simps;
        metis (no_types, lifting) distinct_interval index index_length length_pos_if_in_set nth_mem
              less_mult_imp_div_less mod_div_mult_eq mod_less_divisor mult.commute nth_eq_iff_index_eq)
qed

end

setup Sign.mandatory_path "Ix"

setup Sign.mandatory_path "prod"

lemma interval_conv:
  shows "(x, y)  set (interval b)  x  set (interval (fst_bounds b))  y  set (interval (snd_bounds b))"
by (force simp: interval_prod_def)

setup Sign.parent_path

type_synonym 'i square = "('i, 'i) two_dim"

definition square :: "'i::Ix Ix.square  bool" where
  "square = (λ((l, l'), (u, u')). Ix.interval (l, u) = Ix.interval (l', u'))"

setup Sign.mandatory_path "square"

lemma conv:
  assumes "Ix.square b"
  shows "i  set (Ix.interval (fst_bounds b))  i  set (Ix.interval (snd_bounds b))"
using assms by (clarsimp simp: Ix.square_def)

setup Sign.parent_path

setup Sign.parent_path

hide_const (open) interval index
(*<*)

end
(*>*)