Theory Computational_Model
section ‹Oracle combinators›
theory Computational_Model imports 
  Generative_Probabilistic_Value
begin
type_synonym security = nat
type_synonym advantage = "security ⇒ real"
type_synonym ('σ, 'call, 'ret) oracle' = "'σ ⇒ 'call ⇒ ('ret × 'σ) spmf"
type_synonym ('σ, 'call, 'ret) "oracle" = "security ⇒ ('σ, 'call, 'ret) oracle' × 'σ"
print_translation  ‹
  let
    fun tr' [Const (@{type_syntax nat}, _), 
      Const (@{type_syntax prod}, _) $ 
        (Const (@{type_syntax fun}, _) $ s1 $ 
          (Const (@{type_syntax fun}, _) $ call $
            (Const (@{type_syntax pmf}, _) $
              (Const (@{type_syntax option}, _) $
                (Const (@{type_syntax prod}, _) $ ret $ s2))))) $
        s3] =
      if s1 = s2 andalso s1 = s3 then Syntax.const @{type_syntax oracle} $ s1 $ call $ ret
      else raise Match;
  in [(@{type_syntax "fun"}, K tr')]
  end
›
typ "('σ, 'call, 'ret) oracle"
subsection ‹Shared state›
context includes ℐ.lifting and lifting_syntax begin
lift_definition plus_ℐ :: "('out, 'ret) ℐ ⇒ ('out', 'ret') ℐ ⇒ ('out + 'out', 'ret + 'ret') ℐ" (infix ‹⊕⇩ℐ› 500)
is "λresp1 resp2. λout. case out of Inl out' ⇒ Inl ` resp1 out' | Inr out' ⇒ Inr ` resp2 out'" .
lemma plus_ℐ_sel [simp]:
  shows outs_plus_ℐ: "outs_ℐ (plus_ℐ ℐl ℐr) = outs_ℐ ℐl <+> outs_ℐ ℐr"
  and responses_plus_ℐ_Inl: "responses_ℐ (plus_ℐ ℐl ℐr) (Inl x) = Inl ` responses_ℐ ℐl x"
  and responses_plus_ℐ_Inr: "responses_ℐ (plus_ℐ ℐl ℐr) (Inr y) = Inr ` responses_ℐ ℐr y"
by(transfer; auto split: sum.split_asm; fail)+
lemma vimage_Inl_Plus [simp]: "Inl -` (A <+> B) = A" 
  and vimage_Inr_Plus [simp]: "Inr -` (A <+> B) = B"
by auto
lemma vimage_Inl_image_Inr: "Inl -` Inr ` A = {}"
  and vimage_Inr_image_Inl: "Inr -` Inl ` A = {}"
by auto
lemma plus_ℐ_parametric [transfer_rule]:
  "(rel_ℐ C R ===> rel_ℐ C' R' ===> rel_ℐ (rel_sum C C') (rel_sum R R')) plus_ℐ plus_ℐ"
apply(rule rel_funI rel_ℐI)+
subgoal premises [transfer_rule] by(simp; rule conjI; transfer_prover)
apply(erule rel_sum.cases; clarsimp simp add: inj_vimage_image_eq vimage_Inl_image_Inr empty_transfer vimage_Inr_image_Inl)
subgoal premises [transfer_rule] by transfer_prover
subgoal premises [transfer_rule] by transfer_prover
done
lifting_update ℐ.lifting
lifting_forget ℐ.lifting
lemma ℐ_trivial_plus_ℐ [simp]: "ℐ_trivial (ℐ⇩1 ⊕⇩ℐ ℐ⇩2) ⟷ ℐ_trivial ℐ⇩1 ∧ ℐ_trivial ℐ⇩2"
by(auto simp add: ℐ_trivial_def)
end
lemma map_ℐ_plus_ℐ [simp]: 
  "map_ℐ (map_sum f1 f2) (map_sum g1 g2) (ℐ1 ⊕⇩ℐ ℐ2) = map_ℐ f1 g1 ℐ1 ⊕⇩ℐ map_ℐ f2 g2 ℐ2"
proof(rule ℐ_eqI[OF Set.set_eqI], goal_cases)
  case (1 x)
  then show ?case by(cases x) auto
qed (auto simp add: image_image)
lemma le_plus_ℐ_iff [simp]:
  "ℐ1 ⊕⇩ℐ ℐ2 ≤ ℐ1' ⊕⇩ℐ ℐ2' ⟷ ℐ1 ≤ ℐ1' ∧ ℐ2 ≤ ℐ2'"
  by(auto 4 4 simp add: le_ℐ_def dest: bspec[where x="Inl _"] bspec[where x="Inr _"])
lemma ℐ_full_le_plus_ℐ: "ℐ_full ≤ plus_ℐ ℐ1 ℐ2" if "ℐ_full ≤ ℐ1" "ℐ_full ≤ ℐ2"
  using that by(auto simp add: le_ℐ_def top_unique)
lemma plus_ℐ_mono: "plus_ℐ ℐ1 ℐ2 ≤ plus_ℐ ℐ1' ℐ2'" if "ℐ1 ≤ ℐ1'" "ℐ2 ≤ ℐ2'" 
  using that by(fastforce simp add: le_ℐ_def)
context
  fixes left :: "('s, 'a, 'b) oracle'"
  and right :: "('s,'c, 'd) oracle'"
  and s :: "'s"
begin
primrec plus_oracle :: "'a + 'c ⇒ (('b + 'd) × 's) spmf"
where
  "plus_oracle (Inl a) = map_spmf (apfst Inl) (left s a)"
| "plus_oracle (Inr b) = map_spmf (apfst Inr) (right s b)"
lemma lossless_plus_oracleI [intro, simp]:
  "⟦ ⋀a. x = Inl a ⟹ lossless_spmf (left s a); 
     ⋀b. x = Inr b ⟹ lossless_spmf (right s b) ⟧
  ⟹ lossless_spmf (plus_oracle x)"
by(cases x) simp_all
lemma plus_oracle_split:
  "P (plus_oracle lr) ⟷
  (∀x. lr = Inl x ⟶ P (map_spmf (apfst Inl) (left s x))) ∧
  (∀y. lr = Inr y ⟶ P (map_spmf (apfst Inr) (right s y)))"
by(cases lr) auto
lemma plus_oracle_split_asm:
  "P (plus_oracle lr) ⟷
  ¬ ((∃x. lr = Inl x ∧ ¬ P (map_spmf (apfst Inl) (left s x))) ∨
     (∃y. lr = Inr y ∧ ¬ P (map_spmf (apfst Inr) (right s y))))"
by(cases lr) auto
end
notation plus_oracle (infix ‹⊕⇩O› 500)
context
  fixes left :: "('s, 'a, 'b) oracle'"
  and right :: "('s,'c, 'd) oracle'"
begin
lemma WT_plus_oracleI [intro!]:
  "⟦ ℐl ⊢c left s √; ℐr ⊢c right s √ ⟧ ⟹ ℐl ⊕⇩ℐ ℐr ⊢c (left ⊕⇩O right) s √"
by(rule WT_calleeI)(auto elim!: WT_calleeD simp add: inj_image_mem_iff)
lemma WT_plus_oracleD1:
  assumes "ℐl ⊕⇩ℐ ℐr ⊢c (left ⊕⇩O right) s √ " (is "?ℐ ⊢c ?callee s √")
  shows "ℐl ⊢c left s √"
proof(rule WT_calleeI)
  fix call ret s'
  assume "call ∈ outs_ℐ ℐl" "(ret, s') ∈ set_spmf (left s call)"
  hence "(Inl ret, s') ∈ set_spmf (?callee s (Inl call))" "Inl call ∈ outs_ℐ (ℐl ⊕⇩ℐ ℐr)"
    by(auto intro: rev_image_eqI)
  hence "Inl ret ∈ responses_ℐ ?ℐ (Inl call)" by(rule WT_calleeD[OF assms])
  then show "ret ∈ responses_ℐ ℐl call" by(simp add: inj_image_mem_iff)
qed
lemma WT_plus_oracleD2:
  assumes "ℐl ⊕⇩ℐ ℐr ⊢c (left ⊕⇩O right) s √ " (is "?ℐ ⊢c ?callee s √")
  shows "ℐr ⊢c right s √"
proof(rule WT_calleeI)
  fix call ret s'
  assume "call ∈ outs_ℐ ℐr" "(ret, s') ∈ set_spmf (right s call)"
  hence "(Inr ret, s') ∈ set_spmf (?callee s (Inr call))" "Inr call ∈ outs_ℐ (ℐl ⊕⇩ℐ ℐr)"
    by(auto intro: rev_image_eqI)
  hence "Inr ret ∈ responses_ℐ ?ℐ (Inr call)" by(rule WT_calleeD[OF assms])
  then show "ret ∈ responses_ℐ ℐr call" by(simp add: inj_image_mem_iff)
qed
lemma WT_plus_oracle_iff [simp]: "ℐl ⊕⇩ℐ ℐr ⊢c (left ⊕⇩O right) s √ ⟷ ℐl ⊢c left s √ ∧ ℐr ⊢c right s √"
by(blast dest: WT_plus_oracleD1 WT_plus_oracleD2)
lemma callee_invariant_on_plus_oracle [simp]:
  "callee_invariant_on (left ⊕⇩O right) I (ℐl ⊕⇩ℐ ℐr) ⟷
   callee_invariant_on left I ℐl ∧ callee_invariant_on right I ℐr"
   (is "?lhs ⟷ ?rhs")
proof(intro iffI conjI)
  assume ?lhs
  then interpret plus: callee_invariant_on "left ⊕⇩O right" I "ℐl ⊕⇩ℐ ℐr" .
  show "callee_invariant_on left I ℐl"
  proof
    fix s x y s'
    assume "(y, s') ∈ set_spmf (left s x)" and "I s" and "x ∈ outs_ℐ ℐl"
    then have "(Inl y, s') ∈ set_spmf ((left ⊕⇩O right) s (Inl x))"
      by(auto intro: rev_image_eqI)
    then show "I s'" using ‹I s› by(rule plus.callee_invariant)(simp add: ‹x ∈ outs_ℐ ℐl›)
  next
    show "ℐl ⊢c left s √" if "I s" for s using plus.WT_callee[OF that] by simp
  qed
  show "callee_invariant_on right I ℐr"
  proof
    fix s x y s'
    assume "(y, s') ∈ set_spmf (right s x)" and "I s" and "x ∈ outs_ℐ ℐr"
    then have "(Inr y, s') ∈ set_spmf ((left ⊕⇩O right) s (Inr x))"
      by(auto intro: rev_image_eqI)
    then show "I s'" using ‹I s› by(rule plus.callee_invariant)(simp add: ‹x ∈ outs_ℐ ℐr›)
  next
    show "ℐr ⊢c right s √" if "I s" for s using plus.WT_callee[OF that] by simp
  qed
next
  assume ?rhs
  interpret left: callee_invariant_on left I ℐl using ‹?rhs› by simp
  interpret right: callee_invariant_on right I ℐr using ‹?rhs› by simp
  show ?lhs
  proof
    fix s x y s'
    assume "(y, s') ∈ set_spmf ((left ⊕⇩O right) s x)" and "I s" and "x ∈ outs_ℐ (ℐl ⊕⇩ℐ ℐr)"
    then have "(projl y, s') ∈ set_spmf (left s (projl x)) ∧ projl x ∈ outs_ℐ ℐl ∨
      (projr y, s') ∈ set_spmf (right s (projr x)) ∧ projr x ∈ outs_ℐ ℐr"
      by (cases x)  auto
    then show "I s'" using ‹I s› 
      by (auto dest: left.callee_invariant right.callee_invariant)
  next
    show "ℐl ⊕⇩ℐ ℐr ⊢c (left ⊕⇩O right) s √" if "I s" for s 
      using left.WT_callee[OF that] right.WT_callee[OF that] by simp
  qed
qed
lemma callee_invariant_plus_oracle [simp]:
  "callee_invariant (left ⊕⇩O right) I ⟷
   callee_invariant left I ∧ callee_invariant right I"
  (is "?lhs ⟷  ?rhs")
proof -
  have "?lhs ⟷ callee_invariant_on (left ⊕⇩O right) I (ℐ_full ⊕⇩ℐ ℐ_full)"
    by(rule callee_invariant_on_cong)(auto split: plus_oracle_split_asm)
  also have "… ⟷ ?rhs" by(rule callee_invariant_on_plus_oracle)
  finally show ?thesis .
qed
lemma plus_oracle_parametric [transfer_rule]:
  includes lifting_syntax shows
  "((S ===> A ===> rel_spmf (rel_prod B S))
   ===> (S ===> C ===> rel_spmf (rel_prod D S))
   ===> S ===> rel_sum A C ===> rel_spmf (rel_prod (rel_sum B D) S))
   plus_oracle plus_oracle"
unfolding plus_oracle_def[abs_def] by transfer_prover
lemma rel_spmf_plus_oracle:
  "⟦ ⋀q1' q2'. ⟦ q1 = Inl q1'; q2 = Inl q2' ⟧ ⟹ rel_spmf (rel_prod B S) (left1 s1 q1') (left2 s2 q2');
    ⋀q1' q2'. ⟦ q1 = Inr q1'; q2 = Inr q2' ⟧ ⟹ rel_spmf (rel_prod D S) (right1 s1 q1') (right2 s2 q2');
    S s1 s2; rel_sum A C q1 q2 ⟧
  ⟹ rel_spmf (rel_prod (rel_sum B D) S) ((left1 ⊕⇩O right1) s1 q1) ((left2 ⊕⇩O right2) s2 q2)"
apply(erule rel_sum.cases; clarsimp)
 apply(erule meta_allE)+
 apply(erule meta_impE, rule refl)+
 subgoal premises [transfer_rule] by transfer_prover
apply(erule meta_allE)+
apply(erule meta_impE, rule refl)+
subgoal premises [transfer_rule] by transfer_prover
done
end
subsection ‹Shared state with aborts›
context
  fixes left :: "('s, 'a, 'b option) oracle'"
  and right :: "('s,'c, 'd option) oracle'"
  and s :: "'s"
begin
primrec plus_oracle_stop :: "'a + 'c ⇒ (('b + 'd) option × 's) spmf"
where
  "plus_oracle_stop (Inl a) = map_spmf (apfst (map_option Inl)) (left s a)"
| "plus_oracle_stop (Inr b) = map_spmf (apfst (map_option Inr)) (right s b)"
lemma lossless_plus_oracle_stopI [intro, simp]:
  "⟦ ⋀a. x = Inl a ⟹ lossless_spmf (left s a); 
     ⋀b. x = Inr b ⟹ lossless_spmf (right s b) ⟧
  ⟹ lossless_spmf (plus_oracle_stop x)"
by(cases x) simp_all
lemma plus_oracle_stop_split:
  "P (plus_oracle_stop lr) ⟷
  (∀x. lr = Inl x ⟶ P (map_spmf (apfst (map_option Inl)) (left s x))) ∧
  (∀y. lr = Inr y ⟶ P (map_spmf (apfst (map_option Inr)) (right s y)))"
by(cases lr) auto
lemma plus_oracle_stop_split_asm:
  "P (plus_oracle_stop lr) ⟷
  ¬ ((∃x. lr = Inl x ∧ ¬ P (map_spmf (apfst (map_option Inl)) (left s x))) ∨
     (∃y. lr = Inr y ∧ ¬ P (map_spmf (apfst (map_option Inr)) (right s y))))"
by(cases lr) auto
end
notation plus_oracle_stop (infix ‹⊕⇩O⇧S› 500)
subsection ‹Disjoint state›
context
  fixes left :: "('s1, 'a, 'b) oracle'"
  and right :: "('s2, 'c, 'd) oracle'"
begin
fun parallel_oracle :: "('s1 × 's2, 'a + 'c, 'b + 'd) oracle'"
where
  "parallel_oracle (s1, s2) (Inl a) = map_spmf (map_prod Inl (λs1'. (s1', s2))) (left s1 a)"
| "parallel_oracle (s1, s2) (Inr b) = map_spmf (map_prod Inr (Pair s1)) (right s2 b)"
lemma parallel_oracle_def:
  "parallel_oracle = (λ(s1, s2). case_sum (λa. map_spmf (map_prod Inl (λs1'. (s1', s2))) (left s1 a)) (λb. map_spmf (map_prod Inr (Pair s1)) (right s2 b)))"
by(auto intro!: ext split: sum.split)
lemma lossless_parallel_oracle [simp]:
  "lossless_spmf (parallel_oracle s12 xy) ⟷
   (∀x. xy = Inl x ⟶ lossless_spmf (left (fst s12) x)) ∧
   (∀y. xy = Inr y ⟶ lossless_spmf (right (snd s12) y))"
by(cases s12; cases xy) simp_all
lemma parallel_oracle_split:
  "P (parallel_oracle s1s2 lr) ⟷
  (∀s1 s2 x. s1s2 = (s1, s2) ⟶ lr = Inl x ⟶ P (map_spmf (map_prod Inl (λs1'. (s1', s2))) (left s1 x))) ∧
  (∀s1 s2 y. s1s2 = (s1, s2) ⟶ lr = Inr y ⟶ P (map_spmf (map_prod Inr (Pair s1)) (right s2 y)))"
by(cases s1s2; cases lr) auto
lemma parallel_oracle_split_asm:
  "P (parallel_oracle s1s2 lr) ⟷
  ¬ ((∃s1 s2 x. s1s2 = (s1, s2) ∧ lr = Inl x ∧ ¬ P (map_spmf (map_prod Inl (λs1'. (s1', s2))) (left s1 x))) ∨
     (∃s1 s2 y. s1s2 = (s1, s2) ∧ lr = Inr y ∧ ¬ P (map_spmf (map_prod Inr (Pair s1)) (right s2 y))))"
