Theory WordAbstract

(*
 * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
 * Copyright (c) 2022 Apple Inc. All rights reserved.
 *
 * SPDX-License-Identifier: BSD-2-Clause
 *)

chapter "WA phase: Word Abstraction"

theory WordAbstract
imports
  L2Peephole
  NatBitwise
begin

section "Basic Definitions"

definition "WORD_MAX x  ((2 ^ (len_of x - 1) - 1) :: int)"
definition "WORD_MIN x  (- (2 ^ (len_of x - 1)) :: int)"
definition "UWORD_MAX x  ((2 ^ (len_of x)) - 1 :: nat)"

lemma WORD_values [simplified]:
  "WORD_MAX (TYPE( 8 signed)) = (2 ^  7 - 1)"
  "WORD_MAX (TYPE(16 signed)) = (2 ^ 15 - 1)"
  "WORD_MAX (TYPE(32 signed)) = (2 ^ 31 - 1)"
  "WORD_MAX (TYPE(64 signed)) = (2 ^ 63 - 1)"

  "WORD_MIN (TYPE( 8 signed)) = - (2 ^  7)"
  "WORD_MIN (TYPE(16 signed)) = - (2 ^ 15)"
  "WORD_MIN (TYPE(32 signed)) = - (2 ^ 31)"
  "WORD_MIN (TYPE(64 signed)) = - (2 ^ 63)"

  "UWORD_MAX (TYPE( 8)) = (2 ^  8 - 1)"
  "UWORD_MAX (TYPE(16)) = (2 ^ 16 - 1)"
  "UWORD_MAX (TYPE(32)) = (2 ^ 32 - 1)"
  "UWORD_MAX (TYPE(64)) = (2 ^ 64 - 1)"
  by (auto simp: WORD_MAX_def WORD_MIN_def UWORD_MAX_def)

lemmas WORD_values_add1 =
   WORD_values [THEN arg_cong [where f="λx. x + 1"],
    simplified semiring_norm, simplified numeral_One]

lemmas WORD_values_minus1 =
   WORD_values [THEN arg_cong [where f="λx. x - 1"],
    simplified semiring_norm, simplified numeral_One nat_numeral]

lemmas WORD_values_fold [L1unfold] =
  WORD_values [symmetric]
  WORD_values_add1 [symmetric]
  WORD_values_minus1 [symmetric]

lemma WORD_signed_to_unsigned [simp]:
   "WORD_MAX TYPE('a signed) = WORD_MAX TYPE('a::len)"
   "WORD_MIN TYPE('a signed) = WORD_MIN TYPE('a::len)"
   "UWORD_MAX TYPE('a signed) = UWORD_MAX TYPE('a::len)"
  by (auto simp: WORD_MAX_def WORD_MIN_def UWORD_MAX_def)

(*
 * The following set of theorems allow us to discharge simple
 * equalities involving INT_MIN, INT_MAX and UINT_MAX without
 * the constants being unfolded in the final output.
 *
 * For example:
 *
 *    (4 < INT_MAX)  becomes  True
 *    (x < INT_MAX)  remains  (x < INT_MAX)
 *)

lemma INT_MIN_comparisons [simp]:
  " a  - (2 ^ (len_of TYPE('a) - 1))   a  WORD_MIN (TYPE('a::len))"
  "a < - (2 ^ (len_of TYPE('a) - 1))  a < WORD_MIN (TYPE('a::len))"
  "a  - (2 ^ (len_of TYPE('a) - 1))  a  WORD_MIN (TYPE('a::len))"
  "a > - (2 ^ (len_of TYPE('a) - 1))  a  WORD_MIN (TYPE('a::len))"
  by (auto simp: WORD_MIN_def)

lemma INT_MAX_comparisons [simp]:
  "a  (2 ^ (len_of TYPE('a) - 1)) - 1  a  WORD_MAX (TYPE('a::len))"
  "a < (2 ^ (len_of TYPE('a) - 1)) - 1  a < WORD_MAX (TYPE('a::len))"
  "a  (2 ^ (len_of TYPE('a) - 1)) - 1  a  WORD_MAX (TYPE('a::len))"
  "a > (2 ^ (len_of TYPE('a) - 1)) - 1  a  WORD_MAX (TYPE('a::len))"
  by (auto simp: WORD_MAX_def)

lemma UINT_MAX_comparisons [simp]:
  "x  (2 ^ (len_of TYPE('a))) - 1  x  UWORD_MAX (TYPE('a::len))"
  "x < (2 ^ (len_of TYPE('a))) - 1  x  UWORD_MAX (TYPE('a::len))"
  "x  (2 ^ (len_of TYPE('a))) - 1  x  UWORD_MAX (TYPE('a::len))"
  "x > (2 ^ (len_of TYPE('a))) - 1  x > UWORD_MAX (TYPE('a::len))"
  by (auto simp: UWORD_MAX_def)

lemma is_up_SCAST_same_signed [simp]: "is_up (SCAST (('a::len)  'a signed))"
  unfolding is_up
  by simp

lemma sint_ucast_signed [simp, L2opt]:"sint (UCAST ('a::len  'a signed) x) = sint x"
  using is_up_SCAST_same_signed scast_ucast_norm(2) sint_up_scast
  by (metis scast_scast_id(1))

section "Abstracting values and expressions"

(*
 * This definition is used when we are trying to introduce a new type
 * in the program text: it simply states that introducing a given
 * abstraction is desired in the current context.
 *)
definition "introduce_typ_abs_fn f  True"

declare introduce_typ_abs_fn_def [simp]

lemma introduce_typ_abs_fn:
  "introduce_typ_abs_fn f"
  by simp

(*
 * Show that a binary operator "X" (of type "'a ⇒ 'a ⇒ bool") is an
 * abstraction (over function f) of "X'".
 *
 * For example, (a ≤int b) could be an abstraction of (a ≤w32 b)
 * over the abstraction function "unat".
 *)
definition
  abstract_bool_binop :: "('a  'a  bool)  ('c  'a)
                ('a  'a  bool)  ('c  'c  bool)  bool"
where
  "abstract_bool_binop P f X X'  a b. P (f a) (f b)  (X' a b = X (f a) (f b))"

(* Show that a binary operator "X" (of type "'a ⇒ 'a ⇒ 'b") abstracts "X'". *)
definition
  abstract_binop :: "('a  'a  bool)  ('c  'a)
                ('a  'a  'a)  ('c  'c  'c)  bool"
where
   "abstract_binop P f X X'  a b. P (f a) (f b)  (f (X' a b) = X (f a) (f b))"

(* The value "a" is the abstract version of "b" under precondition "P". *)
definition "abstract_val P a f b  P  (a = f b)"

(* The variable "a" is the abstracted version of the variable "b". *)
definition "abs_var a f b  abstract_val True a f b"


lemmas basic_abstract_defs = 
  abstract_bool_binop_def 
  abstract_binop_def 
  abstract_val_def 
  abs_var_def

lemma abstract_val_trivial:
  "abstract_val True (f b) f b"
  by (simp add: basic_abstract_defs)

lemma abstract_binop_is_abstract_val:
    "abstract_binop P f X X' = (a b. abstract_val (P (f a) (f b)) (X (f a) (f b)) f (X' a b))"
  by (auto simp add: basic_abstract_defs)

