Theory Word_Lib.More_Word_Operations

(*
 * Copyright Data61, CSIRO (ABN 41 687 119 230)
 *
 * SPDX-License-Identifier: BSD-2-Clause
 *)

section ‹Misc word operations›

theory More_Word_Operations
  imports
    "HOL-Library.Word"
    Aligned
    Reversed_Bit_Lists
    More_Misc
    Signed_Words
    Word_Lemmas
    Many_More
    Word_EqI
begin

context
  includes bit_operations_syntax
begin

definition
  ptr_add :: "'a :: len word  nat  'a word" where
  "ptr_add ptr n  ptr + of_nat n"

definition
  alignUp :: "'a::len word  nat  'a word" where
 "alignUp x n  x + 2 ^ n - 1 AND NOT (2 ^ n - 1)"

lemma alignUp_unfold:
  alignUp w n = (w + mask n) AND NOT (mask n)
  by (simp add: alignUp_def mask_eq_exp_minus_1 add_mask_fold)

(* standard notation for blocks of 2^n-1 words, usually aligned;
   abbreviation so it simplifies directly *)
abbreviation mask_range :: "'a::len word  nat  'a word set" where
  "mask_range p n  {p .. p + mask n}"

definition
  w2byte :: "'a :: len word  8 word" where
  "w2byte  ucast"

(* Count leading zeros  *)
definition
  word_clz :: "'a::len word  nat"
where
  "word_clz w  length (takeWhile Not (to_bl w))"

(* Count trailing zeros  *)
definition
  word_ctz :: "'a::len word  nat"
where
  "word_ctz w  length (takeWhile Not (rev (to_bl w)))"