by(cases s1s2; cases lr) auto
lemma WT_parallel_oracle [intro!, simp]:
  "⟦ ℐl ⊢c left sl √; ℐr ⊢c right sr √ ⟧ ⟹ plus_ℐ ℐl ℐr ⊢c parallel_oracle (sl, sr) √"
by(rule WT_calleeI)(auto elim!: WT_calleeD simp add: inj_image_mem_iff)
lemma callee_invariant_parallel_oracleI [simp, intro]:
  assumes "callee_invariant_on left Il ℐl" "callee_invariant_on right Ir ℐr"
  shows "callee_invariant_on parallel_oracle (pred_prod Il Ir) (ℐl ⊕⇩ℐ ℐr)"
proof
  interpret left: callee_invariant_on left Il ℐl by fact
  interpret right: callee_invariant_on right Ir ℐr by fact
  show "pred_prod Il Ir s12'"
    if "(y, s12') ∈ set_spmf (parallel_oracle s12 x)" and "pred_prod Il Ir s12" and "x ∈ outs_ℐ (ℐl ⊕⇩ℐ ℐr)"
    for s12 x y s12' using that
    by(cases s12; cases s12; cases x)(auto dest: left.callee_invariant right.callee_invariant)
  show "ℐl ⊕⇩ℐ ℐr ⊢c local.parallel_oracle s √" if "pred_prod Il Ir s" for s using that
    by(cases s)(simp add: left.WT_callee right.WT_callee)
