Theory CCorresE

 * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
 * Copyright (c) 2022 Apple Inc. All rights reserved.
 * SPDX-License-Identifier: BSD-2-Clause

 * A simple CCorres framework extension supporting exceptions on the monadic side.

 * A special form of "ccorres" where either side may throw an
 * exception if the other also throws an exception.
  ccorresE :: "('t  's)  bool  'ee set  ('p  ('t, 'p, 'ee) com option)
                         ('s  bool)  ('t set)
                         (unit, unit, 's) exn_monad  ('t, 'p, 'ee) com  bool"
  "ccorresE st check_term AF Γ G G' 
   λm c. s. G (st s)  (s  G')  succeeds m (st s) 
  ((t. Γ  c, Normal s  t 
   (case t of
         Normal s'  reaches m (st s) (Result ()) (st s')
       | Abrupt s'  reaches m (st s) (Exn ()) (st s')
       | Fault e  e  AF
       | _  False))
    (check_term  Γ  c  Normal s))"

lemma ccorresE_cong:
  " s. P s = P' s;
     s. (s  Q) = (s  Q');
     s. P' s  run f s = run f' s;
     s x. s  Q'  Γ g, Normal s  x = Γ g', Normal s  x
  ccorresE st ct AF Γ P Q f g = ccorresE st ct AF Γ P' Q' f' g"
  apply atomize
  apply (clarsimp simp: ccorresE_def split: xstate.splits)
  by (auto simp add: succeeds_def reaches_def)

lemma ccorresE_guard_imp:
  " ccorresE st ct AF Γ Q Q' A B; s. P s  Q s; t. t  P'  t  Q'    ccorresE st ct AF Γ P P' A B"
  apply atomize
  apply (clarsimp simp: ccorresE_def split: xstate.splits)

lemma ccorresE_guard_imp_stronger:
  " ccorresE st ct AF Γ Q Q' A B;
     s.  P (st s); s  P'   Q (st s);
     s.  P (st s); s  P'   s  Q'  
  ccorresE st ct AF Γ P P' A B"
  apply atomize
  apply (clarsimp simp: ccorresE_def split_def split: xstate.splits)

lemma ccorresE_assume_pre:
  " s.  G (st s); s  G'  
         ccorresE st ct AF Γ (G and (λs'. s' = st s)) (G'  {t'. t' = s}) A B  
     ccorresE st ct AF Γ G G' A B"
  apply atomize
  apply (simp add: ccorresE_def pred_conj_def)

lemma ccorresE_Seq:
  " ccorresE st ct AF Γ  UNIV L L';
     ccorresE st ct AF Γ  UNIV R R'  
   ccorresE st ct AF Γ  UNIV (do {_  L; R }) (L' ;; R')"
  apply (clarsimp simp: ccorresE_def)
  apply (rule conjI)
   apply clarsimp
   apply (erule exec_Normal_elim_cases)
    apply (clarsimp split: xstate.splits
        simp add: succeeds_bind reaches_bind, intro conjI)
    by (metis (mono_tags, lifting) Normal_resultE case_exception_or_result_Result)
      (metis (mono_tags, lifting) case_exception_or_result_Exn 
        case_exception_or_result_Result exec_Normal_elim_cases(1) 
        exec_Normal_elim_cases(3) xstate.exhaust)+
    apply (clarsimp split: xstate.splits
        simp add: succeeds_bind reaches_bind)
    by (metis Abrupt Fault terminates.Seq xstate.exhaust)
lemma ccorresE_Cond:
  " ccorresE st ct AF Γ  C A L';
     ccorresE st ct AF Γ  (UNIV - C) A R'  
   ccorresE st ct AF Γ  UNIV A (Cond C L' R')"
  apply (clarsimp simp: ccorresE_def pred_neg_def)
  subgoal for s
    apply (rule conjI)
     apply clarsimp
     apply (erule exec_Normal_elim_cases)
      apply (erule_tac x=s in allE, erule impE, fastforce, fastforce)
     apply (erule_tac x=s in allE, erule impE, fastforce, fastforce)
    apply clarsimp
    apply (cases "s  C")
     apply (rule terminates.CondTrue, assumption)
     apply (erule allE, erule impE, fastforce)
     apply clarsimp
    apply (rule terminates.CondFalse, assumption)
    apply (erule allE, erule impE, fastforce)
    apply clarsimp

lemma ccorresE_Cond_match:
  " ccorresE st ct AF Γ C C' L L';
     ccorresE st ct AF Γ (not C) (UNIV - C') R R';
     s. C (st s) = (s  C')  
   ccorresE st ct AF Γ  UNIV (condition C L R) (Cond C' L' R')"
  apply atomize
  apply (simp add: ccorresE_def pred_neg_def)
  apply clarify
  apply (intro conjI allI impI)

  subgoal for s t
    apply (auto elim!: exec_Normal_elim_cases split: xstate.splits)
  subgoal for s
    apply (cases "s  C'")
      by (simp add: terminates.CondTrue)
      by (simp add: terminates.CondFalse)

lemma ccorresE_Guard:
  " ccorresE st ct AF Γ  G X Y    ccorresE st ct AF Γ  G X (Guard F G Y)"
  apply (clarsimp simp: ccorresE_def)
  apply (rule conjI)
   apply clarsimp
   apply (erule exec_Normal_elim_cases, auto)[1]
  apply clarsimp
  apply (rule terminates.Guard, assumption)
  apply force

lemma ccorresE_Catch:
  "ccorresE st ct AF Γ  UNIV A A'; ccorresE st ct AF Γ  UNIV B B' 
    ccorresE st ct AF Γ  UNIV (A <catch> (λ_. B)) (TRY A' CATCH B' END)"
  apply (clarsimp simp: ccorresE_def)
  apply (rule conjI)
   apply clarsimp
   apply (erule_tac x=s in allE)
   apply (erule exec_Normal_elim_cases)
  subgoal for s t s'
    apply (cases t)
      using reaches_catch
      by (metis case_xval_simps(1) succeeds_catch xstate.simps(16) xstate.simps(17))
      using reaches_catch
      by (metis case_xval_simps(1) succeeds_catch xstate.simps(17))
      using reaches_catch
      by (metis case_xval_simps(1) succeeds_catch xstate.simps(17) xstate.simps(18))
      using reaches_catch
      by (metis case_xval_simps(1) succeeds_catch xstate.simps(17) xstate.simps(19))
  subgoal for s t

    apply (cases t)
       apply (fastforce simp add: reaches_catch succeeds_catch)+
  subgoal for s
    by (metis case_xval_simps(1) succeeds_catch terminates.Catch xstate.simps(17))

lemma ccorresE_Call:
  " Γ X' = Some Z'; ccorresE st ct AF Γ  UNIV Z Z'  
    ccorresE st ct AF Γ  UNIV Z (Call X')"
  apply (clarsimp simp: ccorresE_def)
  apply (rule conjI)
   apply clarsimp
   apply (erule exec_Normal_elim_cases)
    apply (clarsimp)
   apply clarsimp
  apply clarify
  apply (erule terminates.Call)
  apply (erule allE, erule (1) impE)
  apply clarsimp

lemma ccorresE_exec_Normal:
    " ccorresE st ct AF Γ G G' B B'; Γ B', Normal s  Normal t; s  G'; G (st s); succeeds B (st s)  
   reaches B (st s) (Result ()) (st t)"
  apply (clarsimp simp: ccorresE_def)
  apply force

lemma ccorresE_exec_Abrupt:
    " ccorresE st ct AF Γ G G' B B'; Γ B', Normal s  Abrupt t; s  G'; G (st s); succeeds B (st s)  
   reaches B (st s) (Exn ()) (st t)"
  apply (clarsimp simp: ccorresE_def)
  apply force

lemma ccorresE_exec_Fault:
    " ccorresE st ct AF Γ G G' B B'; Γ B', Normal s  Fault f; f  AF; s  G'; G (st s); succeeds B (st s)   P"
  apply (clarsimp simp: ccorresE_def)
  apply force

lemma ccorresE_exec_Stuck:
    " ccorresE st ct AF Γ G G' B B'; Γ B', Normal s  Stuck; s  G'; G (st s); succeeds B (st s)   P"
  apply (clarsimp simp: ccorresE_def)
  apply force

lemma ccorresE_exec_cases [consumes 5]:
    " ccorresE st ct AF Γ G G' B B'; Γ B', Normal s  s'; s  G'; G (st s); succeeds B (st s);
                  t'.  s' = Normal t'; reaches B (st s) (Result ()) (st t')  R;
                  t'.  s' = Abrupt t'; reaches B (st s) (Exn ()) (st t')   R;
                  f.  s' = Fault f; f  AF   R
  apply atomize
  apply (cases s')
     apply (drule ccorresE_exec_Normal, auto)[1]
    apply (drule ccorresE_exec_Abrupt, auto)[1]
  subgoal for f
    apply (cases "f  AF")
      by auto
      by (drule ccorresE_exec_Fault, auto)[1]
  apply (drule ccorresE_exec_Stuck, auto)[1]

lemma ccorresE_terminates:
  " ccorresE st ct AF Γ  UNIV B B'; succeeds B (st s); ct   Γ  B'  Normal s"
   by (clarsimp simp: ccorresE_def)

lemma exec_While_final_inv':
  assumes exec: "Γ b, x  s'"
  " b = While C B; x = Normal s;
    s.  s  C   I s (Normal s);
    t t'.  t  C; Γ B, Normal t  Normal t'; I t' s'   I t s';
    t t'.  t  C; Γ B, Normal t  Abrupt t'   I t (Abrupt t');
    t.  t  C; Γ B, Normal t  Stuck   I t Stuck;
    t f.  t  C; Γ B, Normal t  Fault f   I t (Fault f) 
     I s s'"
  using exec
  apply (induct arbitrary: s rule: exec.induct, simp_all)
  apply clarsimp
  apply atomize
  apply clarsimp
  apply (erule allE, erule (1) impE)
  apply (erule exec_elim_cases, auto)

lemma exec_While_final_inv:
  " Γ While C B, Normal s  s';
    s.  s  C   I s (Normal s);
    t t'.  t  C; Γ B, Normal t  Normal t'; I t' s'   I t s';
    t t'.  t  C; Γ B, Normal t  Abrupt t'   I t (Abrupt t');
    t.  t  C; Γ B, Normal t  Stuck   I t Stuck;
    t f.  t  C; Γ B, Normal t  Fault f   I t (Fault f) 
     I s s'"
   apply (erule exec_While_final_inv', (rule refl)+, simp_all)

lemma ccorresE_termination':
  assumes no_fail: "succeeds (whileLoop CC BB r) s"
  and s_match: "s = st s'  CC = (λ_. C)  BB = (λ_. B)"
  and corres: "ccorresE st ct AF Γ  UNIV B B'"
  and cond_match: "s. C (st s) = (s  C')"
  and ct: "ct"
shows "Γ While C' B'  Normal s'"
proof -
  from no_fail have "run (whileLoop CC BB r) s  "
    by (simp add: succeeds_def)
  from this show ?thesis
    using s_match
  proof (induct arbitrary: s' rule: whileLoop_ne_top_induct)
    case (step a s)
    then show ?case using corres cond_match ct
      apply (cases "s'  C'")
      apply (simp add: terminates.WhileFalse)
      apply (rule terminates.WhileTrue)
      apply simp
        by (simp add: runs_to_def_old ccorresE_terminates)
        apply (clarsimp simp: runs_to_def_old)
        by (metis Abrupt Fault ccorresE_exec_Normal ccorresE_exec_Stuck 
                  iso_tuple_UNIV_I top1I xstate.exhaust)

lemma ccorresE_termination:
  assumes no_fail: "succeeds (whileLoop (λ_. C) (λ_. B) r) s"
  and s_match: "s = st s'"
  and corres: "ccorresE st ct AF Γ  UNIV B B'"
  and cond_match: "s. C (st s) = (s  C')"
  and ct: "ct"
  shows "Γ While C' B'  Normal s'"
  apply (auto intro: ccorresE_termination' [OF no_fail _ corres _ ct] simp: s_match cond_match)

lemma ccorresE_While:
  assumes body_refines: "ccorresE st ct AF Γ  UNIV B B'"
      and cond_match: "s. C (st s) = (s  C')"
    shows "ccorresE st ct AF Γ G G' (whileLoop (λ_. C) (λ_. B) ()) (While C' B')"
proof -
    fix s t
    assume guard_abs: "G (st s)"
    assume guard_conc: "s  G'"

    assume no_fail: "succeeds (whileLoop (λ_. C) (λ_. B) ()) (st s)"
    assume conc_steps: "Γ While C' B', Normal s  t"
    have "case t of
        Normal s'  reaches (whileLoop (λ_. C) (λ_. B) ()) (st s) (Result ()) (st s')
      | Abrupt s'  reaches (whileLoop (λ_. C) (λ_. B) ()) (st s) (Exn ()) (st s')
      | Fault e  e  AF
      | _  False"
      apply (insert no_fail, erule rev_mp)
      apply (rule exec_While_final_inv [OF conc_steps])
        using cond_match
        apply clarsimp
        apply (subst whileLoop_unroll)
        apply (simp add: reaches_condition_iff)
      subgoal for t1 t'
        apply (subst (1 2 3) whileLoop_unroll)
        apply (clarsimp simp add: cond_match)
        using ccorresE_exec_Normal [OF body_refines, of t1 t']
        using cond_match
        apply (force split: xstate.splits simp add: succeeds_bind reaches_bind )
      subgoal for t1 t'
        apply (subst (1 2 3) whileLoop_unroll)
        apply (clarsimp simp add: cond_match)
        using ccorresE_exec_Abrupt [OF body_refines, of t1 t']
        using cond_match
        apply (clarsimp simp add: succeeds_bind reaches_bind)
        using Exn_def by force
      subgoal for t
        apply (subst (1 2 3) whileLoop_unroll)
        apply (clarsimp simp add: cond_match)
        using ccorresE_exec_Stuck [OF body_refines, of t]
        apply (metis UNIV_I succeeds_bind top1I)
      subgoal for t
        apply (subst (1 2 3) whileLoop_unroll)
        apply (clarsimp simp add: cond_match)
        using ccorresE_exec_Fault [OF body_refines, of t]
        apply (metis UNIV_I succeeds_bind top1I)
    fix s
    assume guard_abs: "G (st s)"
    assume guard_conc: "s  G'"
    assume no_fail: "succeeds (whileLoop (λ_. C) (λ_. B) ()) (st s)"
    have "ct  ΓWhile C' B'  Normal s"
      apply clarify
      apply (rule ccorresE_termination [OF no_fail])
         apply (rule refl)
        apply (rule body_refines)
       apply (rule cond_match)
      apply simp
  ultimately show ?thesis
    by (auto simp: ccorresE_def)

lemma ccorresE_get:
  "(s. ccorresE st ct AF Γ (P and (λs'. s' = s)) Q (L s) R)  ccorresE st ct AF Γ P Q ((get_state) >>= L) R"
  apply atomize
  apply (auto simp add: ccorresE_def succeeds_bind reaches_bind pred_conj_def split: xstate.splits )

lemma ccorresE_fail:
  "ccorresE st ct AF Γ P Q fail R"
  apply (clarsimp simp: ccorresE_def)

lemma ccorresE_DynCom:
  " t.  t  P'   ccorresE st ct AF Γ P (P'  {t'. t' = t}) A (B t)   ccorresE st ct AF Γ P P' A (DynCom B)"
  apply atomize
  apply (clarsimp simp: ccorresE_def)
  apply (rule conjI)
   apply clarsimp
   apply (erule exec_Normal_elim_cases)
   apply (erule allE, erule(1) impE)
   apply clarsimp
  apply clarify
  apply (rule terminates.DynCom)
  apply clarsimp

lemma ccorresE_Catch_nothrow:
  "ccorresE st ct AF Γ  UNIV A A'; ¬ exceptions_thrown A' 
    ccorresE st ct AF Γ  UNIV A (TRY A' CATCH B' END)"
  apply (clarsimp simp: ccorresE_def)
  apply (rule conjI)
   apply clarsimp
   apply (erule exec_Normal_elim_cases)
    apply (frule exceptions_thrown_not_abrupt, simp, simp)
    apply simp
   apply simp
  apply clarify
  apply (rule terminates.Catch)
   apply clarsimp
  apply clarsimp
  apply (drule (1) exceptions_thrown_not_abrupt)
   apply simp
  apply simp

context stack_heap_state

definition with_fresh_stack_ptr :: "nat  ('s  'a list set)  ('a::mem_type ptr  ('e::default, 'v, 's) spec_monad)  ('e::default, 'v, 's) spec_monad"
  "with_fresh_stack_ptr n I c 
    do {
      p  assume_result_and_state (λs. {(p, t). d vs bs. (p, d)  stack_allocs n 𝒮 TYPE('a::mem_type) (htd s)  
           vs  I s  length vs = n  length bs = n * size_of TYPE('a) 
           t = hmem_upd (fold (λi. heap_update_padding (p +p int i) (vs!i) (take (size_of TYPE('a)) (drop (i * size_of TYPE('a)) bs))) [0..<n]) (htd_upd (λ_. d) s)});
      on_exit (c p)
        ({(s,t). bs. length bs = n * size_of TYPE('a)  t = hmem_upd (heap_update_list (ptr_val p) bs) (htd_upd (stack_releases n p) s)})

lemma monotone_with_fresh_stack_ptr_le[partial_function_mono]:
  assumes [partial_function_mono]: "p. monotone R (≤) (λf. c f p)"  
  shows "monotone R (≤) (λf. with_fresh_stack_ptr n I (c f))"
  unfolding with_fresh_stack_ptr_def on_exit_def
  by (intro partial_function_mono)

lemma monotone_with_fresh_stack_ptr_ge[partial_function_mono]:
  assumes [partial_function_mono]: "p. monotone R (≥) (λf. c f p)"  
  shows "monotone R (≥) (λf. with_fresh_stack_ptr n I (c f))"
  unfolding with_fresh_stack_ptr_def on_exit_def
  by (intro partial_function_mono)

ML structure with_fresh_stack_ptr =

type data = {
  match: term -> {n:term, init:term, c:term, ct_: term, instantiate: {n:term, init:term, c:term} -> term},
  cterm_match: cterm -> {n:cterm, init:cterm, c:cterm, ct_: cterm, instantiate: {n:cterm, init:cterm, c:cterm} -> cterm},
  term: typ -> term}

fun map_match f ({match, cterm_match, term}:data) = 
  {match = f match, cterm_match = cterm_match, term = term}:data
fun map_cterm_match f ({match, cterm_match, term}:data) = 
  {match = match, cterm_match = f cterm_match, term = term}:data
fun map_term f ({match, cterm_match, term}:data) = 
  {match = match, cterm_match = cterm_match, term = f term}:data

fun merge_match (f, g) = Utils.fast_merge (fn (f, g) => Utils.first_match [f, g]) (f, g)

structure Data = Generic_Data (
  type T = data;
  val empty = {match = fn _ => raise Match, cterm_match =  fn _ => raise Match, term = fn _ => raise Match}
  val merge = Utils.fast_merge (fn ({match = f1, cterm_match = g1, term = t1}, {match = f2, cterm_match = g2, term = t2}) =>
         {match = merge_match (f1, f2), cterm_match = merge_match (g1, g2), term = merge_match (t1, t2)}); 

fun match ctxt = #match (Data.get (Context.Proof ctxt))
fun cterm_match ctxt = #cterm_match (Data.get (Context.Proof ctxt))
fun term ctxt = #term (Data.get (Context.Proof ctxt))

fun add_match match = (map_match (Utils.add_match match))
fun add_cterm_match cterm_match = (map_cterm_match (Utils.add_match cterm_match))
fun add_term match = (map_term (Utils.add_match match))


declaration fn phi => fn context =>
    val T = Morphism.typ phi @{typ 's}
    val t = Morphism.term phi @{term with_fresh_stack_ptr}
    val thy = Context.theory_of context
    fun term T' = 
        if can (Sign.typ_match thy (T, T')) Vartab.empty then t else raise Match
    fun match t = @{morph_match (fo) with_fresh_stack_ptr ?n ?init ?c} phi (Context.theory_of context) t
        handle Pattern.MATCH => raise Match
    fun cterm_match ct = @{cterm_morph_match (fo) with_fresh_stack_ptr ?n ?init ?c} phi ct
        handle Pattern.MATCH => raise Match
   |> with_fresh_stack_ptr.add_match match 
   |> with_fresh_stack_ptr.add_cterm_match cterm_match 
   |> with_fresh_stack_ptr.add_term term
