Theory CorresXF

(*
 * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
 * Copyright (c) 2022 Apple Inc. All rights reserved.
 *
 * SPDX-License-Identifier: BSD-2-Clause
 *)

(*
 * A stronger version of the "corres" framework, allowing return
 * relationships to reference state data.
 *)

theory CorresXF
imports
  CCorresE
begin

(*
 * Refinement with return extraction on the concrete side:
 *
 * For any step on the concrete side, there is an equivalent step on
 * the abstract side.
 *
 * If the abstract step fails, we don't need refinement to hold.
 *)

definition "corresXF_simple st xf P M M' 
  s. (P s  succeeds M (st s))  (r' t'. reaches M' s r' t' 
        reaches M (st s) (xf r' t') (st t'))  succeeds M' s"

(*
 * A definition better suited to dealing with monads with exceptions.
 *)
definition "corresXF st ret_xf ex_xf P A C 
    s. P s  succeeds A (st s) 
        (r t. reaches C s r t  
          (case r of
             Exn r     reaches A (st s) (Exn (ex_xf r t)) (st t)
           | Result r  reaches A (st s) (Result (ret_xf r t)) (st t)))
         succeeds C s"


definition "rel_XF st ret_xf ex_xf Q  λ(r, t) (r', t'). 
  t' = st t  
  rel_xval (λe e'. e' = ex_xf e t) (λv v'. v' = ret_xf v t) r r' 
  Q r t"

lemma corresXF_refines_iff: 
  "corresXF st ret_xf ex_xf P A C 
      (s. P s  refines C A s (st s) (rel_XF st ret_xf ex_xf (λ_ _. True)))"
  apply standard
  subgoal
    apply (clarsimp simp add: corresXF_def refines_def_old rel_XF_def rel_xval.simps split: xval_splits)
    by (metis Exn_def default_option_def exception_or_result_cases not_Some_eq)
  subgoal
    apply (fastforce simp add: corresXF_def refines_def_old rel_XF_def rel_xval.simps split: xval_splits)
    done
  done

(*
* Stronger variant with postcondition
*)
definition "corresXF_post st ret_xf ex_xf P Q A C 
    s. P s  succeeds A (st s) 
        (r t. reaches C s r t  Q s r t 
          (case r of
             Exn r     reaches A (st s) (Exn (ex_xf r t)) (st t)
           | Result r  reaches A (st s) (Result (ret_xf r t)) (st t)))
         succeeds C s"

lemma corresXF_post_refines_iff: 
  "corresXF_post st ret_xf ex_xf P Q A C 
      (s. P s  refines C A s (st s) (rel_XF st ret_xf ex_xf (Q s)))"
  apply standard
  subgoal
    apply (clarsimp simp add: corresXF_post_def refines_def_old rel_XF_def rel_xval.simps split: xval_splits)
    by (metis Exn_def default_option_def exception_or_result_cases not_Some_eq)
  subgoal
    apply (fastforce simp add: corresXF_post_def refines_def_old rel_XF_def rel_xval.simps split: xval_splits)
    done
  done

lemma corresXF_post_to_corresXF: 
  "corresXF_post st ret_xf ex_xf P Q A C  corresXF st ret_xf ex_xf P A C"
  by (auto simp add: corresXF_def corresXF_post_def)

lemma corresXF_corres_XF_post_conv: 
  "corresXF st ret_xf ex_xf P A C = corresXF_post st ret_xf ex_xf P (λ_ _ _. True) A C"
  by (auto simp add: corresXF_def corresXF_post_def)


(* corresXF can be defined in terms of corresXF_simple. *)
lemma corresXF_simple_corresXF:
  "(corresXF_simple st
       (λx s. case x of
           Exn r  Exn (ex_state r s)
         | Result r  (Result (ret_state r s))) P M M')
  = (corresXF st ret_state ex_state P M M')"
 unfolding corresXF_simple_def corresXF_def
 by (auto split: xval_splits)


lemma corresXF_simpleI: "
    s' t' r'. P s'; succeeds M (st s'); reaches M' s' r' t'
         reaches M (st s') (xf r' t') (st t');
    s'. P s'; succeeds M (st s')   succeeds M' s'
    corresXF_simple st xf P M M'"
  apply atomize
  apply (clarsimp simp: corresXF_simple_def)
  done

lemma corresXF_I: "
    s' t' r'. P s'; succeeds M (st s'); reaches M' s' (Result r') t'
         reaches M (st s') (Result (ret_state r' t'))  (st t');
    s' t' r'. P s'; succeeds M (st s'); reaches M' s' (Exn r')  t'
         reaches M (st s') (Exn (ex_state r' t'))  (st t');
    s'. P s'; succeeds M (st s')   succeeds M' s'
    corresXF st ret_state ex_state P M M'"
  apply atomize
  apply (clarsimp simp: corresXF_def)
  subgoal for s r t
    apply (erule_tac x=s in allE, erule (1) impE)
    apply (erule_tac x=s in allE, erule (1) impE)
    apply (erule_tac x=s in allE, erule (1) impE)
    apply (clarsimp split: xval_splits)
    apply auto
    done
  done

lemma ccpo_prod_gfp_gfp:
  "class.ccpo
    (prod_lub Inf Inf :: (('a::complete_lattice * 'b :: complete_lattice) set  _))
    (rel_prod (≥) (≥)) (mk_less (rel_prod (≥) (≥)))"
  by (rule ccpo_rel_prodI ccpo_Inf)+

lemma admissible_mem: "ccpo.admissible Inf (≥) (λA. x  A)"
  by (auto simp: ccpo.admissible_def)

lemma admissible_nondet_ord_corresXF:
  "ccpo.admissible Inf (≥) (λA. corresXF st R E P A C)"
  unfolding corresXF_def imp_conjL imp_conjR
  apply (intro admissible_all  admissible_conj admissible_imp)
  subgoal for s
  apply (rule ccpo.admissibleI)
    apply (clarsimp simp add: ccpo.admissible_def chain_def 
        succeeds_def reaches_def split: prod.splits xval_splits)
    apply (intro conjI allI impI)
    subgoal
      apply transfer
      by (auto simp add: Inf_post_state_def top_post_state_def)
         (metis outcomes.simps(2) post_state.simps(3))
    subgoal
      apply transfer
      by (auto simp add: Inf_post_state_def top_post_state_def)
         (metis outcomes.simps(2) post_state.simps(3))
    done
  subgoal
    apply (rule ccpo.admissibleI)
    apply (clarsimp simp add: ccpo.admissible_def chain_def 
        succeeds_def reaches_def split: prod.splits xval_splits)
    apply transfer
    apply (auto simp add: Inf_post_state_def top_post_state_def image_def vimage_def 
        split: if_split_asm)
    done
  done



lemma corresXF_top: "corresXF st ret_xf ex_xf P  C"
  by (auto simp add: corresXF_def)

lemma admissible_nondet_ord_corresXF_post:
  "ccpo.admissible Inf (≥) (λA. corresXF_post st R E P Q A C)"
  unfolding corresXF_post_def imp_conjL imp_conjR
  apply (intro admissible_all  admissible_conj admissible_imp)
  subgoal
  apply (rule ccpo.admissibleI)
    apply (clarsimp simp add: ccpo.admissible_def chain_def 
        succeeds_def reaches_def split: prod.splits xval_splits)
    apply (intro conjI allI impI)
    subgoal
      apply transfer
      apply (clarsimp simp add: Inf_post_state_def top_post_state_def)
      apply (metis (no_types, lifting) INF_top_conv(1) Inf_post_state_def top_post_state_def)
      done
    subgoal
      apply transfer
      by (auto simp add: Inf_post_state_def top_post_state_def)
         (metis outcomes.simps(2) post_state.simps(3))
   subgoal
     apply transfer
     apply (clarsimp simp add: Inf_post_state_def top_post_state_def)
     apply (metis (no_types, lifting) INF_top_conv(1) Inf_post_state_def top_post_state_def)
     done
   subgoal
     apply transfer
     apply (clarsimp simp add: Inf_post_state_def top_post_state_def image_def vimage_def 
         split: if_split_asm)
     by (metis outcomes.simps(2) post_state.simps(3))
    done
  subgoal
    apply (rule ccpo.admissibleI)
    apply (clarsimp simp add: ccpo.admissible_def chain_def 
        succeeds_def reaches_def split: prod.splits xval_splits)
    apply transfer
    apply (auto simp add: Inf_post_state_def top_post_state_def image_def vimage_def 
        split: if_split_asm)
    done
  done



lemma corresXF_post_top: "corresXF_post st ret_xf ex_xf P Q  C"
  by (auto simp add: corresXF_post_def)

lemma corresXF_assume_pre:
  " s s'.  P s'; s = st s'   corresXF st xf_normal xf_exception P L R   corresXF st xf_normal xf_exception P L R"
  apply atomize
  apply (clarsimp simp: corresXF_def)
  apply force
  done

lemma corresXF_assume_fix_pre:
  " s s'.  P s'; s = st s'   corresXF st xf_normal xf_exception (λs. s = s'  P s) L R   corresXF st xf_normal xf_exception P L R"
  apply atomize
  apply (clarsimp simp: corresXF_def)
  done

lemma corresXF_guard_imp:
  " corresXF st xf_normal xf_exception Q f g; s. P s  Q s 
       corresXF st xf_normal xf_exception P f g"
  apply (clarsimp simp: corresXF_def)
  done

lemma corresXF_return:
  " s.  P s   xf_normal b s = a  
     corresXF st xf_normal xf_exception P (return a) (return b)"
  apply (clarsimp simp: corresXF_def )
  done

lemma corresXF_gets:
  " s. P s  ret (g s) s = f (st s)  
     corresXF st ret ex P (gets f) (gets g)"
  apply (clarsimp simp: corresXF_def)
  done

lemma corresXF_insert_guard:
  " corresXF st ret ex Q A C; s.  P s   G (st s)  Q s   
        corresXF st ret ex P (guard G >>= (λ_. A)) C "
  apply (auto  simp: corresXF_def succeeds_guard succeeds_bind reaches_bind reaches_guard)
  done

lemma corresXF_exec_abs_guard:
  "corresXF st ret_xf ex_xf (λs. P s  G (st s)) (A ()) C  corresXF st ret_xf ex_xf P (guard G >>= A) C"
  apply (auto simp: corresXF_def succeeds_guard succeeds_bind reaches_bind reaches_guard)
  done

lemma corresXF_simple_exec:
  " corresXF_simple st xf P A B; reaches B s r' s'; succeeds A (st s); P s 
       reaches A (st s) (xf r' s') (st s')"
  apply (fastforce simp: corresXF_simple_def)
  done

lemma corresXF_simple_fail:
  " corresXF_simple st xf P A B; ¬ succeeds B s; P s 
       ¬ succeeds A (st s)"
  apply (fastforce simp: corresXF_simple_def)
  done

lemma corresXF_simple_no_fail:
  " corresXF_simple st xf P A B; succeeds A (st s); P s 
       succeeds B s"
  apply (fastforce simp: corresXF_simple_def)
  done

lemma corresXF_exec_normal:
  " corresXF st ret ex P A B; reaches B s (Result r') s'; succeeds A (st s); P s 
       reaches A (st s) (Result (ret r' s')) (st s')"
  by (auto simp add: corresXF_def split: xval_splits)

lemma corresXF_exec_except:
  " corresXF st ret ex P A B; reaches B s (Exn r') s'; succeeds A (st s); P s 
       reaches A (st s) (Exn (ex r' s')) (st s')"
  by (auto simp add: corresXF_def split: xval_splits)

lemma corresXF_exec_fail:
  " corresXF st ret ex P A B; ¬ succeeds B s; P s 
       ¬ succeeds A (st s)"
  by (auto simp add: corresXF_def split: xval_splits)

lemma corresXF_intermediate:
    " corresXF st ret_xf ex_xf P A' C;
         corresXF id (λr s. r) (λr s. r) (λs. x. s = st x  P x) A A'  
        corresXF st ret_xf ex_xf P A C"
  apply (clarsimp simp: corresXF_def split: xval_splits)
  apply fast
  done

lemma corresXF_join:
  " corresXF st V E P L L'; x y. corresXF st V' E (P' x y) (R x) (R' y); 
    s. Q s  L'  s ?⦃ λr t. case r of Exn _   | Result v  P' (V v t) v t ; 
    s. Q s  P s  
    corresXF st V' E Q (L >>= R) (L' >>= R')"
  apply (clarsimp simp add: corresXF_refines_iff split: xval_splits)
  apply (rule refines_bind_bind_exn[where Q="rel_XF st V E (λ_ _. True) 
    (λ(r, t) x. case r of Exn _   | Result v  P' (V v t) v t)"])
  subgoal
    apply (rule refines_strengthen1[where R="rel_XF st V E (λ_ _. True)"])
    by auto
  by (auto simp add: rel_XF_def)

lemma corresXF_join_xf_state_independent_same_state:
  " corresXF (λs. s) (λr s. V r)  (λr s. E r) P L L'; 
    y. corresXF (λs. s) (λr s. V' r)  (λr s. E r) (P' (V y)) (R (V y)) (R' y); 
    s. Q s  L  s ?⦃ λr t. case r of Exn _   | Result v  P' v t; s. Q s  P s  
    corresXF (λs. s) (λr s. V' r) (λr s. E r) Q (L >>= R) (L' >>= R')"
  apply (clarsimp simp add: corresXF_refines_iff)
  apply (rule refines_bind_bind_exn[where
        Q="rel_XF (λs. s) (λr s. V r) (λr s. E r) (λ_ _. True) 
          (λx (r, t). case r of Exn _   | Result v  P' v t)"])
  subgoal
    apply (rule refines_strengthen2[where R="rel_XF (λs. s) (λr s. V r) (λr s. E r) (λ_ _. True)"])
    by auto
  by (auto simp add: rel_XF_def)

lemma corresXF_except:
  " corresXF st V E P L L'; x y. corresXF st V E' (P' x y) (R x) (R' y); 
    s. Q s  L'  s ?⦃ λr s. case r of Exn r  P' (E r s) r s | Result _   ; 
    s. Q s  P s  
    corresXF st V E' Q ( L <catch> R) (L' <catch> R')"
  apply (clarsimp simp add: corresXF_refines_iff)
  apply (rule refines_catch [where Q="(rel_XF st V E (λ_ _. True)) 
    (λ(r, s) x. case r of Exn r  P' (E r s) r s | Result _  )"])
  subgoal 
    apply (rule refines_strengthen1[where R="rel_XF st V E (λ_ _. True)"])
    by auto
  apply (auto simp add: rel_XF_def)
  done

lemma corresXF_cond:
  " corresXF st V E P L L'; corresXF st V E P R R'; s. P s  A (st s) = A' s  
    corresXF st V E P (condition A L R) (condition A' L' R')"
  apply (clarsimp simp add: corresXF_refines_iff)
  using refines_condition by metis

lemma refines_assume_succeeds: "(succeeds g t  refines f g s t R)  refines f g s t R"
  by (auto simp add: refines_def_old)

lemma corresXF_while:
  assumes body_corres: "x y. corresXF st ret ex (λs. P x s  y = ret x s) (A y) (B x)"
  and cond_match: "s r. P r s  C r s = C' (ret r s) (st s)"
  and pred_inv: "r s. P r s  C r s  succeeds (whileLoop C' A (ret r s)) (st s) 
                          B r  s ?⦃ λr s. case r of Exn _  True | Result r  P r s "
  and init_match: "s. P' x s  y = ret x s"
  and pred_imply: "s. P' x s  P x s"
shows "corresXF st ret ex (P' x) (whileLoop C' A y) (whileLoop C B x)"
  apply (clarsimp simp add: corresXF_refines_iff)
  apply (rule refines_assume_succeeds)
  subgoal for s
    apply (rule refines_mono [OF _ refines_whileLoop' 
          [where R = "rel_XF st ret ex (λr s. case r of Exn _  True | Result r  P r s) 
            (λ(r, s) _. case r of Exn _  True | Result r 
              succeeds (whileLoop C' A (ret r s)) (st s))" 
          and C = C and C'=C' and B=B and B'=A and I=x and I'=y and s=s and s'="st s"]])
    subgoal by (auto simp add: rel_XF_def)
    subgoal by (auto simp add: rel_XF_def cond_match pred_imply)
    subgoal using body_corres pred_inv
      apply (subst (asm) rel_XF_def)
      apply (clarsimp simp add: corresXF_refines_iff)
      apply (rule refines_strengthen[where R="rel_XF st ret ex (λ_ _. True)" and
        G="λr s. case r of Exn _  True | Result r  succeeds (whileLoop C' A r) s"])
      subgoal by auto
      apply assumption
      subgoal
        apply (subst (asm) (3) whileLoop_unroll)
        apply (auto simp add: succeeds_bind runs_to_partial_def_old split: xval_splits)
        done
      apply (auto simp: rel_XF_def rel_xval.simps)
      done
    subgoal
      by (auto simp add: rel_XF_def init_match pred_imply)
    subgoal
      by (auto simp add: rel_XF_def rel_xval.simps rel_exception_or_result.simps Exn_def default_option_def split: xval_splits )
    done
  done

lemma corresXF_name_pre:
  " s'. corresXF st ret ex (λs. P s  s = s') A C  
           corresXF st ret ex P A C"
  by (clarsimp simp: corresXF_def)

lemma corresXF_guarded_while_body:
  "corresXF st ret ex P A B 
      corresXF st ret ex P
             (do{ r  A; _  guard (G r); return r }) B"
  apply (clarsimp simp add: corresXF_refines_iff)
  apply (clarsimp simp add: refines_def_old succeeds_bind reaches_bind)
  by (smt (verit) Exn_def case_exception_or_result_Exn case_exception_or_result_Result 
      default_option_def the_Exception_simp the_Exception_Result exception_or_result_cases 
      is_Exception_simps(1) is_Exception_simps(2) not_None_eq)

lemma whileLoop_succeeds_terminates_guard_body: 
  assumes B_succeeds: "i s. succeeds (B i) s  succeeds (B' i) s"
  assumes B_reaches: "i s r t. reaches (B' i) s r t  succeeds (B i) s  reaches (B i) s r t"
  assumes termi: "run (whileLoop C B I) s  "  
  shows "run (whileLoop C B' I) s  " 
  using termi
proof (induct rule: whileLoop_ne_top_induct)
  case step show ?case unfolding top_post_state_def
    apply (rule whileLoop_ne_Failure)
    using step B_succeeds B_reaches
    apply (clarsimp simp add: runs_to_def_old)
    by (metis top_post_state_def)
qed

lemma whileLoop_succeeds_guard_body: 
  assumes B_succeeds: "i s. succeeds (B i) s  succeeds (B' i) s"
  assumes B_reaches: "i s r t. reaches (B' i) s r t  succeeds (B i) s  reaches (B i) s r t"
  assumes termi: "succeeds (whileLoop C B I) s"  
  shows "succeeds (whileLoop C B' I) s"
  using whileLoop_succeeds_terminates_guard_body [OF B_succeeds B_reaches] termi
  by (auto simp: succeeds_def top_post_state_def)


lemma corresXF_guarded_while:
  assumes body_corres: "x y. corresXF st ret ex (λs. P x s  y = ret x s) (A y) (B x)"
  and cond_match: "s r.  P r s; G (ret r s) (st s)   C r s = C' (ret r s) (st s)"
  and pred_inv: "r s. P r s  C r s  succeeds (whileLoop C' A (ret r s)) (st s)  G (ret r s) (st s) 
                          B r  s ?⦃ λr s. case r of Exn _  True | Result r  G (ret r s) (st s)  P r s "
  and pred_imply: "s.  G y (st s); P' x s   P x s"
  and init_match: "s.  G y (st s); P' x s   y = ret x s"
  shows "corresXF st ret ex (P' x)
      (do {
         _  guard (G y);
         whileLoop C' (λi. (do {
            r  A i;
            _  guard (G r);
            return r
          })) y
       })
      (whileLoop C B x)"
proof -

  {fix i s
    assume *: "succeeds (whileLoop C' (λi.
                       (do { r  A i;
                            _  guard (G r);
                            return r
                        })) i) s" 
    have "succeeds (whileLoop C' A i) s"
      apply (rule  whileLoop_succeeds_guard_body [OF _ _ *])
      subgoal by (auto simp add: succeeds_bind)

      subgoal apply (clarsimp simp add: reaches_bind succeeds_bind)
        by (smt (verit, best) dual_order.refl exception_or_result_split_asm le_boolD)
      done
  } note new_body_fails_more = this


  note new_body_corres = body_corres [THEN corresXF_guarded_while_body]

  show ?thesis
    apply (rule corresXF_exec_abs_guard)
    apply (rule corresXF_name_pre)
    apply (rule corresXF_assume_pre)
    apply clarsimp
    subgoal for s'
      apply (rule corresXF_guard_imp)
       apply (rule corresXF_while [where 
            P="λx s. P x s  G (ret x s) (st s)" and P'="λx s. P' x s  s = s'"])
           apply (rule corresXF_guard_imp)
            apply (rule new_body_corres)
           apply (clarsimp)
          apply (clarsimp)
          apply (rule cond_match, simp, simp)
      subgoal for r s 
        using pred_inv [of r s, OF _ _ new_body_fails_more ]
        apply clarsimp
        apply (subst (asm)  whileLoop_unroll)
        apply (clarsimp simp add: cond_match)
        apply (clarsimp simp add: succeeds_bind)
        apply (clarsimp simp add: runs_to_partial_def_old split: xval_splits, intro conjI)
         apply (metis (mono_tags, lifting) body_corres corresXF_exec_normal)
        using body_corres corresXF_exec_normal by fastforce
      subgoal
        using init_match by auto
      subgoal
        using init_match pred_imply by auto
      subgoal by auto
      done
    done
qed


(*

definition "corresXF st ret_xf ex_xf P A C ≡
    ∀s. P s ∧ ¬ snd (A (st s)) ⟶
        (∀(r, t) ∈ fst (C s).
          case r of
             Inl r ⇒ (Inl (ex_xf r t), st t) ∈ fst (A (st s))
           | Inr r ⇒ (Inr (ret_xf r t), st t) ∈ fst (A (st s)))
        ∧ ¬ snd (C s)"*)


(* Merge of lemmas ccorresE and corresXF. *)
definition "ac_corres st check_termination AF Γ rx ex G 
  λA B. s. (G s  succeeds A (st s)) 
         (t. Γ  B, Normal s  t 
             (case t of
               Normal s'  reaches A (st s) (Result (rx s')) (st s')
             | Abrupt s'  reaches A (st s) (Exn (ex s')) (st s')
             | Fault e  e  AF
             | _  False))
           (check_termination  Γ  B  Normal s)"

(* We can merge ccorresE and corresXF to form a ccorresXF statement. *)
lemma ccorresE_corresXF_merge:
  " ccorresE st1 ct AF Γ  G1 M B;
     corresXF st2 rx ex G2 A M;
     s. st s = st2 (st1 s);
     r s. rx' s = rx r (st1 s);
     r s. ex' s = ex r (st1 s);
     s. G s  (s  G1  G2 (st1 s))  
    ac_corres st ct AF Γ rx' ex' G A B"
  apply (unfold ac_corres_def)
  apply clarsimp
  apply (clarsimp simp: ccorresE_def)
  apply (clarsimp simp: corresXF_def)
  apply (erule allE, erule impE, force)
  apply (erule allE, erule impE, force)
  apply clarsimp
  apply (erule allE, erule impE, fastforce)
  subgoal for s t by (cases t; fastforce)
  done


(* We can also merge corresXF statements. *)
lemma corresXF_corresXF_merge:
  " corresXF st rx ex P A B; corresXF st' rx' ex' P' B C  
    corresXF (st o st') (λrv s. rx (rx' rv s) (st' s))
             (λrv s. ex (ex' rv s) (st' s)) (λs. P' s  P (st' s)) A C "
  apply (clarsimp simp: corresXF_def split: xval_splits)
  apply fastforce
  done

lemma ac_corres_guard_imp:
  " ac_corres st ct AF G rx ex P A C; s. P' s  P s   ac_corres st ct AF G rx ex P' A C"
  apply atomize
  apply (clarsimp simp: ac_corres_def)
  done

(*
 * Rules to use the corresXF definitions.
 *)

lemma corresXF_modify_local:
  " s. st s = st (M s); s. P s  ret () (M s) = x 
       corresXF st ret ex P (return x) (modify M)"
  by (auto simp add: corresXF_def)

lemma corresXF_modify_global:
  " s. P s  M (st s) = st (M' s)  
     corresXF st ret ex P (modify M) (modify M')"
  by (auto simp add: corresXF_def)

lemma corresXF_select_modify:
  " s. P s  st s = st (M s); s. P s  ret () (M s)  x  
     corresXF st ret ex P (select x) (modify M)"
  by (auto simp add: corresXF_def)

lemma corresXF_select_select:
  " s a. st s = st (M (a::('a  ('a::{type}))) s);
         s x.  P s; x  b  ret x s  a  
     corresXF st ret ex P (select a) (select b)"
  by (auto simp add: corresXF_def)

lemma corresXF_modify_gets:
  " s. P s  st s = st (M s); s. P s  ret () (M s) = f (st (M s))  
     corresXF st ret ex P (gets f) (modify M)"
  by (auto simp add: corresXF_def)


lemma corresXF_guard:
  " s. P s  G' s =  G (st s)   corresXF st ret ex P (guard G) (guard G')"
  by (auto simp add: corresXF_def)

lemma corresXF_fail:
  "corresXF st return_xf exception_xf P fail X"
  by (auto simp add: corresXF_def)

lemma corresXF_spec:
  " s s'. ((s, s')  A') = ((st s, st s')  A); surj st 
      corresXF st ret ex P (state_select A) (state_select A')"
  apply (clarsimp simp add: corresXF_def)
  apply (frule_tac y=undefined in surjD)
  apply (clarsimp simp: image_def set_eq_UNIV)
  apply metis
  done

lemma corresXF_throw:
  " s. P s  E B s = A   corresXF st V E P (throw A) (throw B)"
  by (auto simp add: corresXF_def Exn_def)


lemma corresXF_append_gets_abs:
  assumes corres: "corresXF st ret ex P L R"
  and consistent: "s. P s  R  s ?⦃λr s. case r of Exn _   | Result v  M (ret v s) (st s) = ret' v s "
shows "corresXF st ret' ex P (L >>= (λr. gets (M r))) R"
  using corres consistent
  apply (clarsimp simp add: corresXF_refines_iff runs_to_partial_def_old refines_def_old 
      succeeds_bind reaches_bind rel_XF_def rel_xval.simps split: xval_splits )
  using Exn_def by force

lemma corresXF_skipE:
  "corresXF st ret ex P skip skip"
  by (auto simp add: corresXF_def Exn_def)


lemma corresXF_id:
    "corresXF id (λr s. r) (λr s. r) P M M"
  by (fastforce simp: corresXF_def split: xval_splits)

lemma corresXF_cong:
  " s. st s = st' s;
     s r. ret_xf r s = ret_xf' r s;
     s r. ex_xf r s = ex_xf' r s;
     s. P s = P' s;
     s s'. P' s'  run A s = run A' s;
     s. P' s  run C s = run C' s  
           corresXF st ret_xf ex_xf P A C = corresXF st' ret_xf' ex_xf' P' A' C'"
   apply atomize
   apply (auto simp: corresXF_def reaches_def succeeds_def split: xval_splits)
   done

lemma corresXF_exec_abs_select:
  " x  Q; x  Q  corresXF id rx ex P (A x) A'   corresXF id rx ex P (select Q >>= A) A'"
  by (fastforce simp add: corresXF_def succeeds_bind reaches_bind split: xval_splits)

end