qed
end
lemma parallel_oracle_parametric:
  includes lifting_syntax shows
  "((S1 ===> CALL1 ===> rel_spmf (rel_prod (=) S1)) 
  ===> (S2 ===> CALL2 ===> rel_spmf (rel_prod (=) S2))
  ===> rel_prod S1 S2 ===> rel_sum CALL1 CALL2 ===> rel_spmf (rel_prod (=) (rel_prod S1 S2)))
  parallel_oracle parallel_oracle"
unfolding parallel_oracle_def[abs_def] by (fold relator_eq)transfer_prover
subsection ‹Indexed oracles›
definition family_oracle :: "('i ⇒ ('s, 'a, 'b) oracle') ⇒ ('i ⇒ 's, 'i × 'a, 'b) oracle'"
where "family_oracle f s = (λ(i, x). map_spmf (λ(y, s'). (y, s(i := s'))) (f i (s i) x))"
lemma family_oracle_apply [simp]:
  "family_oracle f s (i, x) = map_spmf (apsnd (fun_upd s i)) (f i (s i) x)"
by(simp add: family_oracle_def apsnd_def map_prod_def)
lemma lossless_family_oracle:
  "lossless_spmf (family_oracle f s ix) ⟷ lossless_spmf (f (fst ix) (s (fst ix)) (snd ix))"