lemma abstract_expr_bool_binop:
  " abstract_bool_binop E f X X';
     introduce_typ_abs_fn f;
     abstract_val P a f a';
     abstract_val Q b f b'  
           abstract_val (P  Q  E a b) (X a b) id (X' a' b')"
  by (clarsimp simp add: basic_abstract_defs)

lemma abstract_expr_binop:
  " abstract_binop E f X X';
     abstract_val P a f a';
     abstract_val Q b f b'  
           abstract_val (P  Q  E a b) (X a b) f (X' a' b')"
  by (clarsimp simp add: basic_abstract_defs)

lemma unat_abstract_bool_binops:
    "abstract_bool_binop (λ_ _. True) (unat :: ('a::len) word  nat) (<) (<)"
    "abstract_bool_binop (λ_ _. True) (unat :: ('a::len) word  nat) (≤) (≤)"
    "abstract_bool_binop (λ_ _. True) (unat :: ('a::len) word  nat) (=) (=)"
  by (auto simp:  word_less_nat_alt word_le_nat_alt eq_iff basic_abstract_defs)

lemmas unat_mult_simple = iffD1 [OF unat_mult_lem [unfolded word_bits_len_of]]

lemma le_to_less_plus_one:
    "((a::nat)  b) = (a < b + 1)"
  by arith

lemma unat_abstract_binops:
  "abstract_binop (λa b. a + b  UWORD_MAX TYPE('a::len)) (unat :: 'a word  nat) (+) (+)"
  "abstract_binop (λa b. a * b  UWORD_MAX TYPE('a)) (unat :: 'a word  nat) (*) (*)"
  "abstract_binop (λa b. a  b) (unat :: 'a word  nat) (-) (-)"
  "abstract_binop (λa b. True) (unat :: 'a word  nat) (div) (div)"
  "abstract_binop (λa b. True) (unat :: 'a word  nat) (mod) (mod)"
  by (auto simp: unat_plus_if' unat_div unat_mod UWORD_MAX_def le_to_less_plus_one
              WordAbstract.unat_mult_simple word_bits_def unat_sub word_le_nat_alt
              basic_abstract_defs)

lemma unat_of_int:
  "i  0; i < 2 ^ LENGTH('a)  unat (of_int i :: 'a::len word) = nat i"
  by (metis nat_less_numeral_power_cancel_iff of_nat_nat unat_of_nat_len)

(* fixme generalises Word_Lemmas_32.unat_of_int_32 *)
lemma unat_of_int_signed:
  "i  0; i < 2 ^ LENGTH('a)  unat (of_int i :: 'a::len signed word) = nat i"
  by (simp add: unat_of_int)

lemma nat_sint:
  "0 <=s (x :: 'a::len signed word)  nat (sint x) = unat x"
  apply (subst unat_of_int_signed[where 'a='a, symmetric])
    apply (simp add: word_sle_def)
   apply (rule less_trans[OF sint_lt])
   apply simp
  apply simp
  done

lemma int_unat_nonneg:
  "0 <=s (x :: 'a::len signed word)  int (unat x) = sint x"
  by (simp add: int_unat word_sle_msb_le sint_eq_uint)

lemma uint_sint_nonneg:
  "0 <=s (x :: 'a::len signed word)  uint x = sint x"
  by (simp add: int_unat word_sle_msb_le sint_eq_uint)

lemma unat_bitwise_abstract_binops:
  "abstract_binop (λa b. True) (unat :: 'a::len word  nat) Bit_Operations.and Bit_Operations.and"
  "abstract_binop (λa b. True) (unat :: 'a::len word  nat) Bit_Operations.or Bit_Operations.or"
  "abstract_binop (λa b. True) (unat :: 'a::len word  nat) Bit_Operations.xor Bit_Operations.xor"
  apply (simp add: unsigned_and_eq uint_nat unat_of_int basic_abstract_defs)
  apply (simp add: unsigned_or_eq  uint_nat unat_of_int basic_abstract_defs)
  apply (simp add: unsigned_xor_eq uint_nat unat_of_int basic_abstract_defs)
  done

lemma abstract_val_unsigned_bitNOT:
  "abstract_val P x unat (x' :: 'a::len word) 
   abstract_val P (UWORD_MAX TYPE('a) - x) unat (NOT x')"
  apply (clarsimp simp: UWORD_MAX_def NOT_eq basic_abstract_defs)
  by (metis nat_le_Suc_less diff_Suc_eq_diff_pred mask_eq_sum_exp mask_eq_sum_exp_nat 
    minus_diff_commute unat_lt2p unat_minus_one_word unat_sub_if_size unsigned_0)


lemma snat_abstract_bool_binops:
    "abstract_bool_binop (λ_ _. True) (sint :: ('a::len) signed word  int) (<) (word_sless)"
    "abstract_bool_binop (λ_ _. True) (sint :: 'a signed word  int) (≤) (word_sle)"
    "abstract_bool_binop (λ_ _. True) (sint :: 'a signed word  int) (=) (=)"
  by (auto simp: word_sless_def word_sle_def less_le basic_abstract_defs)

lemma snat_abstract_binops:
  "abstract_binop (λa b. WORD_MIN TYPE('a::len)  a + b  a + b  WORD_MAX TYPE('a)) (sint :: 'a signed word  int) (+) (+)"
  "abstract_binop (λa b. WORD_MIN TYPE('a)  a * b  a * b  WORD_MAX TYPE('a)) (sint :: 'a signed word  int) (*) (*)"
  "abstract_binop (λa b. WORD_MIN TYPE('a)  a - b  a - b  WORD_MAX TYPE('a)) (sint :: 'a signed word  int) (-) (-)"
  "abstract_binop (λa b. WORD_MIN TYPE('a)  a sdiv b  a sdiv b  WORD_MAX TYPE('a)) (sint :: 'a signed word  int) (sdiv) (sdiv)"
  "abstract_binop (λa b. WORD_MIN TYPE('a)  a smod b  a smod b  WORD_MAX TYPE('a)) (sint :: 'a signed word  int) (smod) (smod)"
  by (auto simp: signed_arith_sint word_size WORD_MIN_def WORD_MAX_def basic_abstract_defs)

lemma sint_bitwise_abstract_binops:
  "abstract_binop (λa b. True) (sint :: 'a::len signed word  int) Bit_Operations.and Bit_Operations.and"
  "abstract_binop (λa b. True) (sint :: 'a::len signed word  int) Bit_Operations.or Bit_Operations.or"
  "abstract_binop (λa b. True) (sint :: 'a::len signed word  int) Bit_Operations.xor Bit_Operations.xor"
  apply (fastforce intro: int_eq_test_bitI
                   simp: nth_sint bin_nth_ops basic_abstract_defs)+
  done

lemma abstract_val_signed_bitNOT:
  "abstract_val P x sint (x' :: 'a::len signed word) 
   abstract_val P (NOT x) sint (NOT x')"
  by (auto intro: int_eq_test_bitI 
      simp add: nth_sint bin_nth_ops word_nth_neq basic_abstract_defs min_less_iff_disj)

lemma abstract_val_signed_unary_minus:
  " abstract_val P r sint r'  
       abstract_val (P  (- r)  WORD_MAX TYPE('a)) (- r) sint ( - (r' :: ('a :: len) signed word))"
  apply (clarsimp simp add: basic_abstract_defs)
  using sint_range_size [where w=r']
  apply -
  apply (subst signed_arith_sint)
   apply (clarsimp simp: word_size WORD_MAX_def)
  apply simp
  done

lemma bang_big_nonneg:
  " 0 <=s (x::'a::len signed word); n  size x - 1   (x !! n) = False"
  apply (cases "n = size x - 1")
   apply (simp add: word_size msb_nth[where 'a="'a signed", symmetric, simplified] word_sle_msb_le)
  apply (simp add: test_bit_bl)
  apply arith
  done

(* FIXME: move to Word_Lib *)

lemma int_shiftr_nth[simp]:
  "(i >> n) !! m = i !! (n + m)" for i :: int
  by (simp add: shiftr_def bin_nth_shiftr)

(* FIXME: move to Word_Lib *)
lemma int_shiftl_nth[simp]:
  "(i << n) !! m = (n  m  i !! (m - n))" for i :: int
  by (simp add: shiftl_def bin_nth_shiftl)


lemma sint_shiftr_nonneg:
  " 0 <=s (x :: 'a::len signed word); 0  n; n < LENGTH('a)   sint (x >> n) = sint x >> n"
  apply (rule int_eq_test_bitI)
  apply (clarsimp simp: bang_big_nonneg[simplified word_size] nth_sint nth_shiftr field_simps 
      simp del: bit_signed_iff)
  done

lemma abstract_val_unsigned_unary_minus:
  " abstract_val P r unat r'  
       abstract_val P (if r = 0 then 0 else UWORD_MAX TYPE('a::len) + 1 - r) unat ( - (r' :: 'a word))"
  by (clarsimp simp: unat_minus' word_size unat_eq_zero UWORD_MAX_def basic_abstract_defs)

(* Rules for shifts *)
lemma abstract_val_signed_shiftr_signed:
  " abstract_val Px x sint (x' :: ('a :: len) signed word);
     abstract_val Pn n sint (n' :: ('b :: len) signed word)  
   abstract_val (Px  Pn  0  x  0  n  n < int LENGTH('a))
                (x >> (nat n)) sint (x' >> (unat n'))"
  apply (clarsimp simp only: abstract_val_def)
  apply (subst nat_sint, simp add: word_sle_def)
  apply (subst sint_shiftr_nonneg)
     apply (simp add: word_sle_def)
    apply simp
   apply (subst SMT.nat_int_comparison(2))
   apply (subst int_unat_nonneg)
    apply (simp add: word_sle_def)
   apply assumption
  apply (rule refl)
  done

lemma abstract_val_signed_shiftr_unsigned:
  " abstract_val Px x sint (x' :: ('a :: len) signed word);
     abstract_val Pn n unat (n' :: ('b :: len) word)  
   abstract_val (Px  Pn  0  x  n < LENGTH('a))
                (x >> n) sint (x' >> unat n')"
  apply (clarsimp simp: shiftr_int_def basic_abstract_defs)
  apply (subst sint_shiftr_nonneg)
     apply (simp add: word_sle_def)
    apply simp
   apply assumption
  apply (clarsimp simp: shiftr_int_def)
  done

lemma foo:  "¬ n - i < LENGTH('a::len) - Suc 0 
    n < LENGTH('a) - Suc 0 = False"
  apply simp
  done

lemma sint_shiftl_nonneg:
  " 0 <=s (x :: 'a::len signed word); n < LENGTH('a); sint x << n < 2^(LENGTH('a) - 1)  
   sint (x << n) = sint x << n"
  apply (rule int_eq_test_bitI)
  subgoal for na
    apply (simp add: nth_sint nth_shiftl word_sle_def int_shiftl_less_cancel int_2p_eq_shiftl
        bang_big_nonneg[simplified word_size]
        del:  bit_signed_iff shiftl_1)
      (* fixme: cleanup *)
    apply (intro impI iffI conjI; (solves simp)?)
      apply (drule(1) int_shiftl_lt_2p_bits[rotated])
      apply (clarsimp simp: nth_sint)
      apply (drule_tac x="LENGTH('a) - 1 - n" in spec)
      apply (subgoal_tac "LENGTH('a) - 1 - n < LENGTH('a) - 1")
       apply simp
      apply arith
     apply (drule(1) int_shiftl_lt_2p_bits[rotated])
     apply (clarsimp simp: nth_sint)
     apply (drule_tac x="na - n" in spec)
     apply simp
    apply (cases "n = 0")
     apply (simp add: word_sle_msb_le[where x=0, simplified word_sle_def, simplified] msb_nth)
    apply (drule(1) int_shiftl_lt_2p_bits[rotated])
    apply (clarsimp simp: nth_sint)
    apply (drule_tac x="LENGTH('a) - 1 - n" in spec)
    apply (subgoal_tac "LENGTH('a) - 1 - n < LENGTH('a) - 1")
     apply simp
    apply simp
    done
  done

lemma abstract_val_signed_shiftl_signed:
  " abstract_val Px x sint (x' :: ('a :: len) signed word);
     abstract_val Pn n sint (n' :: ('b :: len) signed word)  
   abstract_val (Px  Pn  0  x  0  n  n < int LENGTH('a)  x << nat n < 2^(LENGTH('a) - 1))
                (x << nat n) sint (x' << unat n')"
  apply (clarsimp simp add: basic_abstract_defs)
  by (metis One_nat_def len_gt_0 nat_int nat_sint signed_0 sint_shiftl_nonneg word_sle_eq 
    zless_nat_conj)

lemma abstract_val_signed_shiftl_unsigned:
  " abstract_val Px x sint (x' :: ('a :: len) signed word);
     abstract_val Pn n unat (n' :: ('b :: len) word)  
   abstract_val (Px  Pn  0  x  n < LENGTH('a)  x << n < 2^(LENGTH('a) - 1))
                (x << n) sint (x' << unat n')"
  by (clarsimp simp: sint_shiftl_nonneg word_sle_def basic_abstract_defs
                     nat_less_eq_zless[where z="int LENGTH('a)", simplified])

lemma abstract_val_unsigned_shiftr_unsigned:
  " abstract_val Px x unat (x' :: ('a :: len) word);
     abstract_val Pn n unat (n' :: ('a :: len) word)  
   abstract_val (Px  Pn) (x >> n) unat (x' >> unat n')"
  apply (clarsimp simp add: basic_abstract_defs)
  apply (simp add: shiftr_div_2n'  shiftr_int_def)
  using shiftr_eq_div by blast

lemma abstract_val_unsigned_shiftr_signed:
  " abstract_val Px x unat (x' :: ('a :: len) word);
     abstract_val Pn n sint (n' :: ('b :: len) signed word)  
   abstract_val (Px  Pn  0  n) (x >> nat n) unat (x' >> unat n')"
  apply (clarsimp simp: shiftr_div_2n' shiftr_int_def basic_abstract_defs)
  by (simp add: nat_sint shiftr_nat_def word_sle_eq)

lemma abstract_val_unsigned_shiftl_unsigned:
  " abstract_val Px x unat (x' :: ('a :: len) word);
     abstract_val Pn n unat (n' :: ('b :: len) word)  
   abstract_val (Px  Pn  n < LENGTH('a)  x << n < 2^LENGTH('a))
                (x << n) unat (x' << unat n')"
  by (clarsimp simp: shiftl_t2n Word_Lemmas_Internal.shiftl_nat_def unat_mult_simple field_simps basic_abstract_defs)

lemma abstract_val_unsigned_shiftl_signed:
  " abstract_val Px x unat (x' :: ('a :: len) word);
     abstract_val Pn n sint (n' :: ('b :: len) signed word)  
   abstract_val (Px  Pn  0  n  n < int (LENGTH('a))  x << nat n < 2^LENGTH('a))
                (x << nat n) unat (x' << unat n')"
  apply (clarsimp simp: shiftl_t2n Word_Lemmas_Internal.shiftl_nat_def unat_mult_simple field_simps basic_abstract_defs)
  apply (simp add: sint_eq_uint word_msb_sint)
  by (metis Word.of_nat_unat len_gt_0 nat_int unat_of_nat_len unat_power_lower word_arith_nat_mult 
    zless_nat_conj)

(* TODO: this would be useful for simplifying signed left shift c_guards,
   which are already implied by the generated word abs guard (premise #2).

   However, the c_guard is translated before the new word abs guards,
   thus L2Opt (which only propagates guards forwards) is unable to
   make use of this rule at present. *)
lemma signed_shiftl_c_guard_simp (* [L2flow] *):
  " int bound < 2^LENGTH('a); a * 2^b < int bound; 0  a  
   unat (of_int a :: 'a::len word) * 2 ^ b < bound"
  apply (subst unat_of_int)
    apply assumption
   apply (drule(1) less_trans)
   apply (subgoal_tac "a * 2^b < 2^LENGTH('a) * 2^b")
    apply simp
   apply (erule less_le_trans)
   apply simp
  apply (subgoal_tac "nat (a * 2^b) < nat (int bound)")
   apply (simp add: nat_power_eq nat_mult_distrib)
  apply (subst nat_mono_iff)
   apply (rule le_less_trans, assumption)
   apply (erule le_less_trans[rotated])
   apply (simp add: mult_left_mono[where a="1::int", simplified])
  apply simp
  done

lemmas abstract_val_signed_ops [simplified simp_thms] =
  abstract_expr_bool_binop [OF snat_abstract_bool_binops(1)]
  abstract_expr_bool_binop [OF snat_abstract_bool_binops(2)]
  abstract_expr_bool_binop [OF snat_abstract_bool_binops(3)]
  abstract_expr_binop [OF snat_abstract_binops(1)]
  abstract_expr_binop [OF snat_abstract_binops(2)]
  abstract_expr_binop [OF snat_abstract_binops(3)]
  abstract_expr_binop [OF snat_abstract_binops(4)]
  abstract_expr_binop [OF snat_abstract_binops(5)]
  abstract_expr_binop [OF sint_bitwise_abstract_binops(1)]
  abstract_expr_binop [OF sint_bitwise_abstract_binops(2)]
  abstract_expr_binop [OF sint_bitwise_abstract_binops(3)]
  abstract_val_signed_bitNOT
  abstract_val_signed_unary_minus
  abstract_val_signed_shiftr_signed
  abstract_val_signed_shiftr_unsigned
  abstract_val_signed_shiftl_signed
  abstract_val_signed_shiftl_unsigned

lemmas abstract_val_unsigned_ops [simplified simp_thms] =
  abstract_expr_bool_binop [OF unat_abstract_bool_binops(1)]
  abstract_expr_bool_binop [OF unat_abstract_bool_binops(2)]
  abstract_expr_bool_binop [OF unat_abstract_bool_binops(3)]
  abstract_expr_binop [OF unat_abstract_binops(1)]
  abstract_expr_binop [OF unat_abstract_binops(2)]
  abstract_expr_binop [OF unat_abstract_binops(3)]
  abstract_expr_binop [OF unat_abstract_binops(4)]
  abstract_expr_binop [OF unat_abstract_binops(5)]
  abstract_expr_binop [OF unat_bitwise_abstract_binops(1)]
  abstract_expr_binop [OF unat_bitwise_abstract_binops(2)]
  abstract_expr_binop [OF unat_bitwise_abstract_binops(3)]
  abstract_val_unsigned_bitNOT
  abstract_val_unsigned_unary_minus
  abstract_val_unsigned_shiftr_signed
  abstract_val_unsigned_shiftr_unsigned
  abstract_val_unsigned_shiftl_signed
  abstract_val_unsigned_shiftl_unsigned

lemma mod_less:
  "(a :: nat) < c  a mod b < c"
  by (metis less_trans mod_less_eq_dividend order_leE)

lemma abstract_val_ucast:
    " introduce_typ_abs_fn (unat :: ('a::len) word  nat);
       abstract_val P v unat v' 
         abstract_val (P  v  nat (WORD_MAX TYPE('a)))
                  (int v) sint (ucast (v' :: 'a word) :: 'a signed word)"
  apply (clarsimp simp: uint_nat [symmetric] basic_abstract_defs)
  apply (subst sint_eq_uint)
   apply (rule not_msb_from_less)
   apply (clarsimp simp: word_less_nat_alt unat_ucast WORD_MAX_def le_to_less_plus_one)
   apply (subst (asm) nat_diff_distrib)
     apply simp
    apply clarsimp
   apply clarsimp
   apply (metis of_nat_numeral nat_numeral nat_power_eq of_nat_0_le_iff)
  apply (clarsimp simp: uint_up_ucast is_up)
  done

(* Base rule for heap-lifted signed words. See the function mk_sword_heap_get_rule. *)
lemma abstract_val_heap_sword_template:
  " introduce_typ_abs_fn (sint :: ('a::len) signed word  int);
     abstract_val P p' id p 
    abstract_val P (sint (ucast (heap_get s p' :: 'a word) :: 'a signed word))
                      sint (ucast (heap_get s p) :: 'a signed word)"
  by (simp add: basic_abstract_defs)

lemma abstract_val_scast:
    " introduce_typ_abs_fn (sint :: ('a::len) signed word  int);
       abstract_val P C' sint C 
              abstract_val (P  0  C') (nat C') unat (scast (C :: ('a::len) signed word) :: ('a::len) word)"
  apply (clarsimp simp: down_cast_same [symmetric] is_down unat_ucast basic_abstract_defs)
  apply (subst sint_eq_uint)
   apply (clarsimp simp: word_msb_sint)
  apply (clarsimp simp: unat_def [symmetric])
  apply (subst word_unat.norm_Rep [symmetric])
  apply clarsimp
  done

lemma abstract_val_scast_upcast:
    " len_of TYPE('a::len)  len_of TYPE('b::len);
       abstract_val P C' sint C 
              abstract_val P (C') sint (scast (C :: 'a signed word) :: 'b signed word)"
  by (clarsimp simp: down_cast_same [symmetric] sint_up_scast is_up basic_abstract_defs)

lemma abstract_val_scast_downcast:
    " len_of TYPE('b) < len_of TYPE('a::len);
       abstract_val P C' sint C 
              abstract_val P (sbintrunc ((len_of TYPE('b::len) - 1)) C') sint (scast (C :: 'a signed word) :: 'b signed word)"
  by (metis Word.of_int_sint abstract_val_def len_signed word_sbin.inverse_norm)

lemma abstract_val_ucast_upcast:
    " len_of TYPE('a::len)  len_of TYPE('b::len);
       abstract_val P C' unat C 
              abstract_val P (C') unat (ucast (C :: 'a word) :: 'b word)"
  by (clarsimp simp: is_up unat_ucast_upcast basic_abstract_defs)

lemma abstract_val_ucast_downcast:
    " len_of TYPE('b::len) < len_of TYPE('a::len);
       abstract_val P C' unat C 
              abstract_val P (C' mod (UWORD_MAX TYPE('b) + 1)) unat (ucast (C :: 'a word) :: 'b word)"
  apply (clarsimp simp: scast_def sint_uint UWORD_MAX_def basic_abstract_defs)
  unfolding ucast_def unat_def
  apply (subst int_word_uint)
  apply (metis (mono_tags) uint_mod uint_power_lower unat_def unat_mod unat_power_lower)
  done

(*
 * The pair A/C are a valid abstraction/concrete-isation function pair,
 * under the precondition's P and Q.
 *)
definition
 "valid_typ_abs_fn (P :: 'a  bool) (Q :: 'a  bool) (A :: 'c  'a) (C :: 'a  'c) 
     (v. P v  A (C v) = v)  (v. Q (A v)  C (A v) = v)"

declare valid_typ_abs_fn_def [simp]

lemma valid_typ_abs_fn_id:
  "valid_typ_abs_fn (λ_. True) (λ_. True) id id"
  by clarsimp

lemma valid_typ_abs_fn_unit:
  "valid_typ_abs_fn (λ_. True) (λ_. True) id (id :: unit  unit)"
  by clarsimp

lemma valid_typ_abs_fn_unat:
  "valid_typ_abs_fn (λv. v  UWORD_MAX TYPE('a::len)) (λ_. True) (unat :: 'a word  nat) (of_nat :: nat  'a word)" 
  supply unsigned_of_nat [simp del] 
  by (clarsimp simp: unat_of_nat_eq UWORD_MAX_def le_to_less_plus_one)

lemma valid_typ_abs_fn_sint:
  "valid_typ_abs_fn (λv. WORD_MIN TYPE('a::len)  v  v  WORD_MAX TYPE('a)) (λ_. True) (sint :: 'a signed word  int) (of_int :: int  'a signed word)"
  by (clarsimp simp: sint_of_int_eq WORD_MIN_def WORD_MAX_def)

lemma valid_typ_abs_fn_tuple:
  " valid_typ_abs_fn P_a Q_a abs_a conc_a; valid_typ_abs_fn P_b Q_b abs_b conc_b  
          valid_typ_abs_fn (λ(a, b). P_a a  P_b b) (λ(a, b). Q_a a  Q_b b) (map_prod abs_a abs_b) (map_prod conc_a conc_b)"
  by clarsimp

lemma valid_typ_abs_fn_tuple_split:
  " valid_typ_abs_fn P_a Q_a abs_a conc_a; valid_typ_abs_fn P_b Q_b abs_b conc_b  
          valid_typ_abs_fn (λ(a, b). P_a a  P_b b) (λ(a, b). Q_a a  Q_b b) (λ(a, b). (abs_a a, abs_b b)) (map_prod conc_a conc_b)"
  by clarsimp

lemma introduce_typ_abs_fn_tuple:
  " introduce_typ_abs_fn abs_a; introduce_typ_abs_fn abs_b  
         introduce_typ_abs_fn (map_prod abs_a abs_b)"
  by clarsimp

lemma valid_typ_abs_fn_sum:
  " valid_typ_abs_fn P_a Q_a abs_a conc_a; valid_typ_abs_fn P_b Q_b abs_b conc_b  
          valid_typ_abs_fn (case_sum P_a P_b) (case_sum Q_a Q_b) (map_sum abs_a abs_b) (map_sum conc_a conc_b)"
  by (auto simp add: map_sum_def split: sum.splits)

lemma introduce_typ_abs_fn_sum:
  " introduce_typ_abs_fn abs_a; introduce_typ_abs_fn abs_b  
         introduce_typ_abs_fn (map_sum abs_a abs_b)"
  by clarsimp


section "Refinement Lemmas"

named_theorems word_abs

definition
  "corresTA P rx ex A C  corresXF (λs. s) (λr s. rx r) (λr s. ex r) P A C"

definition "rel_word_abs ex rx  rel_xval (λc a. a = ex c) (λc a. a = rx c)" 

lemma rel_word_abs_simps[simp]:
  "rel_word_abs ex rx (Result rc) (Exn la) = False"
  "rel_word_abs ex rx (Exn lc) (Result ra) = False"
  "rel_word_abs ex rx (Result rc) (Result ra) = (ra = rx rc)"
  "rel_word_abs ex rx (Exn lc) (Exn la) = (la = ex lc)"
  by (auto simp add: rel_word_abs_def)

lemma corresTA_refines:
  "corresTA P rx ex fa fc  P s  refines fc fa s s (rel_prod (rel_word_abs ex rx) (=))"
  unfolding corresTA_def
  apply (clarsimp simp add: corresXF_refines_conv rel_word_abs_def rel_xval.simps)
  apply (clarsimp simp add: refines_def_old rel_xval.simps reaches_succeeds)
  by (smt (verit) le_boolE le_boolI' linorder_not_le xval_split)

lemma refines_corresTA:
  assumes sim: "s. P s  refines fc fa s s (rel_prod (rel_word_abs ex rx) (=))"
  shows "corresTA P rx ex fa fc"
  unfolding corresTA_def
  apply (clarsimp simp add: corresXF_refines_conv rel_word_abs_def rel_xval.simps)
  using sim
  apply (fastforce simp add: refines_def_old rel_xval.simps reaches_succeeds rel_word_abs_def split: xval_splits)
  done

lemma corresTA_refines_conv: 
  "corresTA P rx ex fa fc  (s. P s  refines fc fa s s (rel_prod (rel_word_abs ex rx) (=)))"
  using corresTA_refines refines_corresTA by metis

lemma admissible_nondet_ord_corresTA [corres_admissible]:
  "ccpo.admissible Inf (≥) (λA. corresTA P rx ex  A C)"
  unfolding corresTA_def
  apply (rule admissible_nondet_ord_corresXF)
  done

lemma corresTA_top [corres_top]: "corresTA P rx st  C"
  by (auto simp add: corresTA_def corresXF_def)

lemma corresTA_assume_and_weaken_pre:
  assumes A_C: "s. P s  corresTA Q rt ex A C"
  assumes P_Q: "s. P s  Q s" 
  shows "corresTA P rt ex A C"
  unfolding corresTA_def
  apply (rule corresXF_assume_pre)
  apply (rule corresXF_guard_imp)
   apply (rule A_C [unfolded corresTA_def])
   apply assumption
  apply (rule P_Q)
  apply assumption
  done

lemma corresTA_L2_gets:
  " s. abstract_val (Q s) (C s) rx (C' s)  
     corresTA Q rx ex (L2_gets (λs. C s) n) (L2_gets (λs. C' s) n)"
  unfolding L2_defs
  apply (clarsimp simp add: corresTA_refines_conv)
  apply (rule refines_gets)
  apply (simp add: abstract_val_def)
  done

lemma corresTA_L2_modify:
    " s. abstract_val (P s) (m s) id (m' s)  
            corresTA P rx ex (L2_modify (λs. m s)) (L2_modify (λs. m' s))"
  unfolding L2_defs
  apply (clarsimp simp add: corresTA_refines_conv)
  apply (rule refines_modify)
  apply (simp add: abstract_val_def)
  done

(* FIXME: move to spec monad *)

lemma refines_throw: "R (Exn x, s) (Exn y, t)  refines (throw x) (throw y) s t R"
  by (auto simp add: refines_def_old Exn_def)

lemma corresTA_L2_throw:
  " abstract_val Q C ex C'  
     corresTA (λ_. Q) rx ex (L2_throw C n) (L2_throw C' n)"
  unfolding L2_defs
  apply (clarsimp simp add: corresTA_refines_conv)
  apply (simp add: abstract_val_def)
  done

lemma corresTA_L2_skip:
  "corresTA (λ_. True) rx ex L2_skip L2_skip"
  unfolding L2_defs
  by (auto simp add: corresTA_refines_conv refines_gets)


lemma corresTA_L2_fail:
  "corresTA (λ_. True) rx ex L2_fail L2_fail"
  unfolding L2_defs
  by (auto simp add: corresTA_refines_conv)

lemma corresTA_L2_seq':
  fixes L' :: "('e, 'c1, 's) exn_monad"
  fixes R' :: "'c1  ('e, 'c2, 's) exn_monad"
  fixes L :: "('ea, 'a1, 's) exn_monad"
  fixes R :: "'a1  ('ea, 'a2, 's) exn_monad"
  shows
  " corresTA P rx1 ex L L';
     r. corresTA (Q (rx1 r)) rx2 ex (R (rx1 r)) (R' r)  
    corresTA P rx2 ex
       (L2_seq L (λr. L2_seq (L2_guard (λs. Q r s)) (λ_. R r)))
       (L2_seq L' (λr. R' r))"
  apply atomize
  apply (clarsimp simp: L2_seq_def L2_guard_def corresTA_def)
  apply (erule corresXF_join [where P'="λx y s. rx1 y = x"])
  subgoal
    by (auto simp add: corresXF_def reaches_bind reaches_guard  succeeds_bind split: xval_splits)
  subgoal
    by (auto simp add: runs_to_partial_def_old split: xval_splits)
  subgoal by simp
  done

lemma corresTA_L2_seq:
  " introduce_typ_abs_fn rx1;
    PROP THIN (Trueprop (corresTA P (rx1 :: 'a  'b) ex L L'));
    PROP THIN  (r r'. abs_var r rx1 r'  corresTA (Q r) rx2 ex (R r) (R' r'))  
       corresTA P rx2 ex (L2_seq L (λr. L2_seq (L2_guard (Q r)) (λ_. R r))) (L2_seq L' R')"
  unfolding THIN_def
  by (rule corresTA_L2_seq', (simp add: basic_abstract_defs)+)

lemma corresTA_L2_seq_unused_result:
  " introduce_typ_abs_fn rx1;
    PROP THIN (Trueprop (corresTA P (rx1 :: 'a  'b) ex L L'));
    PROP THIN (Trueprop (corresTA Q rx2 ex R R')) 
       corresTA P rx2 ex (L2_seq L (λr. L2_seq (L2_guard Q) (λ_. R))) (L2_seq L' (λ_. R'))"
  unfolding THIN_def
  by (rule corresTA_L2_seq', simp+)

lemma corresTA_L2_seq_unit:
  fixes L' :: "('e, unit, 's) exn_monad"
  fixes R' :: "unit  ('e, 'r, 's) exn_monad"
  fixes L :: "('ea, unit, 's) exn_monad"
  fixes R :: "('ea, 'ra, 's) exn_monad"
  shows
  "PROP THIN (Trueprop (corresTA P id ex L L'));
    PROP THIN (Trueprop (corresTA Q rx ex R (R' ())))  
    corresTA P rx ex
       (L2_seq L (λr. L2_seq (L2_guard Q) (λ_. R)))
       (L2_seq L' R')"
  unfolding THIN_def
  by (rule corresTA_L2_seq', simp+)

lemma corresTA_L2_catch':
  fixes L' :: "('e1, 'c, 's) exn_monad"
  fixes R' :: "'e1  ('e2, 'c, 's) exn_monad"
  fixes L :: "('e1a, 'ca, 's) exn_monad"
  fixes R :: "'e1a  ('e2a, 'ca, 's) exn_monad"
  shows
  "corresTA P rx ex1 L L';
    r. corresTA (Q (ex1 r)) rx ex2 (R (ex1 r)) (R' r)  
    corresTA P rx ex2 (L2_catch L (λr. L2_seq (L2_guard (λs. Q r s)) (λ_. R r))) (L2_catch L' (λr. R' r))"
  apply atomize
  apply (clarsimp simp: L2_seq_def L2_catch_def L2_guard_def corresTA_def)
  apply (erule corresXF_except [where P'="λx y s. ex1 y = x"])
  subgoal
    by (auto simp add: corresXF_def reaches_bind reaches_guard  succeeds_bind split: xval_splits)
  subgoal
    by (auto simp add: runs_to_partial_def_old split: xval_splits)
  subgoal by simp
  done

lemma corresTA_L2_catch:
  " introduce_typ_abs_fn ex1;
    PROP THIN (Trueprop (corresTA P rx ex1 L L'));
    PROP THIN (r r'. abs_var r ex1 r'  corresTA (Q r) rx ex2 (R r) (R' r'))  
       corresTA P rx ex2 (L2_catch L (λr. L2_seq (L2_guard (λs. Q r s)) (λ_. R r))) (L2_catch L' (λr. R' r))"
  unfolding THIN_def
  by (rule corresTA_L2_catch', (simp add: basic_abstract_defs)+)


term "corresTA P rx ex f g" 


lemma corresTA_yield: 
  "abstract_val True v' (map_xval ex rx) v  corresTA P rx ex (yield v') (yield v)"
  apply (auto simp add: corresTA_refines_conv rel_word_abs_def abstract_val_def rel_xval.simps map_xval_def 
      intro!: refines_yield split: xval_splits)
  done


lemma map_sum_apply: "map_sum ex rx = (λv. case v of Inl l  Inl (ex l) | Inr r  Inr (rx r))"
  by (simp add: fun_eq_iff split_sum_all)


(* FIXME: to spec monad *)
lemma refines_try_rel_prod: 
  assumes "refines f g s t (rel_prod (rel_xval (rel_sum L R) R) S)"
  shows "refines (try f) (try g) s t (rel_prod (rel_xval L R) S)"
  using assms
  apply (clarsimp simp add: refines_def_old reaches_try unnest_exn_def rel_xval.simps split: xval_splits sum.splits)
  subgoal for r s' r'
    apply (erule_tac x=r' in allE)
    apply (erule_tac x=s' in allE)

    apply (cases r)
    subgoal for e
      apply (clarsimp simp add: default_option_def Exn_def, safe)
       apply (metis Exception_eq_Exception Exn_def Exn_neq_Result exception_or_result_cases not_None_eq sum_all_ex(2))
      by (smt (verit, ccfv_threshold) Exn_eq_Exception(2) Exn_neq_Result 
          rel_sum.cases rel_sum_simps(2) theLeft.simps the_Exn_Exn(2))

    subgoal for v
      apply (clarsimp simp add: default_option_def Exn_def, safe)
        apply (metis Exception_eq_Exception Exn_def Exn_neq_Result exception_or_result_cases not_None_eq sum_all_ex(2))
      apply (smt (verit, ccfv_threshold) Exn_def Result_neq_Exn rel_sum.cases rel_sum_simps(3) unnest_exn_eq_simps(3))
      by (metis Exn_def Result_eq_Result Result_neq_Exn)
    done
  done

lemma rel_map_xval_sum_rel_sum_conv: 
  "rel_xval (λc a. a = map_sum ex rx c) (λc a. a = rx c) =
       rel_xval (rel_sum (λc a. a = ex c) (λc a. a = rx c)) (λc a. a = rx c)"
  apply (rule ext)+
  apply (auto simp add: rel_xval.simps rel_sum.simps)
  done

lemma corresTA_L2_try':
  assumes corres_L_L': "corresTA P rx (map_sum ex rx) L L'" 
  shows "corresTA P rx ex (L2_try L) (L2_try L')"
  unfolding L2_defs 
  apply (clarsimp simp add: corresTA_refines_conv rel_word_abs_def)
  apply (rule refines_try_rel_prod)
  using corres_L_L' 
  apply (auto simp add: corresTA_refines_conv rel_word_abs_def rel_map_xval_sum_rel_sum_conv)
  done

lemma corresTA_L2_while:
  assumes init_corres: "abstract_val Q i rx i'"
  and cond_corres: "PROP THIN (r r' s. abs_var r rx r'
                            abstract_val (G r s) (C r s) id (C' r' s))"
  and body_corres: "PROP THIN (r r'. abs_var r rx r'
                            corresTA (P r) rx ex (B r) (B' r'))"
  shows "corresTA (λ_. Q) rx ex
       (L2_guarded_while (λr s. G r s) (λr s. C r s) (λr. L2_seq (L2_guard (λs. P r s)) (λ_. B r)) i x)
       (L2_while (λr s. C' r s) B' i' x)"
proof -
  note cond_corres = cond_corres [unfolded THIN_def, rule_format]
  note body_corres = body_corres [unfolded THIN_def, rule_format]
  note body_corres' =
       corresXF_guarded_while_body [OF body_corres [unfolded corresTA_def]]

  have init_corres':
    "Q  i = rx i'"
    using init_corres
    by (simp add: basic_abstract_defs)

  note basic_abstract_defs [simp]
  show ?thesis
    thm corresXF_assume_pre
    apply (clarsimp simp: L2_defs  corresTA_def gets_return)
    apply (rule corresXF_assume_pre)
    thm corresXF_guarded_while
    apply (rule corresXF_guarded_while [where P="λr s. G (rx r) s"])
    subgoal for s s' x y
      apply (cut_tac r'=x in body_corres, simp)
      apply (fastforce simp add: corresTA_refines_conv corresXF_refines_conv refines_def_old reaches_bind succeeds_bind 
          rel_word_abs_def rel_xval.simps split: xval_splits)
      done
    subgoal 
       apply (insert cond_corres)[1]
      apply (clarsimp)
      done
    subgoal
      by (auto simp add: runs_to_partial_def_old split: xval_splits)
    subgoal using init_corres
      by (clarsimp)
    subgoal using init_corres'
      by clarsimp
    done
qed


lemma corresTA_L2_guard:
  " s. abstract_val (Q s) (G s) id (G' s) 
            corresTA (λ_. True) rx ex (L2_guard (λs. G s  Q s)) (L2_guard (λs. G' s))"
  unfolding L2_defs 
  apply (auto simp add: corresTA_refines_conv abstract_val_def intro!: refines_guard)
  done


lemma corresTA_L2_guard':
  "s. abstract_val (Q s) (G s) id (G' s); 
    s. R s  G s  Q s
            corresTA (λ_. True) rx ex (L2_guard (λs. R s)) (L2_guard (λs. G' s))"
  unfolding L2_defs 
  apply (auto simp add: corresTA_refines_conv abstract_val_def intro!: refines_guard)
  done

lemma corresTA_L2_guarded_simple:
  assumes G_G': "s. abstract_val (Q s) (G s) id (G' s)"
  assumes f_f': "s. G' s  Q s  G s  corresTA P rx ex f f'"
  shows "corresTA (λ_. True) rx ex (L2_guarded (λs. G s  Q s  P s) f) (L2_guarded G' f')"
  unfolding L2_defs L2_guarded_def
  apply (clarsimp simp add: corresTA_refines_conv)
  apply (rule refines_bind_guard_right)
  using G_G' f_f'
  by (auto simp add: corresTA_refines_conv refines_def_old succeeds_bind reaches_bind abstract_val_def)

lemma corresTA_L2_spec:
  "(s t. abstract_val (Q s) (P s t) id (P' s t)) 
   corresTA Q rx ex (L2_spec {(s, t). P s t}) (L2_spec {(s, t). P' s t})"
  unfolding L2_defs
  by (auto simp add: corresTA_refines_conv abstract_val_def refines_def_old reaches_bind succeeds_bind)

lemma corresTA_L2_assume:
  "(s r t. abstract_val (Q s) (P s) (λX. (λ(x, y). (rx x, y)) ` X) (P' s)) 
   corresTA Q rx ex (L2_assume P) (L2_assume P')"
  unfolding L2_defs
  apply (auto simp add: corresTA_refines_conv abstract_val_def refines_def_old reaches_bind succeeds_bind
    rel_word_abs_def rel_xval.simps)
  done

lemma corresTA_L2_condition:
  "PROP THIN (Trueprop (corresTA P rx ex L L'));
    PROP THIN (Trueprop (corresTA Q rx ex R R'));
     s. abstract_val (T s) (C s) id (C' s)  
    corresTA T rx ex
          (L2_condition (λs. C s)
            (L2_seq (L2_guard P) (λ_. L))
            (L2_seq (L2_guard Q) (λ_. R))
           ) (L2_condition (λs. C' s) L' R')"
  unfolding THIN_def L2_defs
  by (auto simp add: corresTA_refines_conv abstract_val_def intro!: refines_condition refines_bind_guard_right)


lemma L2_call_L2_defs: "L2_call x emb ns = L2_catch x (λe. L2_throw (emb e) ns)"
  unfolding L2_defs L2_call_def
  apply (rule spec_monad_eqI)
  apply (clarsimp simp add: runs_to_iff)
  apply (auto simp add: runs_to_def_old map_exn_def split: xval_splits)
  done

lemma corresTA_L2_call:
  "corresTA P rx ex' A B; 
    r r'. abs_var r ex' r'   abstract_val Q (emb r) ex (emb' r')
     
        corresTA (λs. P s  Q) rx ex (L2_call A emb ns) (L2_call B emb' ns)"
  unfolding L2_call_def
  by (force simp add: corresTA_refines_conv abstract_val_def abs_var_def refines_def_old reaches_map_value map_exn_def
    rel_word_abs_def rel_xval.simps)

(* Backup rule to corresTA_L2_call. Converts the return type of the function call. *)
lemma corresTA_L2_call':
  " corresTA P f1 ex' A B;
     valid_typ_abs_fn Q1 Q1' f1 f1';
     valid_typ_abs_fn Q2 Q2' f2 f2';
    r r'. abs_var r ex' r'   abstract_val Q (emb r) ex (emb' r')
    
   corresTA (λs. P s  Q) f2 ex
       (L2_seq (L2_call A emb ns) (ETA_TUPLED (λret. (L2_seq (L2_guard (λ_. Q1' ret)) (λ_. L2_gets (λ_. f2 (f1' ret)) ns)))))
       (L2_call B emb' ns)"
  unfolding L2_call_def L2_defs
  apply (clarsimp simp add:  ETA_TUPLED_def corresTA_refines_conv abstract_val_def abs_var_def 
      refines_def_old reaches_map_value map_exn_def reaches_bind succeeds_bind reaches_succeeds
    rel_word_abs_def rel_xval.simps)
  subgoal for s s' r'
    apply (erule_tac x=s in allE)
    apply clarsimp
    apply (cases r')
    subgoal
      apply (simp add: default_option_def Exn_def [symmetric])
      by (smt (z3) Exn_def Exn_neq_Result case_exception_or_result_Exn case_xval_simps(1) the_Exn_Exn(1))
    subgoal
      apply simp
      by (metis (mono_tags, lifting) Exn_neq_Result case_exception_or_result_Result case_xval_simps(2) the_Result_simp)
    done
  done
  
lemma corresTA_L2_unknown:
  "corresTA (λ_. True) rx ex (L2_unknown x) (L2_unknown x)"
  unfolding L2_defs
  by (auto simp add: corresTA_refines_conv intro!: refines_select)


lemma corresTA_L2_call_exec_concrete:
  " corresTA P rx ex' A B ; 
    r r'. abs_var r ex' r'   abstract_val Q (emb r) ex (emb' r') 
        corresTA (λs. s'. s = st s'  P s'  Q) rx ex
               (exec_concrete st (L2_call A emb ns))
               (exec_concrete st (L2_call B emb' ns))"
  unfolding L2_defs L2_call_def
  apply (clarsimp simp add: corresTA_refines_conv abstract_val_def abs_var_def)
  apply (clarsimp simp add: refines_def_old succeeds_exec_concrete_iff reaches_exec_concrete 
      reaches_map_value rel_word_abs_def rel_xval.simps map_exn_def split: xval_splits )
  subgoal for r t t' r'
    apply (erule_tac x=t in allE)
    apply clarsimp
    apply (cases r')
    subgoal
      apply (clarsimp simp add: default_option_def Exn_def [symmetric])
      by (metis Exn_eq_Exn Exn_neq_Result)
    subgoal
      apply simp
      by (metis Result_eq_Result Result_neq_Exn)
    done
  done
  

lemma corresTA_L2_call_exec_abstract:
  " corresTA P rx ex' A B; 
    r r'. abs_var r ex' r'   abstract_val Q (emb r) ex (emb' r')  
        corresTA (λs. P (st s)  Q) rx ex
               (exec_abstract st (L2_call A emb ns))
               (exec_abstract st (L2_call B emb' ns))"
  unfolding L2_defs L2_call_def
  apply (clarsimp simp add: corresTA_refines_conv abstract_val_def abs_var_def)
  apply (clarsimp simp add: refines_def_old succeeds_exec_abstract_iff reaches_exec_abstract 
      reaches_map_value rel_word_abs_def rel_xval.simps map_exn_def split: xval_splits )
  subgoal for s r s' r'
    apply (erule_tac x="st s" in allE)
    apply clarsimp
    apply (cases r')
    subgoal
      apply (clarsimp simp add: default_option_def Exn_def [symmetric])
      by (metis Exn_eq_Exn Exn_neq_Result)
    subgoal
      apply simp
      by (metis Result_eq_Result Result_neq_Exn)
    done
  done


(* More backup rules for calls. *)
lemma corresTA_L2_call_exec_concrete':
  " corresTA P f1 ex' A B;
     valid_typ_abs_fn Q1 Q1' f1 f1';
     valid_typ_abs_fn Q2 Q2' f2 f2'; 
     r r'. abs_var r ex' r'   abstract_val Q (emb r) ex (emb' r')
    
   corresTA (λs. s'. s = st s'  P s'  Q) f2 ex
       (L2_seq (exec_concrete st (L2_call A emb ns)) (λret. (L2_seq (L2_guard (λ_. Q1' ret)) (λ_. L2_gets (λ_. f2 (f1' ret)) []))))
       (exec_concrete st (L2_call B emb' ns))"
  unfolding L2_defs L2_call_def
  apply (clarsimp simp add: corresTA_refines_conv abstract_val_def abs_var_def)
  apply (clarsimp simp add: refines_def_old succeeds_exec_concrete_iff reaches_exec_concrete 
      reaches_bind succeeds_bind
      reaches_map_value rel_word_abs_def rel_xval.simps map_exn_def split: xval_splits )
  subgoal for r t t' r'
    apply (erule_tac x=t in allE)
    apply clarsimp
    apply (cases r')
    subgoal
      apply (clarsimp simp add: default_option_def Exn_def [symmetric])
      by (smt (z3) Exn_def Exn_eq_Exn Exn_neq_Result case_exception_or_result_Exn)


    subgoal for v
      apply clarsimp
      apply (erule_tac x="Result v" in allE)
      apply (erule_tac x="t'" in allE)
      by (smt (z3) Exn_neq_Result Result_eq_Result case_exception_or_result_Result)
    done
  done

lemma corresTA_L2_call_exec_abstract':
  " corresTA P f1 ex' A B;
     valid_typ_abs_fn Q1 Q1' f1 f1';
     valid_typ_abs_fn Q2 Q2' f2 f2';
    r r'. abs_var r ex' r'   abstract_val Q (emb r) ex (emb' r')
    
   corresTA (λs. P (st s)  Q) f2 ex
       (L2_seq (exec_abstract st (L2_call A emb ns)) (λret. (L2_seq (L2_guard (λ_. Q1' ret)) (λ_. L2_gets (λ_. f2 (f1' ret)) []))))
       (exec_abstract st (L2_call B emb' ns))"
  unfolding L2_defs L2_call_def
  apply (clarsimp simp add: corresTA_refines_conv abstract_val_def abs_var_def)
  apply (clarsimp simp add: refines_def_old succeeds_exec_abstract_iff reaches_exec_abstract 
      reaches_bind succeeds_bind
      reaches_map_value rel_word_abs_def rel_xval.simps map_exn_def split: xval_splits )
  subgoal for s r s' r'
    apply (erule_tac x="st s" in allE)
    apply clarsimp
    apply (cases r')
    subgoal
      apply (clarsimp simp add: default_option_def Exn_def [symmetric])
      by (smt (z3) Exn_def Exn_eq_Exn Exn_neq_Result case_exception_or_result_Exn)
    subgoal
      apply clarsimp
      by (smt (z3) Result_eq_Result Result_neq_Exn case_exception_or_result_Result)
    done
  done

 

text ‹Avoid higher-order unification issues by explicit application with @{term "($)"}: 
 in concrete program position enforces 'obvious' instantiation
 in abstract program position enforces introduction of two separate variables for @{term a} and 
  @{term b} instead of a higher-order flex-flex pair.›

lemma abstract_val_fun_app:
   " abstract_val Q b id b'; abstract_val P a id a'  
           abstract_val (P  Q) (f $ (a $ b)) f (a' $ b')"
  by (simp add: basic_abstract_defs)

lemma corresTA_precond_to_guard:
  "corresTA (λs. P s) rx ex A A'  corresTA (λ_. True) rx ex (L2_seq (L2_guard (λs. P s)) (λ_. A)) A'"
  unfolding L2_defs
  by (auto simp add: corresTA_refines_conv intro: refines_bind_guard_right)


lemma corresTA_precond_to_asm:
  " s. P s  corresTA (λ_. True) rx ex A A'   corresTA P rx ex A A'"
  by (clarsimp simp: corresXF_def corresTA_def)

lemma L2_guard_true: "L2_seq (L2_guard (λ_. True)) A = A ()"
  unfolding L2_defs
  apply (rule spec_monad_ext)
  apply (auto simp add: run_bind run_guard)
  done

lemma corresTA_simp_trivial_guard:
  "corresTA P rx ex (L2_seq (L2_guard (λ_. True)) A) C  corresTA P rx ex (A ()) C"
  by (simp add: L2_guard_true)

lemma corresTA_extract_preconds_of_call_init:
  " corresTA (λs. P) rx ex A A'   corresTA (λs. P  True) rx ex A A'"
  by simp

lemma corresTA_extract_preconds_of_call_step:
  " corresTA (λs. (abs_var a f a'  R)  C) rx ex A A'; abstract_val Y a f a' 
            corresTA (λs. R  (Y  C)) rx ex A A'"
  by (clarsimp simp: corresXF_def corresTA_def basic_abstract_defs)

lemma corresTA_extract_preconds_of_call_final:
  " corresTA (λs. (abs_var a f a')  C) rx ex A A'; abstract_val Y a f a' 
            corresTA (λs. (Y  C)) rx ex A A'"
  by (clarsimp simp: corresXF_def corresTA_def basic_abstract_defs)

lemma corresTA_extract_preconds_of_call_final':
  " corresTA (λs. True  C) rx ex A A' 
            corresTA (λs. C) rx ex A A'"
  by (clarsimp simp: corresXF_def corresTA_def)

lemma corresTA_extract_preconds_of_call_init_prems:
  " corresTA (λs. P  True) rx ex A A'  corresTA (λs. P) rx ex A A'"
  by simp

lemma corresTA_extract_preconds_of_call_step_prems:
  "Y.  abstract_val Y a f a'  corresTA (λs. R  (Y  C)) rx ex A A' 
            corresTA (λs. (abs_var a f a'  R)  C) rx ex A A'"
  by (clarsimp simp: corresXF_def corresTA_def basic_abstract_defs)

lemma corresTA_extract_preconds_of_call_final_prems:
  "Y. abstract_val Y a f a'  corresTA (λs. (Y  C)) rx ex A A'
            corresTA (λs. (abs_var a f a')  C) rx ex A A' "
  by (auto simp: corresXF_def corresTA_def basic_abstract_defs split: prod.splits sum.splits)

lemma corresTA_extract_preconds_of_call_final'_prems:
  "corresTA (λs. C) rx ex A A' 
             corresTA (λs. True  C) rx ex A A'"
  by (clarsimp simp: corresXF_def corresTA_def)

lemma corresTA_case_prod:
 " introduce_typ_abs_fn rx1;
    introduce_typ_abs_fn rx2;
    abstract_val (Q x) x (map_prod rx1 rx2) x';
      a b a' b'.  abs_var a rx1 a'; abs_var  b rx2 b' 
                        corresTA (P a b) rx ex (M a b) (M' a' b')   
    corresTA (λs. case x of (a, b)  P a b s  Q (a, b)) rx ex (case x of (a, b)  M a b) (case x' of (a, b)  M' a b)"
  apply (clarsimp simp add: corresTA_def)
  apply (rule corresXF_assume_pre)
  apply (clarsimp simp: split_def map_prod_def basic_abstract_defs)
  done

lemma abstract_val_case_prod:
  " abstract_val True r (map_prod f g) r';
       a b a' b'.   abs_var a f a'; abs_var  b g b' 
                      abstract_val (P a b) (M a b) h (M' a' b') 
        abstract_val (P (fst r) (snd r))
            (case r of (a, b)  M a b) h
            (case r' of (a, b)  M' a b)"
  apply (cases r, cases r')
  apply (clarsimp simp: map_prod_def basic_abstract_defs)
  done

lemma abstract_val_case_prod_fun_app:
  " abstract_val True r (map_prod f g) r';
       a b a' b'.   abs_var a f a'; abs_var b g b' 
                      abstract_val (P a b) (M a b s) h (M' a' b' s) 
        abstract_val (P (fst r) (snd r))
            ((case r of (a, b)  M a b) s) h
            ((case r' of (a, b)  M' a b) s)"
  apply (cases r, cases r')
  apply (clarsimp simp: map_prod_def basic_abstract_defs)
  done

lemma abstract_val_of_nat:
  "abstract_val (r  UWORD_MAX TYPE('a::len)) r unat (of_nat r :: 'a word)"
  by (clarsimp simp: unat_of_nat_eq UWORD_MAX_def le_to_less_plus_one basic_abstract_defs)

lemma abstract_val_of_int:
  "abstract_val (WORD_MIN TYPE('a::len)  r  r  WORD_MAX TYPE('a)) r sint (of_int r :: 'a signed word)"
  by (clarsimp simp: sint_of_int_eq WORD_MIN_def WORD_MAX_def basic_abstract_defs)


lemma abstract_val_tuple:
  " abstract_val P a absL a';
     abstract_val Q b absR b'  
         abstract_val (P  Q) (a, b) (map_prod absL absR) (a', b')"
  by (clarsimp simp add: basic_abstract_defs)

lemma abstract_val_Inl:
  " abstract_val P a absL a' 
         abstract_val P (Inl a) (map_sum absL absR) (Inl a')"
  by (clarsimp simp add: basic_abstract_defs)

lemma abstract_val_Inr:
  " abstract_val P b absR b' 
         abstract_val P (Inr b) (map_sum absL absR) (Inr b')"
  by (clarsimp simp add: basic_abstract_defs)


lemma abstract_val_func:
   " abstract_val P a id a'; abstract_val Q b id b' 
          abstract_val (P  Q) (f a b) id (f a' b')"
  by (clarsimp simp add: basic_abstract_defs)

lemma abstract_val_conj:
  " abstract_val P a id a';
        abstract_val Q b id b'  
     abstract_val (P  (a  Q)) (a  b) id (a'  b')"
  apply (clarsimp simp add: basic_abstract_defs)
  apply blast
  done

lemma abstract_val_disj:
  " abstract_val P a id a';
        abstract_val Q b id b'  
     abstract_val (P  (¬ a  Q)) (a  b) id (a'  b')"
  apply (clarsimp simp add: basic_abstract_defs)
  apply blast
  done

lemma abstract_val_unwrap:
  " introduce_typ_abs_fn f; abstract_val P a f b 
         abstract_val P a id (f b)"
  by (simp add: basic_abstract_defs)

lemma abstract_val_uint:
  " introduce_typ_abs_fn unat; abstract_val P x unat x' 
       abstract_val P (int x) id (uint x')"
  by (clarsimp simp add: basic_abstract_defs)

lemma abstract_val_lambda:
   " v. abstract_val (P v) (a v) id (a' v)  
           abstract_val (v. P v) (λv. a v) id (λv. a' v)"
  by (auto simp add: basic_abstract_defs)

(* Rules for translating simpl wrappers. *)
lemma corresTA_call_L1:
  "abstract_val True arg_xf id arg_xf' 
   corresTA (λ_. True) id id
     (L2_call_L1 arg_xf gs ret_xf l1body)
     (L2_call_L1 arg_xf' gs ret_xf l1body)"
  apply (unfold corresTA_def abstract_val_def id_def)
  apply (subst (asm) simp_thms)
  apply (erule subst)
  apply (rule corresXF_id[simplified id_def])
  done


context stack_heap_state
begin
lemma corresTA_with_fresh_stack_ptr[word_abs]: 
  assumes f[simplified THIN_def, rule_format]: "PROP THIN (p. corresTA Q rx ex (fa p) (fc p))"
  assumes init: "s. abstract_val (P s) (inita s) id (initc s)"
  shows "corresTA P rx ex 
           (with_fresh_stack_ptr n inita (L2_VARS (λp. (L2_seq (L2_guard Q) (λ_.  fa p))) nm)) 
           (with_fresh_stack_ptr n initc (L2_VARS fc nm))"
  apply (rule refines_corresTA)
  apply (simp add: L2_seq_def L2_guard_def rel_word_abs_def)
  apply (rule refines_rel_prod_with_fresh_stack_ptr )
  subgoal for s
    using init [of s] 
    by (simp add: abstract_val_def)
  subgoal for s s' p
    apply (rule refines_bind_guard_right)
    apply (simp only: rel_word_abs_def [symmetric])
    apply (rule corresTA_refines)
     apply (rule f)
    apply simp
    done
  done
end

context typ_heap_typing
begin

lemma corresTA_guard_with_fresh_stack_ptr[word_abs]: 
  assumes f[simplified THIN_def, rule_format]: "PROP THIN (p. corresTA Q rx ex (fa p) (fc p))"
  assumes init: "s. abstract_val (P s) (inita s) id (initc s)"
  shows "corresTA P rx ex 
           (guard_with_fresh_stack_ptr n inita (L2_VARS (λp. (L2_seq (L2_guard Q) (λ_.  fa p))) nm)) 
           (guard_with_fresh_stack_ptr n initc (L2_VARS fc nm))"
  apply (rule refines_corresTA)
  apply (simp add: L2_seq_def L2_guard_def rel_word_abs_def)
  apply (rule refines_rel_xval_guard_with_fresh_stack_ptr )
  subgoal for s
    using init [of s] 
    by (simp add: abstract_val_def)
  subgoal for s s' p
    apply (rule refines_bind_guard_right)
    apply (simp only: rel_word_abs_def [symmetric])
    apply (rule corresTA_refines)
     apply (rule f)
    apply simp
    done
  done

lemma corresTA_assume_with_fresh_stack_ptr[word_abs]: 
  assumes f[simplified THIN_def, rule_format]: "PROP THIN (p. corresTA Q rx ex (fa p) (fc p))"
  assumes init: "s. abstract_val (P s) (inita s) id (initc s)"
  shows "corresTA P rx ex 
           (assume_with_fresh_stack_ptr n inita (L2_VARS (λp. (L2_seq (L2_guard Q) (λ_.  fa p))) nm)) 
           (assume_with_fresh_stack_ptr n initc (L2_VARS fc nm))"
  apply (rule refines_corresTA)
  apply (simp add: L2_seq_def L2_guard_def rel_word_abs_def)
  apply (rule refines_rel_xval_assume_with_fresh_stack_ptr)
  subgoal for s
    using init [of s] 
    by (simp add: abstract_val_def)
  subgoal for s s' p
    apply (rule refines_bind_guard_right)
    apply (simp only: rel_word_abs_def [symmetric])
    apply (rule corresTA_refines)
     apply (rule f)
    apply simp
    done
  done

lemma corresTA_with_fresh_stack_ptr[word_abs]: 
  assumes f[simplified THIN_def, rule_format]: "PROP THIN (p. corresTA Q rx ex (fa p) (fc p))"
  assumes init: "s. abstract_val (P s) (inita s) id (initc s)"
  shows "corresTA P rx ex 
           (with_fresh_stack_ptr n inita (L2_VARS (λp. (L2_seq (L2_guard Q) (λ_.  fa p))) nm)) 
           (with_fresh_stack_ptr n initc (L2_VARS fc nm))"
  apply (rule refines_corresTA)
  apply (simp add: L2_seq_def L2_guard_def rel_word_abs_def)
  apply (rule refines_rel_xval_with_fresh_stack_ptr)
  subgoal for s
    using init [of s] 
    by (simp add: abstract_val_def)
  subgoal for s s' p
    apply (rule refines_bind_guard_right)
    apply (simp only: rel_word_abs_def [symmetric])
    apply (rule corresTA_refines)
     apply (rule f)
    apply simp
    done
  done
end


lemma abstract_val_call_L1_args:
  "abstract_val P x id x'  abstract_val P y id y' 
   abstract_val P (x and y) id (x' and y')"
  by (simp add: basic_abstract_defs)

lemma abstract_val_call_L1_arg:
  "abs_var x id x'  abstract_val P (λs. f s = x) id (λs. f s = x')"
  by (simp add: basic_abstract_defs)

(* Variable abstraction *)

lemma abstract_val_abs_var [consumes 1]:
  " abs_var a f a'   abstract_val True a f a'"
  by (clarsimp simp: fun_upd_def basic_abstract_defs split: if_splits)

lemma abstract_val_abs_var_concretise [consumes 1]:
  " abs_var a A a'; introduce_typ_abs_fn A; valid_typ_abs_fn PA PC A (C :: 'a  'c)  
       abstract_val (PC a) (C a) id a'"
  by (clarsimp simp: fun_upd_def basic_abstract_defs split: if_splits)

lemma abstract_val_abs_var_give_up [consumes 1]:
  " abs_var a id a'   abstract_val True (A a) A a'"
  by (clarsimp simp: fun_upd_def basic_abstract_defs split: if_splits)


lemma abstract_val_abs_var_sint_unat [consumes 1]:
  " abs_var a sint a'   abstract_val (0  a) (nat a) id (unat a')"
  apply (simp add: basic_abstract_defs )
  by (metis nat_uint_eq signed_0 sint_eq_uint word_msb_0 word_sle_eq word_sle_msb_le)

lemma abstract_val_abs_var_uint_unat [consumes 1]:
 " abs_var a uint a'   abstract_val True (nat a) id (unat a')"
  by (simp add: basic_abstract_defs)

lemma abs_var_id: "(abs_var a id a') = (a' = a)"
  by (auto simp add: abs_var_def abstract_val_def)

lemma abstract_val_id: "abstract_val P a id a"
  by (simp add: abstract_val_def)
lemmas abstract_val_id_unit_ptr = abstract_val_id [where a= "a::unit ptr" and P = True] for a

lemma abstract_val_id_True: "abstract_val True a id a"
  by (rule abstract_val_id)

lemmas abs_var_id_unit_ptr = abs_var_id [where a= "a::unit ptr"] for a

(* Misc *)

lemma len_of_word_comparisons [L2opt]:
  "len_of TYPE(64)  len_of TYPE(64)"
  "len_of TYPE(32)  len_of TYPE(64)"
  "len_of TYPE(16)  len_of TYPE(64)"
  "len_of TYPE( 8)  len_of TYPE(64)"
  "len_of TYPE(32)  len_of TYPE(32)"
  "len_of TYPE(16)  len_of TYPE(32)"
  "len_of TYPE( 8)  len_of TYPE(32)"
  "len_of TYPE(16)  len_of TYPE(16)"
  "len_of TYPE( 8)  len_of TYPE(16)"
  "len_of TYPE( 8)  len_of TYPE( 8)"

  "len_of TYPE(32) < len_of TYPE(64)"
  "len_of TYPE(16) < len_of TYPE(64)"
  "len_of TYPE( 8) < len_of TYPE(64)"
  "len_of TYPE(16) < len_of TYPE(32)"
  "len_of TYPE( 8) < len_of TYPE(32)"
  "len_of TYPE( 8) < len_of TYPE(16)"

  "len_of TYPE('a::len signed) = len_of TYPE('a)"
  "(len_of TYPE('a) = len_of TYPE('a)) = True"
  by auto

lemma sbintrunc_eq: "0  i  i < 2^n  sbintrunc n i = i"
  apply (induction n arbitrary: i)
   apply auto
  done

lemma uint_ucast:
  "uint (ucast x :: ('a :: len) word) = uint x mod 2 ^ LENGTH('a)"
  unfolding uint_nat unat_ucast by (simp add: zmod_int)

lemma uint_scast':
  "uint (SCAST('a::len  'b::len) c) = sint c mod 2^LENGTH('b)"
  by (metis Word.of_int_sint word_uint.eq_norm)

lemma uint_ucast':
  "uint (UCAST('a::len  'b::len) c) = uint c mod 2^LENGTH('b)"
  by (metis Word.of_int_uint word_uint.eq_norm)

lemma sint_ucast':
  "LENGTH('a) < LENGTH('b)  sint (UCAST('a::len  'b::len) c) = uint c"
  by (smt bintrunc_bintrunc_ge less_or_eq_imp_le signed_take_bit_eq uint_sint unsigned_ucast_eq 
    unsigned_word_eqI)

lemma scast_ucast_eq_ucast [simp, L2opt]:
  "LENGTH('a::len) < LENGTH('b::len)  LENGTH('b)  LENGTH('c::len) 
    SCAST('b  'c) (UCAST('a  'b) x) = UCAST ('a  'c) x"
  unfolding word_uint_eq_iff uint_scast' uint_ucast sint_ucast' ..


lemma scast_ucast_simps [simp, L2opt]:
  " len_of TYPE('b)  len_of TYPE('a); len_of TYPE('c)  len_of TYPE('b)  
         (scast (ucast (a :: 'a::len word) :: 'b::len word) :: 'c::len word) = ucast a"
  " len_of TYPE('c)  len_of TYPE('a); len_of TYPE('c)  len_of TYPE('b)  
         (scast (ucast (a :: 'a::len word) :: 'b::len word) :: 'c::len word) = ucast a"
  " len_of TYPE('a)  len_of TYPE('b); len_of TYPE('c)  len_of TYPE('b)  
         (scast (ucast (a :: 'a::len word) :: 'b::len word) :: 'c::len word) = ucast a"
  " len_of TYPE('a)  len_of TYPE('b)  
     (scast (scast (a :: 'a::len word) :: 'b::len word) :: 'c::len word) = scast a"
  " len_of TYPE('b)  len_of TYPE('a); len_of TYPE('c)  len_of TYPE('b)  
            (ucast (scast (a :: 'a::len word) :: 'b::len word) :: 'c::len word) = scast a"
  " len_of TYPE('c)  len_of TYPE('a); len_of TYPE('c)  len_of TYPE('b)  
     (ucast (scast (a :: 'a::len word) :: 'b::len word) :: 'c::len word) = scast a"
  " len_of TYPE('a)  len_of TYPE('b); len_of TYPE('c)  len_of TYPE('b)  
     (ucast (scast (a :: 'a::len word) :: 'b::len word) :: 'c::len word) = scast a"
  " len_of TYPE('c)  len_of TYPE('b)  
        (ucast (ucast (a :: 'a::len word) :: 'b::len word) :: 'c::len word) = ucast a"
  " len_of TYPE('a)  len_of TYPE('b)  
     (ucast (ucast (a :: 'a::len word) :: 'b::len word) :: 'c::len word) = ucast a"
  " len_of TYPE('a)  len_of TYPE('b)  
            (scast (scast (a :: 'a::len word) :: 'b::len word) :: 'c::len word) = scast a"
  by (auto simp: is_up is_down
      scast_ucast_1 scast_ucast_3 scast_ucast_4
      ucast_scast_1 ucast_scast_3 ucast_scast_4
      scast_scast_a scast_scast_b
      ucast_ucast_a ucast_ucast_b)

declare len_signed [L2opt]

lemmas [L2opt] = zero_sle_ucast_up

lemma zero_sle_ucast_WORD_MAX [L2opt]:
  "(0 <=s ((ucast (b::('a::len) word)) :: ('a::len) signed word))
                = (uint b  WORD_MAX (TYPE('a)))"
  by (clarsimp simp: WORD_MAX_def zero_sle_ucast)

lemmas [L2opt] =
    is_up is_down unat_ucast_upcast sint_ucast_eq_uint

lemmas [L2opt] =
    ucast_down_add scast_down_add
    ucast_down_minus scast_down_minus
    ucast_down_mult scast_down_mult

lemma eq_trivI: "x = y  x = y"
  by simp
(*
 * Setup word abstraction rules.
 *)


(* Common word abstraction rules. *)

lemmas [word_abs] =
  corresTA_L2_gets
  corresTA_L2_modify
  corresTA_L2_throw
  corresTA_L2_skip
  corresTA_L2_fail
  corresTA_L2_seq
  corresTA_L2_seq_unit
  corresTA_L2_catch
  corresTA_L2_try'
  (*corresTA_return*)
  corresTA_L2_while
  corresTA_L2_guard
  corresTA_L2_guarded_simple
  corresTA_L2_spec
  corresTA_L2_assume
  corresTA_L2_condition
  corresTA_L2_unknown
  corresTA_case_prod
  corresTA_L2_call_exec_concrete'
  corresTA_L2_call_exec_concrete
  corresTA_L2_call_exec_abstract'
  corresTA_L2_call_exec_abstract
  corresTA_L2_call'
  corresTA_L2_call
  corresTA_call_L1

lemmas [word_abs] =
  abstract_val_tuple
  abstract_val_Inl
  abstract_val_Inr
  abstract_val_conj
  abstract_val_disj
  abstract_val_case_prod
  abstract_val_trivial
  abstract_val_of_int
  abstract_val_of_nat
  abstract_val_call_L1_args

(* follow the convention that later rules are prefered *)
lemmas abs_var_rules = 
  abstract_val_call_L1_arg
  abstract_val_abs_var_sint_unat
  abstract_val_abs_var_uint_unat
  abstract_val_abs_var_give_up
  abstract_val_abs_var_concretise
  abstract_val_abs_var

lemmas word_abs_base [word_abs] =
  valid_typ_abs_fn_id [where 'a="'a::c_type"]
  valid_typ_abs_fn_id [where 'a="bool"]
  valid_typ_abs_fn_id [where 'a="'gx c_exntype"]
  valid_typ_abs_fn_tuple_split
  valid_typ_abs_fn_tuple
  valid_typ_abs_fn_sum
  valid_typ_abs_fn_unit
  valid_typ_abs_fn_sint
  valid_typ_abs_fn_unat
  len_of_word_comparisons

(*
 * Signed word abstraction rules: 'a sword → int
 *)

lemmas word_abs_sword =
  abstract_val_signed_ops
  abstract_val_scast
  abstract_val_scast_upcast
  abstract_val_scast_downcast
  abstract_val_unwrap [where f=sint]
  introduce_typ_abs_fn [where f="sint :: (sword64  int)"]
  introduce_typ_abs_fn [where f="sint :: (sword32  int)"]
  introduce_typ_abs_fn [where f="sint :: (sword16  int)"]
  introduce_typ_abs_fn [where f="sint :: (sword8  int)"]

(*
 * Unsigned word abstraction rules: 'a word → nat
 *)

lemmas word_abs_word =
  abstract_val_unsigned_ops
  abstract_val_uint
  abstract_val_ucast
  abstract_val_ucast_upcast
  abstract_val_ucast_downcast
  abstract_val_unwrap [where f=unat]
  introduce_typ_abs_fn [where f="unat :: (word64  nat)"]
  introduce_typ_abs_fn [where f="unat :: (word32  nat)"]
  introduce_typ_abs_fn [where f="unat :: (word16  nat)"]
  introduce_typ_abs_fn [where f="unat :: (word8  nat)"]

(* 'a → 'a *)
lemmas word_abs_default =
  introduce_typ_abs_fn [where f="id :: ('a::c_type  'a)"]
  introduce_typ_abs_fn [where f="id :: (bool  bool)"]
  introduce_typ_abs_fn [where f="id :: ('gx c_exntype  'gx c_exntype)"]
  introduce_typ_abs_fn [where f="id :: (unit  unit)"]
  introduce_typ_abs_fn_tuple
  introduce_typ_abs_fn_sum

thm word_abs

lemma int_bounds_to_nat_boundsF: "(n::int) < numeral B  0  n  nat n < numeral B"
  by (simp add: nat_less_iff)

lemma int_bounds_one_to_nat: "(n::int) < 1  0  n  nat n = 0"
  by (simp add: nat_less_iff)

lemma id_map_prod_unfold: "id = map_prod id id"
  by (simp add: prod.map_id0)

lemma id_tuple_unfold: "id = (λ(x, y). (id x, id y))"
  by simp

end