lemma word_ctz_unfold:
  word_ctz w = length (takeWhile (Not  bit w) [0..<LENGTH('a)]) for w :: 'a::len word
  by (simp add: word_ctz_def rev_to_bl_eq takeWhile_map)

lemma word_ctz_unfold':
  word_ctz w = Min (insert LENGTH('a) {n. bit w n}) for w :: 'a::len word
proof (cases n. bit w n)
  case True
  then obtain n where bit w n ..
  from bit w n show ?thesis
    apply (simp add: word_ctz_unfold)
    apply (subst Min_eq_length_takeWhile [symmetric])
      apply (auto simp add: bit_imp_le_length)
    apply (subst Min_insert)
      apply auto
    apply (subst min.absorb2)
     apply (subst Min_le_iff)
       apply auto
    apply (meson bit_imp_le_length order_less_le)
    done
next
  case False
  then have bit w = bot
    by auto
  then have word_ctz w = LENGTH('a)
    by (simp add: word_ctz_def rev_to_bl_eq bot_fun_def map_replicate_const)
  with bit w = bot show ?thesis
    by simp
qed

lemma word_ctz_le:
  "word_ctz (w :: ('a::len word))  LENGTH('a)"
  apply (clarsimp simp: word_ctz_def)
  using length_takeWhile_le apply (rule order_trans)
  apply simp
  done

lemma word_ctz_less:
  "w  0  word_ctz (w :: ('a::len word)) < LENGTH('a)"
  apply (clarsimp simp: word_ctz_def eq_zero_set_bl)
  using length_takeWhile_less apply (rule less_le_trans)
  apply auto
  done

lemma take_bit_word_ctz_eq [simp]:
  take_bit LENGTH('a) (word_ctz w) = word_ctz w
  for w :: 'a::len word
  apply (simp add: take_bit_nat_eq_self_iff word_ctz_def to_bl_unfold)
  using length_takeWhile_le apply (rule le_less_trans)
  apply simp
  done

lemma word_ctz_not_minus_1:
  word_of_nat (word_ctz (w :: 'a :: len word))  (- 1 :: 'a::len word) if 1 < LENGTH('a)
proof -
  note word_ctz_le
  also from that have LENGTH('a) < mask LENGTH('a)
    by (simp add: less_mask)
  finally have word_ctz w < mask LENGTH('a) .
  then have word_of_nat (word_ctz w) < (word_of_nat (mask LENGTH('a)) :: 'a word)
    by (simp add: of_nat_word_less_iff)
  also have  = - 1
    by (rule bit_word_eqI) (simp add: bit_simps)
  finally show ?thesis
    by simp
qed

lemma unat_of_nat_ctz_mw:
  "unat (of_nat (word_ctz (w :: 'a :: len word)) :: 'a :: len word) = word_ctz w"
  by (simp add: unsigned_of_nat)

lemma unat_of_nat_ctz_smw:
  "unat (of_nat (word_ctz (w :: 'a :: len word)) :: 'a :: len signed word) = word_ctz w"
  by (simp add: unsigned_of_nat)

definition
  word_log2 :: "'a::len word  nat"
where
  "word_log2 (w::'a::len word)  size w - 1 - word_clz w"

(* Bit population count. Equivalent of __builtin_popcount. *)
definition
  pop_count :: "('a::len) word  nat"
where
  "pop_count w  length (filter id (to_bl w))"

(* Sign extension from bit n *)
definition
  sign_extend :: "nat  'a::len word  'a word"
where
  "sign_extend n w  if bit w n then w OR NOT (mask n) else w AND mask n"

lemma sign_extend_eq_signed_take_bit:
  sign_extend = signed_take_bit
proof (rule ext)+
  fix n and w :: 'a::len word
  show sign_extend n w = signed_take_bit n w
  proof (rule bit_word_eqI)
    fix q
    assume q < LENGTH('a)
    then show bit (sign_extend n w) q  bit (signed_take_bit n w) q
      by (auto simp add: bit_signed_take_bit_iff
        sign_extend_def bit_and_iff bit_or_iff bit_not_iff bit_mask_iff not_less
        exp_eq_0_imp_not_bit not_le min_def)
  qed
qed

definition
  sign_extended :: "nat  'a::len word  bool"
where
  "sign_extended n w  i. n < i  i < size w  bit w i = bit w n"

lemma ptr_add_0 [simp]:
  "ptr_add ref 0 = ref "
  unfolding ptr_add_def by simp

lemma pop_count_0[simp]:
  "pop_count 0 = 0"
  by (clarsimp simp:pop_count_def)

lemma pop_count_1[simp]:
  "pop_count 1 = 1"
  by (clarsimp simp:pop_count_def to_bl_1)

lemma pop_count_0_imp_0:
  "(pop_count w = 0) = (w = 0)"
  apply (rule iffI)
   apply (clarsimp simp:pop_count_def)
   apply (subst (asm) filter_empty_conv)
   apply (clarsimp simp:eq_zero_set_bl)
   apply fast
  apply simp
  done

lemma word_log2_zero_eq [simp]:
  word_log2 0 = 0
  by (simp add: word_log2_def word_clz_def word_size)

lemma word_log2_unfold:
  word_log2 w = (if w = 0 then 0 else Max {n. bit w n})
  for w :: 'a::len word
proof (cases w = 0)
  case True
  then show ?thesis
    by simp
next
  case False
  then obtain r where bit w r
    by (auto simp add: bit_eq_iff)
  then have Max {m. bit w m} = LENGTH('a) - Suc (length
    (takeWhile (Not  bit w) (rev [0..<LENGTH('a)])))
    by (subst Max_eq_length_takeWhile [of _ LENGTH('a)])
      (auto simp add: bit_imp_le_length)
  then have word_log2 w = Max {x. bit w x}
    by (simp add: word_log2_def word_clz_def word_size to_bl_unfold rev_map takeWhile_map)
  with w  0 show ?thesis
    by simp
qed

lemma word_log2_eqI:
  word_log2 w = n
  if w  0 bit w n m. bit w m  m  n
  for w :: 'a::len word
proof -
  from w  0 have word_log2 w = Max {n. bit w n}
    by (simp add: word_log2_unfold)
  also have Max {n. bit w n} = n
    using that by (auto intro: Max_eqI)
  finally show ?thesis .
qed

lemma bit_word_log2:
  bit w (word_log2 w) if w  0
proof -
  from w  0 have r. bit w r
    by (auto intro: bit_eqI)
  then obtain r where bit w r ..
  from w  0 have word_log2 w = Max {n. bit w n}
    by (simp add: word_log2_unfold)
  also have Max {n. bit w n}  {n. bit w n}
    using bit w r by (subst Max_in) auto
  finally show ?thesis
    by simp
qed

lemma word_log2_maximum:
  n  word_log2 w if bit w n
proof -
  have n  Max {n. bit w n}
    using that by (auto intro: Max_ge)
  also from that have w  0
    by force
  then have Max {n. bit w n} = word_log2 w
    by (simp add: word_log2_unfold)
  finally show ?thesis .
qed

lemma word_log2_nth_same:
  "w  0  bit w (word_log2 w)"
  by (drule bit_word_log2) simp

lemma word_log2_nth_not_set:
  " word_log2 w < i ; i < size w   ¬ bit w i"
  using word_log2_maximum [of w i] by auto

lemma word_log2_highest:
  assumes a: "bit w i"
  shows "i  word_log2 w"
  using a by (simp add: word_log2_maximum)

lemma word_log2_max:
  "word_log2 w < size w"
  apply (cases w = 0)
   apply (simp_all add: word_size)
  apply (drule bit_word_log2)
  apply (fact bit_imp_le_length)
  done

lemma word_clz_0[simp]:
  "word_clz (0::'a::len word) = LENGTH('a)"
  unfolding word_clz_def by simp

lemma word_clz_minus_one[simp]:
  "word_clz (-1::'a::len word) = 0"
  unfolding word_clz_def by simp

lemma is_aligned_alignUp[simp]:
  "is_aligned (alignUp p n) n"
  by (simp add: alignUp_def is_aligned_mask mask_eq_decr_exp word_bw_assocs)

lemma alignUp_le[simp]:
  "alignUp p n  p + 2 ^ n - 1"
  unfolding alignUp_def by (rule word_and_le2)

lemma alignUp_idem:
  fixes a :: "'a::len word"
  assumes "is_aligned a n" "n < LENGTH('a)"
  shows "alignUp a n = a"
  using assms unfolding alignUp_def
  by (metis add_cancel_right_right add_diff_eq and_mask_eq_iff_le_mask mask_eq_decr_exp mask_out_add_aligned order_refl word_plus_and_or_coroll2)

lemma alignUp_not_aligned_eq:
  fixes a :: "'a :: len word"
  assumes al: "¬ is_aligned a n"
  and     sz: "n < LENGTH('a)"
  shows   "alignUp a n = (a div 2 ^ n + 1) * 2 ^ n"
proof -
  from n < LENGTH('a) have (2::int) ^ n < 2 ^ LENGTH('a)
    by simp
  with take_bit_int_less_exp [of n]
    have *: take_bit n k < 2 ^ LENGTH('a) for k :: int
    by (rule less_trans)
  have anz: "a mod 2 ^ n  0"
    by (rule not_aligned_mod_nz) fact+
  then have um: "unat (a mod 2 ^ n - 1) div 2 ^ n = 0"
    apply (transfer fixing: n) using sz
    apply (simp flip: take_bit_eq_mod add: div_eq_0_iff)
    apply (subst take_bit_int_eq_self)
    using *
     apply (auto simp add: diff_less_eq intro: less_imp_le)
    apply (simp add: less_le)
    done
  have "a + 2 ^ n - 1 = (a div 2 ^ n) * 2 ^ n + (a mod 2 ^ n) + 2 ^ n - 1"
    by (simp add: word_mod_div_equality)
  also have " = (a mod 2 ^ n - 1) + (a div 2 ^ n + 1) * 2 ^ n"
    by (simp add: field_simps)
  finally show "alignUp a n = (a div 2 ^ n + 1) * 2 ^ n" using sz
    unfolding alignUp_def
    apply (subst mask_eq_decr_exp [symmetric])
    apply (erule ssubst)
    apply (subst neg_mask_is_div)
    apply (simp add: word_arith_nat_div)
    apply (subst unat_word_ariths(1) unat_word_ariths(2))+
    apply (subst uno_simps)
    apply (subst unat_1)
    apply (subst mod_add_right_eq)
    apply simp
    apply (subst power_mod_div)
    apply (subst div_mult_self1)
     apply simp
    apply (subst um)
    apply simp
    apply (subst mod_mod_power)
    apply simp
    apply (subst word_unat_power, subst Abs_fnat_hom_mult)
    apply (subst mult_mod_left)
    apply (subst power_add [symmetric])
    apply simp
    apply (subst Abs_fnat_hom_1)
    apply (subst Abs_fnat_hom_add)
    apply (subst word_unat_power, subst Abs_fnat_hom_mult)
    apply (subst word_unat.Rep_inverse[symmetric], subst Abs_fnat_hom_mult)
    apply simp
    done
qed

lemma alignUp_ge:
  fixes a :: "'a :: len word"
  assumes sz: "n < LENGTH('a)"
  and nowrap: "alignUp a n  0"
  shows "a  alignUp a n"
proof (cases "is_aligned a n")
  case True
  then show ?thesis using sz
    by (subst alignUp_idem, simp_all)
next
  case False

  have lt0: "unat a div 2 ^ n < 2 ^ (LENGTH('a) - n)" using sz
    by (metis le_add_diff_inverse2 less_mult_imp_div_less order_less_imp_le power_add unsigned_less)

  have"2 ^ n * (unat a div 2 ^ n + 1)  2 ^ LENGTH('a)" using sz
    by (metis One_nat_def Suc_leI add.right_neutral add_Suc_right lt0 nat_le_power_trans nat_less_le)
  moreover have "2 ^ n * (unat a div 2 ^ n + 1)  2 ^ LENGTH('a)" using nowrap sz
    apply -
    apply (erule contrapos_nn)
    apply (subst alignUp_not_aligned_eq [OF False sz])
    apply (subst unat_arith_simps)
    apply (subst unat_word_ariths)
    apply (subst unat_word_ariths)
    apply simp
    apply (subst mult_mod_left)
    apply (simp add: unat_div field_simps power_add[symmetric] mod_mod_power)
    done
  ultimately have lt: "2 ^ n * (unat a div 2 ^ n + 1) < 2 ^ LENGTH('a)" by simp

  have "a = a div 2 ^ n * 2 ^ n + a mod 2 ^ n" by (rule word_mod_div_equality [symmetric])
  also have " < (a div 2 ^ n + 1) * 2 ^ n" using sz lt
    apply (simp add: field_simps)
    apply (rule word_add_less_mono1)
    apply (rule word_mod_less_divisor)
    apply (simp add: word_less_nat_alt)
    apply (subst unat_word_ariths)
    apply (simp add: unat_div)
    done
  also have " =  alignUp a n"
    by (rule alignUp_not_aligned_eq [symmetric]) fact+
  finally show ?thesis by (rule order_less_imp_le)
qed

lemma alignUp_le_greater_al:
  fixes x :: "'a :: len word"
  assumes le: "a  x"
  and     sz: "n < LENGTH('a)"
  and     al: "is_aligned x n"
  shows   "alignUp a n  x"
proof (cases "is_aligned a n")
  case True
  then show ?thesis using sz le by (simp add: alignUp_idem)
next
  case False

  then have anz: "a mod 2 ^ n  0"
    by (rule not_aligned_mod_nz)

  from al obtain k where xk: "x = 2 ^ n * of_nat k" and kv: "k < 2 ^ (LENGTH('a) - n)"
    by (auto elim!: is_alignedE)

  then have kn: "unat (of_nat k :: 'a word) * unat ((2::'a word) ^ n) < 2 ^ LENGTH('a)"
    using sz
    apply (subst unat_of_nat_eq)
     apply (erule order_less_le_trans)
     apply simp
    apply (subst mult.commute)
    apply simp
    apply (rule nat_less_power_trans)
     apply simp
    apply simp
    done

  have au: "alignUp a n = (a div 2 ^ n + 1) * 2 ^ n"
    by (rule alignUp_not_aligned_eq) fact+
  also have "  of_nat k * 2 ^ n"
  proof (rule word_mult_le_mono1 [OF inc_le _ kn])
    show "a div 2 ^ n < of_nat k" using kv xk le sz anz
      by (simp add: alignUp_div_helper)

    show "(0:: 'a word) < 2 ^ n" using sz by (simp add: p2_gt_0 sz)
  qed

  finally show ?thesis using xk by (simp add: field_simps)
qed

lemma alignUp_is_aligned_nz:
  fixes a :: "'a :: len word"
  assumes al: "is_aligned x n"
  and     sz: "n < LENGTH('a)"
  and     ax: "a  x"
  and     az: "a  0"
  shows   "alignUp (a::'a :: len word) n  0"
proof (cases "is_aligned a n")
  case True
  then have "alignUp a n = a" using sz by (simp add: alignUp_idem)
  then show ?thesis using az by simp
next
  case False
  then have anz: "a mod 2 ^ n  0"
    by (rule not_aligned_mod_nz)

  {
    assume asm: "alignUp a n = 0"

    have lt0: "unat a div 2 ^ n < 2 ^ (LENGTH('a) - n)" using sz
      by (metis le_add_diff_inverse2 less_mult_imp_div_less order_less_imp_le power_add unsigned_less)

    have leq: "2 ^ n * (unat a div 2 ^ n + 1)  2 ^ LENGTH('a)" using sz
      by (metis One_nat_def Suc_leI add.right_neutral add_Suc_right lt0 nat_le_power_trans
                order_less_imp_le)

    from al obtain k where  kv: "k < 2 ^ (LENGTH('a) - n)" and xk: "x = 2 ^ n * of_nat k"
      by (auto elim!: is_alignedE)

    then have "a div 2 ^ n < of_nat k" using ax sz anz
      by (rule alignUp_div_helper)

    then have r: "unat a div 2 ^ n < k" using sz
      by (simp flip: drop_bit_eq_div unat_drop_bit_eq) (metis leI le_unat_uoi unat_mono)

    have "alignUp a n = (a div 2 ^ n + 1) * 2 ^ n"
      by (rule alignUp_not_aligned_eq) fact+

    then have " = 0" using asm by simp
    then have "2 ^ LENGTH('a) dvd 2 ^ n * (unat a div 2 ^ n + 1)"
      using sz by (simp add: unat_arith_simps ac_simps)
                  (simp add: unat_word_ariths mod_simps mod_eq_0_iff_dvd)
    with leq have "2 ^ n * (unat a div 2 ^ n + 1) = 2 ^ LENGTH('a)"
      by (force elim!: le_SucE)
    then have "unat a div 2 ^ n = 2 ^ LENGTH('a) div 2 ^ n - 1"
      by (metis (no_types, opaque_lifting) Groups.add_ac(2) add.right_neutral
                add_diff_cancel_left' div_le_dividend div_mult_self4 gr_implies_not0
                le_neq_implies_less power_eq_0_iff zero_neq_numeral)
    then have "unat a div 2 ^ n = 2 ^ (LENGTH('a) - n) - 1"
      using sz by (simp add: power_sub)
    then have "2 ^ (LENGTH('a) - n) - 1 < k" using r
      by simp
    then have False using kv by simp
  } then show ?thesis by clarsimp
qed

lemma alignUp_ar_helper:
  fixes a :: "'a :: len word"
  assumes al: "is_aligned x n"
  and     sz: "n < LENGTH('a)"
  and    sub: "{x..x + 2 ^ n - 1}  {a..b}"
  and    anz: "a  0"
  shows "a  alignUp a n  alignUp a n + 2 ^ n - 1  b"
proof
  from al have xl: "x  x + 2 ^ n - 1" by (simp add: is_aligned_no_overflow)

  from xl sub have ax: "a  x"
    by auto

  show "a  alignUp a n"
  proof (rule alignUp_ge)
    show "alignUp a n  0" using al sz ax anz
      by (rule alignUp_is_aligned_nz)
  qed fact+

  show "alignUp a n + 2 ^ n - 1  b"
  proof (rule order_trans)
    from xl show tp: "x + 2 ^ n - 1  b" using sub
      by auto

    from ax have "alignUp a n  x"
      by (rule alignUp_le_greater_al) fact+
    then have "alignUp a n + (2 ^ n - 1)  x + (2 ^ n - 1)"
      using xl al is_aligned_no_overflow' olen_add_eqv word_plus_mcs_3 by blast
    then show "alignUp a n + 2 ^ n - 1  x + 2 ^ n - 1"
      by (simp add: field_simps)
  qed
qed

lemma alignUp_def2:
  "alignUp a sz = a + 2 ^ sz - 1 AND NOT (mask sz)"
  by (simp add: alignUp_def flip: mask_eq_decr_exp)

lemma alignUp_def3:
  "alignUp a sz = 2^ sz + (a - 1 AND NOT (mask sz))"
  by (simp add: alignUp_def2 is_aligned_triv field_simps mask_out_add_aligned)

lemma  alignUp_plus:
  "is_aligned w us  alignUp (w + a) us  = w + alignUp a us"
  by (clarsimp simp: alignUp_def2 mask_out_add_aligned field_simps)

lemma alignUp_distance:
  "alignUp (q :: 'a :: len word) sz - q  mask sz"
  by (metis (no_types) add.commute add_diff_cancel_left alignUp_def2 diff_add_cancel
                       mask_2pm1 subtract_mask(2) word_and_le1 word_sub_le_iff)

lemma is_aligned_diff_neg_mask:
  "is_aligned p sz  (p - q AND NOT (mask sz)) = (p - ((alignUp q sz) AND NOT (mask sz)))"
  apply (clarsimp simp only:word_and_le2 diff_conv_add_uminus)
  apply (subst mask_out_add_aligned[symmetric]; simp)
  apply (simp add: eq_neg_iff_add_eq_0)
  apply (subst add.commute)
  apply (simp add: alignUp_distance is_aligned_neg_mask_eq mask_out_add_aligned and_mask_eq_iff_le_mask flip: mask_eq_x_eq_0)
  done

lemma word_clz_max:
  "word_clz w  size (w::'a::len word)"
  unfolding word_clz_def
  by (metis length_takeWhile_le word_size_bl)

lemma word_clz_nonzero_max:
  fixes w :: "'a::len word"
  assumes nz: "w  0"
  shows "word_clz w < size (w::'a::len word)"
proof -
  {
    assume a: "word_clz w = size (w::'a::len word)"
    hence "length (takeWhile Not (to_bl w)) = length (to_bl w)"
      by (simp add: word_clz_def word_size)
    hence allj: "jset(to_bl w). ¬ j"
      by (metis a length_takeWhile_less less_irrefl_nat word_clz_def)
    hence "to_bl w = replicate (length (to_bl w)) False"
      using eq_zero_set_bl nz by fastforce
    hence "w = 0"
      by (metis to_bl_0 word_bl.Rep_eqD word_bl_Rep')
    with nz have False by simp
  }
  thus ?thesis using word_clz_max
    by (fastforce intro: le_neq_trans)
qed

(* Sign extension from bit n. *)

lemma bin_sign_extend_iff [bit_simps]:
  bit (sign_extend e w) i  bit w (min e i)
  if i < LENGTH('a) for w :: 'a::len word
  using that by (simp add: sign_extend_def bit_simps min_def)

lemma sign_extend_bitwise_if:
  "i < size w  bit (sign_extend e w) i  (if i < e then bit w i else bit w e)"
  by (simp add: word_size bit_simps)

lemma sign_extend_bitwise_if'  [word_eqI_simps]:
  i < LENGTH('a)  bit (sign_extend e w) i  (if i < e then bit w i else bit w e)
  for w :: 'a::len word
  using sign_extend_bitwise_if [of i w e] by (simp add: word_size)

lemma sign_extend_bitwise_disj:
  "i < size w  bit (sign_extend e w) i  i  e  bit w i  e  i  bit w e"
  by (auto simp: sign_extend_bitwise_if)

lemma sign_extend_bitwise_cases:
  "i < size w  bit (sign_extend e w) i  (i  e  bit w i)  (e  i  bit w e)"
  by (auto simp: sign_extend_bitwise_if)

lemmas sign_extend_bitwise_disj' = sign_extend_bitwise_disj[simplified word_size]
lemmas sign_extend_bitwise_cases' = sign_extend_bitwise_cases[simplified word_size]

(* Often, it is easier to reason about an operation which does not overwrite
   the bit which determines which mask operation to apply. *)
lemma sign_extend_def':
  "sign_extend n w = (if bit w n then w OR NOT (mask (Suc n)) else w AND mask (Suc n))"
  by (rule bit_word_eqI) (auto simp add: bit_simps sign_extend_eq_signed_take_bit min_def less_Suc_eq_le)

lemma sign_extended_sign_extend:
  "sign_extended n (sign_extend n w)"
  by (clarsimp simp: sign_extended_def word_size sign_extend_bitwise_if)

lemma sign_extended_iff_sign_extend:
  "sign_extended n w  sign_extend n w = w"
  apply auto
   apply (auto simp add: bit_eq_iff)
    apply (simp_all add: bit_simps sign_extend_eq_signed_take_bit not_le min_def sign_extended_def word_size split: if_splits)
  using le_imp_less_or_eq apply auto
  done

lemma sign_extended_weaken:
  "sign_extended n w  n  m  sign_extended m w"
  unfolding sign_extended_def by (cases "n < m") auto

lemma sign_extend_sign_extend_eq:
  "sign_extend m (sign_extend n w) = sign_extend (min m n) w"
  by (rule bit_word_eqI) (simp add: sign_extend_eq_signed_take_bit bit_simps)

lemma sign_extended_high_bits:
  " sign_extended e p; j < size p; e  i; i < j   bit p i = bit p j"
  by (drule (1) sign_extended_weaken; simp add: sign_extended_def)

lemma sign_extend_eq:
  "w AND mask (Suc n) = v AND mask (Suc n)  sign_extend n w = sign_extend n v"
  by (simp flip: take_bit_eq_mask add: sign_extend_eq_signed_take_bit signed_take_bit_eq_iff_take_bit_eq)

lemma sign_extended_add:
  assumes p: "is_aligned p n"
  assumes f: "f < 2 ^ n"
  assumes e: "n  e"
  assumes "sign_extended e p"
  shows "sign_extended e (p + f)"
proof (cases "e < size p")
  case True
  note and_or = is_aligned_add_or[OF p f]
  have "¬ bit f e"
    using True e less_2p_is_upper_bits_unset[THEN iffD1, OF f]
    by (fastforce simp: word_size)
  hence i: "bit (p + f) e = bit p e"
    by (simp add: and_or bit_simps)
  have fm: "f AND mask e = f"
    by (fastforce intro: subst[where P="λf. f AND mask e = f", OF less_mask_eq[OF f]]
                  simp: mask_twice e)
  show ?thesis
    using assms
     apply (simp add: sign_extended_iff_sign_extend sign_extend_def i)
     apply (simp add: and_or word_bw_comms[of p f])
     apply (clarsimp simp: word_ao_dist fm word_bw_assocs split: if_splits)
    done
next
  case False thus ?thesis
    by (simp add: sign_extended_def word_size)
qed

lemma sign_extended_neq_mask:
  "sign_extended n ptr; m  n  sign_extended n (ptr AND NOT (mask m))"
  by (fastforce simp: sign_extended_def word_size neg_mask_test_bit bit_simps)

definition
  "limited_and (x :: 'a :: len word) y  (x AND y = x)"

lemma limited_and_eq_0:
  " limited_and x z; y AND NOT z = y   x AND y = 0"
  unfolding limited_and_def
  apply (subst arg_cong2[where f="(AND)"])
    apply (erule sym)+
  apply (simp(no_asm) add: word_bw_assocs word_bw_comms word_bw_lcs)
  done

lemma limited_and_eq_id:
  " limited_and x z; y AND z = z   x AND y = x"
  unfolding limited_and_def
  by (erule subst, fastforce simp: word_bw_lcs word_bw_assocs word_bw_comms)

lemma lshift_limited_and:
  "limited_and x z  limited_and (x << n) (z << n)"
  using push_bit_and [of n x z] by (simp add: limited_and_def shiftl_def)

lemma rshift_limited_and:
  "limited_and x z  limited_and (x >> n) (z >> n)"
  using drop_bit_and [of n x z] by (simp add: limited_and_def shiftr_def)

lemmas limited_and_simps1 = limited_and_eq_0 limited_and_eq_id

lemmas is_aligned_limited_and
    = is_aligned_neg_mask_eq[unfolded mask_eq_decr_exp, folded limited_and_def]

lemmas limited_and_simps = limited_and_simps1
       limited_and_simps1[OF is_aligned_limited_and]
       limited_and_simps1[OF lshift_limited_and]
       limited_and_simps1[OF rshift_limited_and]
       limited_and_simps1[OF rshift_limited_and, OF is_aligned_limited_and]
       not_one_eq

definition
  from_bool :: "bool  'a::len word" where
  "from_bool b  case b of True  of_nat 1
                         | False  of_nat 0"

lemma from_bool_eq:
  from_bool = of_bool
  by (simp add: fun_eq_iff from_bool_def)

lemma from_bool_0:
  "(from_bool x = 0) = (¬ x)"
  by (simp add: from_bool_def split: bool.split)

lemma from_bool_eq_if':
  "((if P then 1 else 0) = from_bool Q) = (P = Q)"
  by (cases Q) (simp_all add: from_bool_def)

definition
  to_bool :: "'a::len word  bool" where
  "to_bool  (≠) 0"

lemma to_bool_and_1:
  "to_bool (x AND 1)  bit x 0"
  by (simp add: to_bool_def word_and_1)

lemma to_bool_from_bool [simp]:
  "to_bool (from_bool r) = r"
  unfolding from_bool_def to_bool_def
  by (simp split: bool.splits)

lemma from_bool_neq_0 [simp]:
  "(from_bool b  0) = b"
  by (simp add: from_bool_def split: bool.splits)

lemma from_bool_mask_simp [simp]:
  "(from_bool r :: 'a::len word) AND 1 = from_bool r"
  unfolding from_bool_def
  by (clarsimp split: bool.splits)

lemma from_bool_1 [simp]:
  "(from_bool P = 1) = P"
  by (simp add: from_bool_def split: bool.splits)

lemma ge_0_from_bool [simp]:
  "(0 < from_bool P) = P"
  by (simp add: from_bool_def split: bool.splits)

lemma limited_and_from_bool:
  "limited_and (from_bool b) 1"
  by (simp add: from_bool_def limited_and_def split: bool.split)

lemma to_bool_1 [simp]: "to_bool 1" by (simp add: to_bool_def)
lemma to_bool_0 [simp]: "¬to_bool 0" by (simp add: to_bool_def)

lemma from_bool_eq_if:
  "(from_bool Q = (if P then 1 else 0)) = (P = Q)"
  by (cases Q) (simp_all add: from_bool_def)

lemma to_bool_eq_0:
  "(¬ to_bool x) = (x = 0)"
  by (simp add: to_bool_def)

lemma to_bool_neq_0:
  "(to_bool x) = (x  0)"
  by (simp add: to_bool_def)

lemma from_bool_all_helper:
  "(bool. from_bool bool = val  P bool)
      = ((bool. from_bool bool = val)  P (val  0))"
  by (auto simp: from_bool_0)

lemma fold_eq_0_to_bool:
  "(v = 0) = (¬ to_bool v)"
  by (simp add: to_bool_def)

lemma from_bool_to_bool_iff:
  "w = from_bool b  to_bool w = b  (w = 0  w = 1)"
  by (cases b) (auto simp: from_bool_def to_bool_def)

lemma from_bool_eqI:
  "from_bool x = from_bool y  x = y"
  unfolding from_bool_def
  by (auto split: bool.splits)

lemma neg_mask_in_mask_range:
  "is_aligned ptr bits  (ptr' AND NOT(mask bits) = ptr) = (ptr'  mask_range ptr bits)"
  apply (erule is_aligned_get_word_bits)
   apply (rule iffI)
    apply (drule sym)
    apply (simp add: word_and_le2)
    apply (subst word_plus_and_or_coroll, word_eqI_solve)
    apply (metis bit.disj_ac(2) bit.disj_conj_distrib2 le_word_or2 word_and_max word_or_not)
   apply clarsimp
   apply (smt (verit) add.right_neutral eq_iff is_aligned_neg_mask_eq mask_out_add_aligned neg_mask_mono_le
              word_and_not)
  apply (simp add: power_overflow mask_eq_decr_exp)
  done

lemma aligned_offset_in_range:
  " is_aligned (x :: 'a :: len word) m; y < 2 ^ m; is_aligned p n; n  m; n < LENGTH('a) 
    (x + y  {p .. p + mask n}) = (x  mask_range p n)"
  apply (subst disjunctive_add)
   apply (simp add: bit_simps)
  apply (erule is_alignedE')
   apply (auto simp add: bit_simps not_le)[1]
   apply (metis less_2p_is_upper_bits_unset)
  apply (simp only: is_aligned_add_or word_ao_dist flip: neg_mask_in_mask_range)
  apply (subgoal_tac y AND NOT (mask n) = 0)
   apply simp
  apply (metis (full_types) is_aligned_mask is_aligned_neg_mask less_mask_eq word_bw_comms(1) word_bw_lcs(1))
  done

lemma mask_range_to_bl':
  " is_aligned (ptr :: 'a :: len word) bits; bits < LENGTH('a) 
    mask_range ptr bits
       = {x. take (LENGTH('a) - bits) (to_bl x) = take (LENGTH('a) - bits) (to_bl ptr)}"
  apply (rule set_eqI, rule iffI)
   apply clarsimp
   apply (subgoal_tac "y. x = ptr + y  y < 2 ^ bits")
    apply clarsimp
    apply (subst is_aligned_add_conv)
       apply assumption
      apply simp
    apply simp
   apply (rule_tac x="x - ptr" in exI)
   apply (simp add: add_diff_eq[symmetric])
   apply (simp only: word_less_sub_le[symmetric])
   apply (rule word_diff_ls')
    apply (simp add: field_simps mask_eq_decr_exp)
   apply assumption
  apply simp
  apply (subgoal_tac "y. y < 2 ^ bits  to_bl (ptr + y) = to_bl x")
   apply clarsimp
   apply (rule conjI)
    apply (erule(1) is_aligned_no_wrap')
   apply (simp only: add_diff_eq[symmetric] mask_eq_decr_exp)
   apply (rule word_plus_mono_right)
    apply simp
   apply (erule is_aligned_no_wrap')
   apply simp
  apply (rule_tac x="of_bl (drop (LENGTH('a) - bits) (to_bl x))" in exI)
  apply (rule context_conjI)
   apply (rule order_less_le_trans [OF of_bl_length])
    apply simp
   apply simp
  apply (subst is_aligned_add_conv)
     apply assumption
    apply simp
  apply (drule sym)
  apply (simp add: word_rep_drop)
  done

lemma mask_range_to_bl:
  "is_aligned (ptr :: 'a :: len word) bits
    mask_range ptr bits
        = {x. take (LENGTH('a) - bits) (to_bl x) = take (LENGTH('a) - bits) (to_bl ptr)}"
  apply (erule is_aligned_get_word_bits)
   apply (erule(1) mask_range_to_bl')
  apply (rule set_eqI)
  apply (simp add: power_overflow mask_eq_decr_exp)
  done

lemma aligned_mask_range_cases:
  " is_aligned (p :: 'a :: len word) n; is_aligned (p' :: 'a :: len word) n' 
    mask_range p n  mask_range p' n' = {} 
       mask_range p n  mask_range p' n' 
       mask_range p n  mask_range p' n'"
  apply (simp add: mask_range_to_bl)
  apply (rule Meson.disj_comm, rule disjCI)
  apply auto
  apply (subgoal_tac "(n''. LENGTH('a) - n = (LENGTH('a) - n') + n'')
                     (n''. LENGTH('a) - n' = (LENGTH('a) - n) + n'')")
   apply (fastforce simp: take_add)
  apply arith
  done

lemma aligned_mask_range_offset_subset:
  assumes al: "is_aligned (ptr :: 'a :: len word) sz" and al': "is_aligned x sz'"
  and szv: "sz'  sz"
  and xsz: "x < 2 ^ sz"
  shows "mask_range (ptr+x) sz'  mask_range ptr sz"
  using al
proof (rule is_aligned_get_word_bits)
  assume p0: "ptr = 0" and szv': "LENGTH ('a)  sz"
  then have "(2 ::'a word) ^ sz = 0" by simp
  show ?thesis using p0
    by (simp add: 2 ^ sz = 0 mask_eq_decr_exp)
next
  assume szv': "sz < LENGTH('a)"

  hence blah: "2 ^ (sz - sz') < (2 :: nat) ^ LENGTH('a)"
    using szv by auto
  show ?thesis using szv szv'
    apply auto
    using al assms(4) is_aligned_no_wrap' apply blast
    apply (simp only: flip: add_diff_eq add_mask_fold)
    apply (subst add.assoc, rule word_plus_mono_right)
     using al' is_aligned_add_less_t2n xsz
     apply fastforce
    apply (simp add: field_simps szv al is_aligned_no_overflow)
    done
qed

lemma aligned_mask_ranges_disjoint:
  " is_aligned (p :: 'a :: len word) n; is_aligned (p' :: 'a :: len word) n';
     p AND NOT(mask n')  p'; p' AND NOT(mask n)  p 
    mask_range p n  mask_range p' n' = {}"
  using aligned_mask_range_cases
  by (auto simp: neg_mask_in_mask_range)

lemma aligned_mask_ranges_disjoint2:
  " is_aligned p n; is_aligned ptr bits; n  m; n < size p; m  bits;
     (y < 2 ^ (n - m). p + (y << m)  mask_range ptr bits) 
    mask_range p n  mask_range ptr bits = {}"
  apply safe
  apply (simp only: flip: neg_mask_in_mask_range)
  apply (drule_tac x="x AND mask n >> m" in spec)
  apply (erule notE[OF mp])
   apply (simp flip: take_bit_eq_mask add: shiftr_def drop_bit_take_bit)
   apply transfer
  apply simp
   apply (simp add: word_size and_mask_less_size)
  apply (subst disjunctive_add)
   apply (auto simp add: bit_simps word_size intro!: bit_eqI)
  done

lemma word_clz_sint_upper[simp]:
  "LENGTH('a)  3  sint (of_nat (word_clz (w :: 'a :: len word)) :: 'a sword)  int (LENGTH('a))"
  using word_clz_max [of w]
  apply (simp add: word_size signed_of_nat)
  apply (subst signed_take_bit_int_eq_self)
    apply simp_all
   apply (metis negative_zle of_nat_numeral semiring_1_class.of_nat_power)
  apply (drule small_powers_of_2)
  apply (erule le_less_trans)
  apply simp
  done

lemma word_clz_sint_lower[simp]:
  "LENGTH('a)  3
    - sint (of_nat (word_clz (w :: 'a :: len word)) :: 'a signed word)  int (LENGTH('a))"
  apply (subst sint_eq_uint)
  using word_clz_max [of w]
   apply (simp_all add: word_size unsigned_of_nat)
  apply (rule not_msb_from_less)
  apply (simp add: word_less_nat_alt unsigned_of_nat)
  apply (subst take_bit_nat_eq_self)
   apply (simp add: le_less_trans)
  apply (drule small_powers_of_2)
  apply (erule le_less_trans)
  apply simp
  done

lemma mask_range_subsetD:
  " p'  mask_range p n; x'  mask_range p' n'; n'  n; is_aligned p n; is_aligned p' n'  
   x'  mask_range p n"
  using aligned_mask_step by fastforce

lemma add_mult_in_mask_range:
  " is_aligned (base :: 'a :: len word) n; n < LENGTH('a); bits  n; x < 2 ^ (n - bits) 
    base + x * 2^bits  mask_range base n"
  by (simp add: is_aligned_no_wrap' mask_2pm1 nasty_split_lt word_less_power_trans2
                word_plus_mono_right)

lemma from_to_bool_last_bit:
  "from_bool (to_bool (x AND 1)) = x AND 1"
  by (metis from_bool_to_bool_iff word_and_1)

lemma sint_ctz:
  0  sint (of_nat (word_ctz (x :: 'a :: len word)) :: 'a signed word)
      sint (of_nat (word_ctz x) :: 'a signed word)  int (LENGTH('a)) (is ?P  ?Q)
  if LENGTH('a) > 2
proof
  have *: word_ctz x < 2 ^ (LENGTH('a) - Suc 0)
    using word_ctz_le apply (rule le_less_trans)
    using that small_powers_of_2 [of LENGTH('a)] apply simp
    done
  have int (word_ctz x) div 2 ^ (LENGTH('a) - Suc 0) = 0
    apply (rule div_pos_pos_trivial)
     apply (simp_all add: *)
    done
  then show ?P by (simp add: signed_of_nat bit_iff_odd)
  show ?Q
    apply (auto simp add: signed_of_nat)
    apply (subst signed_take_bit_int_eq_self)
      apply (auto simp add: word_ctz_le * minus_le_iff [of _ int (word_ctz x)])
    apply (rule order.trans [of _ 0])
     apply simp_all
    done
qed

lemma unat_of_nat_word_log2:
  "LENGTH('a) < 2 ^ LENGTH('b)
    unat (of_nat (word_log2 (n :: 'a :: len word)) :: 'b :: len word) = word_log2 n"
  by (metis less_trans unat_of_nat_eq word_log2_max word_size)

lemma aligned_mask_diff:
  " is_aligned (dest :: 'a :: len word) bits; is_aligned (ptr :: 'a :: len word) sz;
     bits  sz; sz < LENGTH('a); dest < ptr 
    mask bits + dest < ptr"
  apply (frule_tac p' = ptr in aligned_mask_range_cases, assumption)
  apply (elim disjE)
    apply (drule_tac is_aligned_no_overflow_mask, simp)+
    apply (simp add: algebra_split_simps word_le_not_less)
   apply (drule is_aligned_no_overflow_mask; fastforce)
  apply (simp add: is_aligned_weaken algebra_split_simps)
  apply (auto simp add: not_le)
  using is_aligned_no_overflow_mask leD apply blast
  apply (meson aligned_add_mask_less_eq is_aligned_weaken le_less_trans)
  done

lemma Suc_mask_eq_mask:
  "¬bit a n  a AND mask (Suc n) = a AND mask n" for a::"'a::len word"
  by (metis sign_extend_def sign_extend_def')

lemma word_less_high_bits:
  fixes a::"'a::len word"
  assumes high_bits: "i > n. bit a i = bit b i"
  assumes less: "a AND mask (Suc n) < b AND mask (Suc n)"
  shows "a < b"
proof -
  let ?masked = "λx. x AND NOT (mask (Suc n))"
  from high_bits
  have "?masked a = ?masked b"
    by - word_eqI_solve
  then
  have "?masked a + (a AND mask (Suc n)) < ?masked b + (b AND mask (Suc n))"
    by (metis AND_NOT_mask_plus_AND_mask_eq less word_and_le2 word_plus_strict_mono_right)
  then
  show ?thesis
    by (simp add: AND_NOT_mask_plus_AND_mask_eq)
qed

lemma word_less_bitI:
  fixes a :: "'a::len word"
  assumes hi_bits: "i > n. bit a i = bit b i"
  assumes a_bits: "¬bit a n"
  assumes b_bits: "bit b n" "n < LENGTH('a)"
  shows "a < b"
proof -
  from b_bits
  have "a AND mask n < b AND mask (Suc n)"
    by (metis bit_mask_iff impossible_bit le2p_bits_unset leI lessI less_Suc_eq_le mask_eq_decr_exp
              word_and_less' word_ao_nth)
  with a_bits
  have "a AND mask (Suc n) < b AND mask (Suc n)"
    by (simp add: Suc_mask_eq_mask)
  with hi_bits
  show ?thesis
    by (rule word_less_high_bits)
qed

lemma word_less_bitD:
  fixes a::"'a::len word"
  assumes less: "a < b"
  shows "n. (i > n. bit a i = bit b i)  ¬bit a n  bit b n"
proof -
  define xs where "xs  zip (to_bl a) (to_bl b)"
  define tk where "tk  length (takeWhile (λ(x,y). x = y) xs)"
  define  n where  "n  LENGTH('a) - Suc tk"
  have n_less: "n < LENGTH('a)"
    by (simp add: n_def)
  moreover
  { fix i
    have "¬ i < LENGTH('a)  bit a i = bit b i"
      using bit_imp_le_length by blast
    moreover
    assume "i > n"
    with n_less
    have "i < LENGTH('a)  LENGTH('a) - Suc i < tk"
      unfolding n_def by arith
    hence "i < LENGTH('a)  bit a i = bit b i"
      unfolding n_def tk_def xs_def
      by (fastforce dest: takeWhile_take_has_property_nth simp: rev_nth simp flip: nth_rev_to_bl)
    ultimately
    have "bit a i = bit b i"
      by blast
  }
  note all = this
  moreover
  from less
  have "a  b" by simp
  then
  obtain i where "to_bl a ! i  to_bl b ! i"
    using nth_equalityI word_bl.Rep_eqD word_rotate.lbl_lbl by blast
  then
  have "tk  length xs"
    unfolding tk_def xs_def
    by (metis length_takeWhile_less list_eq_iff_zip_eq nat_neq_iff word_rotate.lbl_lbl)
  then
  have "tk < length xs"
    using length_takeWhile_le order_le_neq_trans tk_def by blast
  from nth_length_takeWhile[OF this[unfolded tk_def]]
  have "fst (xs ! tk)  snd (xs ! tk)"
    by (clarsimp simp: tk_def)
  with `tk < length xs`
  have "bit a n  bit b n"
    by (clarsimp simp: xs_def n_def tk_def nth_rev simp flip: nth_rev_to_bl)
  with less all
  have "¬bit a n  bit b n"
    by (metis n_less order.asym word_less_bitI)
  ultimately
  show ?thesis by blast
qed

lemma word_less_bit_eq:
  "(a < b) = (n < LENGTH('a). (i > n. bit a i = bit b i)  ¬bit a n  bit b n)" for a::"'a::len word"
  by (meson bit_imp_le_length word_less_bitD word_less_bitI)

end

end