by(simp add: family_oracle_def split_beta)
subsection ‹State extension›
definition extend_state_oracle :: "('call, 'ret, 's) callee ⇒ ('call, 'ret, 's' × 's) callee" (‹†_› [1000] 1000)
where "extend_state_oracle callee = (λ(s', s) x. map_spmf (λ(y, s). (y, (s', s))) (callee s x))"
lemma extend_state_oracle_simps [simp]:
  "extend_state_oracle callee (s', s) x = map_spmf (λ(y, s). (y, (s', s))) (callee s x)"
by(simp add: extend_state_oracle_def)
context includes lifting_syntax begin
lemma extend_state_oracle_parametric [transfer_rule]:
  "((S ===> C ===> rel_spmf (rel_prod R S)) ===> rel_prod S' S ===> C ===> rel_spmf (rel_prod R (rel_prod S' S)))
  extend_state_oracle extend_state_oracle"
unfolding extend_state_oracle_def[abs_def] by transfer_prover
lemma extend_state_oracle_transfer:
  "((S ===> C ===> rel_spmf (rel_prod R S)) 
  ===> rel_prod2 S ===> C ===> rel_spmf (rel_prod R (rel_prod2 S)))
  (λoracle. oracle) extend_state_oracle"
unfolding extend_state_oracle_def[abs_def]
apply(rule rel_funI)+
apply clarsimp
apply(drule (1) rel_funD)+
apply(auto simp add: spmf_rel_map split_def dest: rel_funD intro: rel_spmf_mono)
done
end
lemma callee_invariant_extend_state_oracle_const [simp]:
  "callee_invariant †oracle (λ(s', s). I s')"
by unfold_locales auto
lemma callee_invariant_extend_state_oracle_const':
  "callee_invariant †oracle (λs. I (fst s))"
by unfold_locales auto
definition lift_stop_oracle :: "('call, 'ret, 's) callee ⇒ ('call, 'ret option, 's) callee"
where "lift_stop_oracle oracle s x = map_spmf (apfst Some) (oracle s x)"
lemma lift_stop_oracle_apply [simp]: "lift_stop_oracle  oracle s x = map_spmf (apfst Some) (oracle s x)"
  by(fact lift_stop_oracle_def)
  
context includes lifting_syntax begin
lemma lift_stop_oracle_transfer:
  "((S ===> C ===> rel_spmf (rel_prod R S)) ===> (S ===> C ===> rel_spmf (rel_prod (pcr_Some R) S)))
   (λx. x) lift_stop_oracle"
unfolding lift_stop_oracle_def
apply(rule rel_funI)+
apply(drule (1) rel_funD)+
apply(simp add: spmf_rel_map apfst_def prod.rel_map)
done
end
definition extend_state_oracle2 :: "('call, 'ret, 's) callee ⇒ ('call, 'ret, 's × 's') callee" (‹_†› [1000] 1000)
  where "extend_state_oracle2 callee = (λ(s, s') x. map_spmf (λ(y, s). (y, (s, s'))) (callee s x))"
lemma extend_state_oracle2_simps [simp]:
  "extend_state_oracle2 callee (s, s') x = map_spmf (λ(y, s). (y, (s, s'))) (callee s x)"
  by(simp add: extend_state_oracle2_def)
lemma extend_state_oracle2_parametric [transfer_rule]: includes lifting_syntax shows
  "((S ===> C ===> rel_spmf (rel_prod R S)) ===> rel_prod S S' ===> C ===> rel_spmf (rel_prod R (rel_prod S S')))
  extend_state_oracle2 extend_state_oracle2"
  unfolding extend_state_oracle2_def[abs_def] by transfer_prover
lemma callee_invariant_extend_state_oracle2_const [simp]:
  "callee_invariant oracle† (λ(s, s'). I s')"
  by unfold_locales auto
lemma callee_invariant_extend_state_oracle2_const':
  "callee_invariant oracle† (λs. I (snd s))"
  by unfold_locales auto
lemma extend_state_oracle2_plus_oracle: 
  "extend_state_oracle2 (plus_oracle oracle1 oracle2) = plus_oracle (extend_state_oracle2 oracle1) (extend_state_oracle2 oracle2)"
proof((rule ext)+; goal_cases)
  case (1 s q)
  then show ?case by (cases s; cases q) (simp_all add: apfst_def spmf.map_comp o_def split_def)
qed
lemma parallel_oracle_conv_plus_oracle:
  "parallel_oracle oracle1 oracle2 = plus_oracle (oracle1†) (†oracle2)"
proof((rule ext)+; goal_cases)
  case (1 s q)
  then show ?case by (cases s; cases q) (auto simp add: spmf.map_comp apfst_def o_def split_def map_prod_def)
qed
lemma map_sum_parallel_oracle: includes lifting_syntax shows
  "(id ---> map_sum f g ---> map_spmf (map_prod (map_sum h k) id)) (parallel_oracle oracle1 oracle2)
  = parallel_oracle ((id ---> f ---> map_spmf (map_prod h id)) oracle1) ((id ---> g ---> map_spmf (map_prod k id)) oracle2)"
proof((rule ext)+; goal_cases)
  case (1 s q)
  then show ?case by (cases s; cases q) (simp_all add: spmf.map_comp o_def apfst_def prod.map_comp)
qed
lemma map_sum_plus_oracle: includes lifting_syntax shows
  "(id ---> map_sum f g ---> map_spmf (map_prod (map_sum h k) id)) (plus_oracle oracle1 oracle2)
  = plus_oracle ((id ---> f ---> map_spmf (map_prod h id)) oracle1) ((id ---> g ---> map_spmf (map_prod k id)) oracle2)"
proof((rule ext)+; goal_cases)
  case (1 s q)
  then show ?case by (cases q) (simp_all add: spmf.map_comp o_def apfst_def prod.map_comp)
qed
lemma map_rsuml_plus_oracle: includes lifting_syntax shows
  "(id ---> rsuml ---> (map_spmf (map_prod lsumr id))) (oracle1 ⊕⇩O (oracle2 ⊕⇩O oracle3)) =
   ((oracle1 ⊕⇩O oracle2) ⊕⇩O oracle3)"
proof((rule ext)+; goal_cases)
  case (1 s q)
  then show ?case 
  proof(cases q)
    case (Inl ql)
    then show ?thesis by(cases ql)(simp_all add: spmf.map_comp o_def apfst_def prod.map_comp)
  qed (simp add: spmf.map_comp o_def apfst_def prod.map_comp id_def)
qed
lemma map_lsumr_plus_oracle: includes lifting_syntax shows
  "(id ---> lsumr ---> (map_spmf (map_prod rsuml id))) ((oracle1 ⊕⇩O oracle2) ⊕⇩O oracle3) =
   (oracle1 ⊕⇩O (oracle2 ⊕⇩O oracle3))"
proof((rule ext)+; goal_cases)
  case (1 s q)
  then show ?case 
  proof(cases q)
    case (Inr qr)
    then show ?thesis by(cases qr)(simp_all add: spmf.map_comp o_def apfst_def prod.map_comp)
  qed (simp add: spmf.map_comp o_def apfst_def prod.map_comp id_def)
qed
context includes lifting_syntax begin
definition lift_state_oracle
  :: "(('s ⇒ 'a ⇒ (('b × 't) × 's) spmf) ⇒ ('s' ⇒ 'a ⇒ (('b × 't) × 's') spmf)) 
  ⇒ ('t × 's ⇒ 'a ⇒ ('b × 't × 's) spmf) ⇒ ('t × 's' ⇒ 'a ⇒ ('b × 't × 's') spmf)" where
  "lift_state_oracle F oracle = 
   (λ(t, s') a. map_spmf rprodl (F ((Pair t ---> id ---> map_spmf lprodr) oracle) s' a))"
lemma lift_state_oracle_simps [simp]:
  "lift_state_oracle F oracle (t, s') a = map_spmf rprodl (F ((Pair t ---> id ---> map_spmf lprodr) oracle) s' a)"
  by(simp add: lift_state_oracle_def)
lemma lift_state_oracle_parametric [transfer_rule]: includes lifting_syntax shows
  "(((S ===> A ===> rel_spmf (rel_prod (rel_prod B T) S)) ===> S' ===> A ===> rel_spmf (rel_prod (rel_prod B T) S'))
  ===> (rel_prod T S ===> A ===> rel_spmf (rel_prod B (rel_prod T S)))
  ===> rel_prod T S' ===> A ===> rel_spmf (rel_prod B (rel_prod T S')))
  lift_state_oracle lift_state_oracle"
  unfolding lift_state_oracle_def map_fun_def o_def by transfer_prover
lemma lift_state_oracle_extend_state_oracle:
  includes lifting_syntax
  assumes "⋀B. Transfer.Rel (((=) ===> (=) ===> rel_spmf (rel_prod B (=))) ===> (=) ===> (=) ===> rel_spmf (rel_prod B (=))) G F"
    
  shows "lift_state_oracle F (extend_state_oracle oracle) = extend_state_oracle (G oracle)"
  unfolding lift_state_oracle_def extend_state_oracle_def
  apply(clarsimp simp add: fun_eq_iff map_fun_def o_def spmf.map_comp split_def rprodl_def)
  subgoal for t s a
    apply(rule sym)
    apply(fold spmf_rel_eq)
    apply(simp add: spmf_rel_map)
    apply(rule rel_spmf_mono)
     apply(rule assms[unfolded Rel_def, where B="λx (y, z). x = y ∧ z = t", THEN rel_funD, THEN rel_funD, THEN rel_funD])
       apply(auto simp add: rel_fun_def spmf_rel_map intro!: rel_spmf_reflI)
    done
  done
lemma lift_state_oracle_compose: 
  "lift_state_oracle F (lift_state_oracle G oracle) = lift_state_oracle (F ∘ G) oracle"
  by(simp add: lift_state_oracle_def map_fun_def o_def split_def spmf.map_comp)
lemma lift_state_oracle_id [simp]: "lift_state_oracle id = id"
  by(simp add: fun_eq_iff spmf.map_comp o_def)
lemma rprodl_extend_state_oracle: includes lifting_syntax shows
  "(rprodl ---> id ---> map_spmf (map_prod id lprodr)) (extend_state_oracle (extend_state_oracle oracle)) = 
  extend_state_oracle oracle"
  by(simp add: fun_eq_iff spmf.map_comp o_def split_def)
end
section ‹Combining GPVs›
subsection ‹Shared state without interrupts›
context
  fixes left :: "'s ⇒ 'x1 ⇒ ('y1 × 's, 'call, 'ret) gpv"
  and right :: "'s ⇒ 'x2 ⇒ ('y2 × 's, 'call, 'ret) gpv"
begin
primrec plus_intercept :: "'s ⇒ 'x1 + 'x2 ⇒ (('y1 + 'y2) × 's, 'call, 'ret) gpv"
where
  "plus_intercept s (Inl x) = map_gpv (apfst Inl) id (left s x)"
| "plus_intercept s (Inr x) = map_gpv (apfst Inr) id (right s x)"
end
lemma plus_intercept_parametric [transfer_rule]:
  includes lifting_syntax shows
  "((S ===> X1 ===> rel_gpv (rel_prod Y1 S) C)
  ===> (S ===> X2 ===> rel_gpv (rel_prod Y2 S) C)
  ===> S ===> rel_sum X1 X2 ===> rel_gpv (rel_prod (rel_sum Y1 Y2) S) C)
  plus_intercept plus_intercept"
unfolding plus_intercept_def[abs_def] by transfer_prover
lemma interaction_bounded_by_plus_intercept [interaction_bound]:
  fixes left right
  shows "⟦ ⋀x'. x = Inl x' ⟹ interaction_bounded_by P (left s x') (n x');
    ⋀y. x = Inr y ⟹ interaction_bounded_by P (right s y) (m y) ⟧
  ⟹ interaction_bounded_by P (plus_intercept left right s x) (case x of Inl x ⇒ n x | Inr y ⇒ m y)"
by(simp split!: sum.split add: interaction_bounded_by_map_gpv_id)
subsection ‹Shared state with interrupts›
context 
  fixes left :: "'s ⇒ 'x1 ⇒ ('y1 option × 's, 'call, 'ret) gpv"
  and right :: "'s ⇒ 'x2 ⇒ ('y2 option × 's, 'call, 'ret) gpv"
begin
primrec plus_intercept_stop :: "'s ⇒ 'x1 + 'x2 ⇒ (('y1 + 'y2) option × 's, 'call, 'ret) gpv"
where
  "plus_intercept_stop s (Inl x) = map_gpv (apfst (map_option Inl)) id (left s x)"
| "plus_intercept_stop s (Inr x) = map_gpv (apfst (map_option Inr)) id (right s x)"
end
lemma plus_intercept_stop_parametric [transfer_rule]:
  includes lifting_syntax shows
  "((S ===> X1 ===> rel_gpv (rel_prod (rel_option Y1) S) C)
  ===> (S ===> X2 ===> rel_gpv (rel_prod (rel_option Y2) S) C)
  ===> S ===> rel_sum X1 X2 ===> rel_gpv (rel_prod (rel_option (rel_sum Y1 Y2)) S) C)
  plus_intercept_stop plus_intercept_stop"
unfolding plus_intercept_stop_def by transfer_prover
subsection ‹One-sided shifts›
primcorec (transfer) left_gpv :: "('a, 'out, 'in) gpv ⇒ ('a, 'out + 'out', 'in + 'in') gpv" where
  "the_gpv (left_gpv gpv) = 
   map_spmf (map_generat id Inl (λrpv input. case input of Inl input' ⇒ left_gpv (rpv input') | _ ⇒ Fail)) (the_gpv gpv)"
abbreviation left_rpv :: "('a, 'out, 'in) rpv ⇒ ('a, 'out + 'out', 'in + 'in') rpv" where
  "left_rpv rpv ≡ λinput. case input of Inl input' ⇒ left_gpv (rpv input') | _ ⇒ Fail"
primcorec (transfer) right_gpv :: "('a, 'out, 'in) gpv ⇒ ('a, 'out' + 'out, 'in' + 'in) gpv" where
  "the_gpv (right_gpv gpv) =
   map_spmf (map_generat id Inr (λrpv input. case input of Inr input' ⇒ right_gpv (rpv input') | _ ⇒ Fail)) (the_gpv gpv)"
abbreviation right_rpv :: "('a, 'out, 'in) rpv ⇒ ('a, 'out' + 'out, 'in' + 'in) rpv" where
  "right_rpv rpv ≡ λinput. case input of Inr input' ⇒ right_gpv (rpv input') | _ ⇒ Fail"
context 
  includes lifting_syntax
  notes [transfer_rule] = corec_gpv_parametric' Fail_parametric' the_gpv_parametric'
begin
lemmas left_gpv_parametric = left_gpv.transfer
lemma left_gpv_parametric':
  "(rel_gpv'' A C R ===> rel_gpv'' A (rel_sum C C') (rel_sum R R')) left_gpv left_gpv"
  unfolding left_gpv_def by transfer_prover
lemmas right_gpv_parametric = right_gpv.transfer
lemma right_gpv_parametric':
  "(rel_gpv'' A C' R' ===> rel_gpv'' A (rel_sum C C') (rel_sum R R')) right_gpv right_gpv"
  unfolding right_gpv_def by transfer_prover
end
lemma left_gpv_Done [simp]: "left_gpv (Done x) = Done x"
  by(rule gpv.expand) simp
lemma right_gpv_Done [simp]: "right_gpv (Done x) = Done x"
  by(rule gpv.expand) simp
lemma left_gpv_Pause [simp]:
  "left_gpv (Pause x rpv) = Pause (Inl x) (λinput. case input of Inl input' ⇒ left_gpv (rpv input') | _ ⇒ Fail)"
  by(rule gpv.expand) simp
lemma right_gpv_Pause [simp]:
  "right_gpv (Pause x rpv) = Pause (Inr x) (λinput. case input of Inr input' ⇒ right_gpv (rpv input') | _ ⇒ Fail)"
  by(rule gpv.expand) simp
lemma left_gpv_map: "left_gpv (map_gpv f g gpv) = map_gpv f (map_sum g h) (left_gpv gpv)"
  using left_gpv.transfer[of "BNF_Def.Grp UNIV f" "BNF_Def.Grp UNIV g" "BNF_Def.Grp UNIV h"]
  unfolding sum.rel_Grp gpv.rel_Grp
  by(auto simp add: rel_fun_def Grp_def)
lemma right_gpv_map: "right_gpv (map_gpv f g gpv) = map_gpv f (map_sum h g) (right_gpv gpv)"
  using right_gpv.transfer[of "BNF_Def.Grp UNIV f" "BNF_Def.Grp UNIV g" "BNF_Def.Grp UNIV h"]
  unfolding sum.rel_Grp gpv.rel_Grp
  by(auto simp add: rel_fun_def Grp_def)
lemma results'_gpv_left_gpv [simp]: 
  "results'_gpv (left_gpv gpv :: ('a, 'out + 'out', 'in + 'in') gpv) = results'_gpv gpv" (is "?lhs = ?rhs")
proof(rule Set.set_eqI iffI)+
  show "x ∈ ?rhs" if "x ∈ ?lhs" for x using that
    by(induction gpv'≡"left_gpv gpv :: ('a, 'out + 'out', 'in + 'in') gpv" arbitrary: gpv)
      (fastforce simp add: elim!: generat.set_cases intro: results'_gpvI split: sum.splits)+
  show "x ∈ ?lhs" if "x ∈ ?rhs" for x using that
    by(induction)
      (auto 4 3 elim!: generat.set_cases intro: results'_gpv_Pure rev_image_eqI results'_gpv_Cont[where input="Inl _"])
qed
lemma results'_gpv_right_gpv [simp]: 
  "results'_gpv (right_gpv gpv :: ('a, 'out' + 'out, 'in' + 'in) gpv) = results'_gpv gpv" (is "?lhs = ?rhs")
proof(rule Set.set_eqI iffI)+
  show "x ∈ ?rhs" if "x ∈ ?lhs" for x using that
    by(induction gpv'≡"right_gpv gpv :: ('a, 'out' + 'out, 'in' + 'in) gpv" arbitrary: gpv)
      (fastforce simp add: elim!: generat.set_cases intro: results'_gpvI split: sum.splits)+
  show "x ∈ ?lhs" if "x ∈ ?rhs" for x using that
    by(induction)
      (auto 4 3 elim!: generat.set_cases intro: results'_gpv_Pure rev_image_eqI results'_gpv_Cont[where input="Inr _"])
qed
lemma left_gpv_Inl_transfer: "rel_gpv'' (=) (λl r. l = Inl r) (λl r. l = Inl r) (left_gpv gpv) gpv"
  by(coinduction arbitrary: gpv)
    (auto simp add: spmf_rel_map generat.rel_map del: rel_funI intro!: rel_spmf_reflI generat.rel_refl_strong rel_funI)
lemma right_gpv_Inr_transfer: "rel_gpv'' (=) (λl r. l = Inr r) (λl r. l = Inr r) (right_gpv gpv) gpv"
  by(coinduction arbitrary: gpv)
    (auto simp add: spmf_rel_map generat.rel_map del: rel_funI intro!: rel_spmf_reflI generat.rel_refl_strong rel_funI)
lemma exec_gpv_plus_oracle_left: "exec_gpv (plus_oracle oracle1 oracle2) (left_gpv gpv) s = exec_gpv oracle1 gpv s"
  unfolding spmf_rel_eq[symmetric] prod.rel_eq[symmetric]
  by(rule exec_gpv_parametric'[where A="(=)" and S="(=)" and CALL="λl r. l = Inl r" and R="λl r. l = Inl r", THEN rel_funD, THEN rel_funD, THEN rel_funD])
    (auto intro!: rel_funI simp add: spmf_rel_map apfst_def map_prod_def rel_prod_conv intro: rel_spmf_reflI left_gpv_Inl_transfer)
lemma exec_gpv_plus_oracle_right: "exec_gpv (plus_oracle oracle1 oracle2) (right_gpv gpv) s = exec_gpv oracle2 gpv s"
  unfolding spmf_rel_eq[symmetric] prod.rel_eq[symmetric]
  by(rule exec_gpv_parametric'[where A="(=)" and S="(=)" and CALL="λl r. l = Inr r" and R="λl r. l = Inr r", THEN rel_funD, THEN rel_funD, THEN rel_funD])
    (auto intro!: rel_funI simp add: spmf_rel_map apfst_def map_prod_def rel_prod_conv intro: rel_spmf_reflI right_gpv_Inr_transfer)
lemma left_gpv_bind_gpv: "left_gpv (bind_gpv gpv f) = bind_gpv (left_gpv gpv) (left_gpv ∘ f)"
  by(coinduction arbitrary:gpv f rule: gpv.coinduct_strong)
    (auto 4 4 simp add: bind_map_spmf spmf_rel_map intro!: rel_spmf_reflI rel_spmf_bindI[of "(=)"] generat.rel_refl rel_funI split: sum.splits)
lemma inline1_left_gpv:
  "inline1 (λs q. left_gpv (callee s q)) gpv s = 
   map_spmf (map_sum id (map_prod Inl (map_prod left_rpv id))) (inline1 callee gpv s)"
proof(induction arbitrary: gpv s rule: parallel_fixp_induct_2_2[OF partial_function_definitions_spmf partial_function_definitions_spmf inline1.mono inline1.mono inline1_def inline1_def, unfolded lub_spmf_empty, case_names adm bottom step])
  case adm show ?case by simp
  case bottom show ?case by simp
  case (step inline1' inline1'')
  then show ?case
    by(auto simp add: map_spmf_bind_spmf o_def bind_map_spmf intro!: ext bind_spmf_cong split: generat.split)
qed
lemma left_gpv_inline: "left_gpv (inline callee gpv s) = inline (λs q. left_gpv (callee s q)) gpv s"
  by(coinduction arbitrary: callee gpv s rule: gpv_coinduct_bind)
    (fastforce simp add: inline_sel spmf_rel_map inline1_left_gpv left_gpv_bind_gpv o_def split_def intro!: rel_spmf_reflI split: sum.split intro!: rel_funI gpv.rel_refl_strong)
lemma right_gpv_bind_gpv: "right_gpv (bind_gpv gpv f) = bind_gpv (right_gpv gpv) (right_gpv ∘ f)"
  by(coinduction arbitrary:gpv f rule: gpv.coinduct_strong)
    (auto 4 4 simp add: bind_map_spmf spmf_rel_map intro!: rel_spmf_reflI rel_spmf_bindI[of "(=)"] generat.rel_refl rel_funI split: sum.splits)
lemma inline1_right_gpv:
  "inline1 (λs q. right_gpv (callee s q)) gpv s = 
   map_spmf (map_sum id (map_prod Inr (map_prod right_rpv id))) (inline1 callee gpv s)"
proof(induction arbitrary: gpv s rule: parallel_fixp_induct_2_2[OF partial_function_definitions_spmf partial_function_definitions_spmf inline1.mono inline1.mono inline1_def inline1_def, unfolded lub_spmf_empty, case_names adm bottom step])
  case adm show ?case by simp
  case bottom show ?case by simp
  case (step inline1' inline1'')
  then show ?case
    by(auto simp add: map_spmf_bind_spmf o_def bind_map_spmf intro!: ext bind_spmf_cong split: generat.split)
qed
lemma right_gpv_inline: "right_gpv (inline callee gpv s) = inline (λs q. right_gpv (callee s q)) gpv s"
  by(coinduction arbitrary: callee gpv s rule: gpv_coinduct_bind)
    (fastforce simp add: inline_sel spmf_rel_map inline1_right_gpv right_gpv_bind_gpv o_def split_def intro!: rel_spmf_reflI split: sum.split intro!: rel_funI gpv.rel_refl_strong)
lemma WT_gpv_left_gpv: "ℐ1 ⊢g gpv √ ⟹ ℐ1 ⊕⇩ℐ ℐ2 ⊢g left_gpv gpv √"
  by(coinduction arbitrary: gpv)(auto 4 4 dest: WT_gpvD)
lemma WT_gpv_right_gpv: "ℐ2 ⊢g gpv √ ⟹ ℐ1 ⊕⇩ℐ ℐ2 ⊢g right_gpv gpv √"
  by(coinduction arbitrary: gpv)(auto 4 4 dest: WT_gpvD)
lemma results_gpv_left_gpv [simp]: "results_gpv (ℐ1 ⊕⇩ℐ ℐ2) (left_gpv gpv) = results_gpv ℐ1 gpv"
  (is "?lhs = ?rhs")
proof(rule Set.set_eqI iffI)+
  show "x ∈ ?rhs" if "x ∈ ?lhs" for x using that
    by(induction gpv'≡"left_gpv gpv :: ('a, 'b + 'c, 'd + 'e) gpv" arbitrary: gpv rule: results_gpv.induct)
      (fastforce intro: results_gpv.intros)+
  show "x ∈ ?lhs" if "x ∈ ?rhs" for x using that
    by(induction)(fastforce intro: results_gpv.intros)+
qed
lemma results_gpv_right_gpv [simp]: "results_gpv (ℐ1 ⊕⇩ℐ ℐ2) (right_gpv gpv) = results_gpv ℐ2 gpv"
  (is "?lhs = ?rhs")
proof(rule Set.set_eqI iffI)+
  show "x ∈ ?rhs" if "x ∈ ?lhs" for x using that
    by(induction gpv'≡"right_gpv gpv :: ('a, 'b + 'c, 'd + 'e) gpv" arbitrary: gpv rule: results_gpv.induct)
      (fastforce intro: results_gpv.intros)+
  show "x ∈ ?lhs" if "x ∈ ?rhs" for x using that
    by(induction)(fastforce intro: results_gpv.intros)+
qed
lemma left_gpv_Fail [simp]: "left_gpv Fail = Fail"
  by(rule gpv.expand) auto
lemma right_gpv_Fail [simp]: "right_gpv Fail = Fail"
  by(rule gpv.expand) auto
lemma rsuml_lsumr_left_gpv_left_gpv:"map_gpv' id rsuml lsumr (left_gpv (left_gpv gpv)) = left_gpv gpv"
  by(coinduction arbitrary: gpv)
    (auto 4 3 simp add: spmf_rel_map generat.rel_map intro!: rel_spmf_reflI rel_generat_reflI rel_funI split!: sum.split elim!: lsumr.elims intro: exI[where x=Fail])
lemma rsuml_lsumr_left_gpv_right_gpv: "map_gpv' id rsuml lsumr (left_gpv (right_gpv gpv)) = right_gpv (left_gpv gpv)"
  by(coinduction arbitrary: gpv)
    (auto 4 3 simp add: spmf_rel_map generat.rel_map intro!: rel_spmf_reflI rel_generat_reflI rel_funI split!: sum.split elim!: lsumr.elims intro: exI[where x=Fail])
lemma rsuml_lsumr_right_gpv: "map_gpv' id rsuml lsumr (right_gpv gpv) = right_gpv (right_gpv gpv)"
  by(coinduction arbitrary: gpv)
    (auto 4 3 simp add: spmf_rel_map generat.rel_map intro!: rel_spmf_reflI rel_generat_reflI rel_funI split!: sum.split elim!: lsumr.elims intro: exI[where x=Fail])
lemma map_gpv'_map_gpv_swap:
  "map_gpv' f g h (map_gpv f' id gpv) = map_gpv (f ∘ f') id (map_gpv' id g h gpv)"
  by(simp add: map_gpv_conv_map_gpv' map_gpv'_comp)
lemma lsumr_rsuml_left_gpv: "map_gpv' id lsumr rsuml (left_gpv gpv) = left_gpv (left_gpv gpv)"
  by(coinduction arbitrary: gpv)
    (auto 4 3 simp add: spmf_rel_map generat.rel_map intro!: rel_spmf_reflI rel_generat_reflI rel_funI split!: sum.split intro: exI[where x=Fail])
lemma lsumr_rsuml_right_gpv_left_gpv:
  "map_gpv' id lsumr rsuml (right_gpv (left_gpv gpv)) = left_gpv (right_gpv gpv)"
  by(coinduction arbitrary: gpv)
    (auto 4 3 simp add: spmf_rel_map generat.rel_map intro!: rel_spmf_reflI rel_generat_reflI rel_funI split!: sum.split intro: exI[where x=Fail])
lemma lsumr_rsuml_right_gpv_right_gpv:
  "map_gpv' id lsumr rsuml (right_gpv (right_gpv gpv)) = right_gpv gpv"
  by(coinduction arbitrary: gpv)
    (auto 4 3 simp add: spmf_rel_map generat.rel_map intro!: rel_spmf_reflI rel_generat_reflI rel_funI split!: sum.split elim!: rsuml.elims intro: exI[where x=Fail])
lemma in_set_spmf_extend_state_oracle [simp]:
  "x ∈ set_spmf (extend_state_oracle oracle s y) ⟷
   fst (snd x) = fst s ∧ (fst x, snd (snd x)) ∈ set_spmf (oracle (snd s) y)"
  by(auto 4 4 simp add: extend_state_oracle_def split_beta intro: rev_image_eqI prod.expand)
lemma extend_state_oracle_plus_oracle: 
  "extend_state_oracle (plus_oracle oracle1 oracle2) = plus_oracle (extend_state_oracle oracle1) (extend_state_oracle oracle2)"
proof ((rule ext)+; goal_cases)
  case (1 s q)
  then show ?case by (cases s; cases q) (simp_all add: apfst_def spmf.map_comp o_def split_def)
qed
definition stateless_callee :: "('a ⇒ ('b, 'out, 'in) gpv) ⇒ ('s ⇒ 'a ⇒ ('b × 's, 'out, 'in) gpv)" where
  "stateless_callee callee s = map_gpv (λb. (b, s)) id ∘ callee"
lemma stateless_callee_parametric': 
  includes lifting_syntax notes [transfer_rule] = map_gpv_parametric' shows
    "((A ===> rel_gpv'' B C R) ===> S ===> A ===> (rel_gpv'' (rel_prod B S) C R))
   stateless_callee stateless_callee"
  unfolding stateless_callee_def by transfer_prover
lemma id_oralce_alt_def: "id_oracle = stateless_callee (λx. Pause x Done)"
  by(simp add: id_oracle_def fun_eq_iff stateless_callee_def)
context
  fixes left :: "'s1 ⇒ 'x1 ⇒ ('y1 × 's1, 'call1, 'ret1) gpv"
    and right :: "'s2 ⇒ 'x2 ⇒ ('y2 × 's2, 'call2, 'ret2) gpv"
begin
fun parallel_intercept :: "'s1 × 's2 ⇒ 'x1 + 'x2 ⇒ (('y1 + 'y2) × ('s1 × 's2), 'call1 + 'call2, 'ret1 + 'ret2) gpv"
  where
    "parallel_intercept (s1, s2) (Inl a) = left_gpv (map_gpv (map_prod Inl (λs1'. (s1', s2))) id (left s1 a))"
  | "parallel_intercept (s1, s2) (Inr b) = right_gpv (map_gpv (map_prod Inr (Pair s1)) id (right s2 b))"
end
end