Theory GPV_Bisim
theory GPV_Bisim imports
GPV_Expectation
begin
subsection ‹Bisimulation for oracles›
text ‹Bisimulation is a consequence of parametricity›
lemma exec_gpv_oracle_bisim':
assumes *: "X s1 s2"
and bisim: "⋀s1 s2 x. X s1 s2 ⟹ rel_spmf (λ(a, s1') (b, s2'). a = b ∧ X s1' s2') (oracle1 s1 x) (oracle2 s2 x)"
shows "rel_spmf (λ(a, s1') (b, s2'). a = b ∧ X s1' s2') (exec_gpv oracle1 gpv s1) (exec_gpv oracle2 gpv s2)"
by(rule exec_gpv_parametric[of X "(=)" "(=)", unfolded gpv.rel_eq rel_prod_conv, THEN rel_funD, THEN rel_funD, THEN rel_funD, OF rel_funI refl, OF rel_funI *])(simp add: bisim)
lemma exec_gpv_oracle_bisim:
assumes *: "X s1 s2"
and bisim: "⋀s1 s2 x. X s1 s2 ⟹ rel_spmf (λ(a, s1') (b, s2'). a = b ∧ X s1' s2') (oracle1 s1 x) (oracle2 s2 x)"
and R: "⋀x s1' s2'. ⟦ X s1' s2'; (x, s1') ∈ set_spmf (exec_gpv oracle1 gpv s1); (x, s2') ∈ set_spmf (exec_gpv oracle2 gpv s2) ⟧ ⟹ R (x, s1') (x, s2')"
shows "rel_spmf R (exec_gpv oracle1 gpv s1) (exec_gpv oracle2 gpv s2)"
apply(rule spmf_rel_mono_strong)
apply(rule exec_gpv_oracle_bisim'[OF * bisim])
apply(auto dest: R)
done
lemma run_gpv_oracle_bisim:
assumes "X s1 s2"
and "⋀s1 s2 x. X s1 s2 ⟹ rel_spmf (λ(a, s1') (b, s2'). a = b ∧ X s1' s2') (oracle1 s1 x) (oracle2 s2 x)"
shows "run_gpv oracle1 gpv s1 = run_gpv oracle2 gpv s2"
using exec_gpv_oracle_bisim'[OF assms]
by(fold spmf_rel_eq)(fastforce simp add: spmf_rel_map intro: rel_spmf_mono)
context
fixes joint_oracle :: "('s1 × 's2) ⇒ 'a ⇒ (('b × 's1) × ('b × 's2)) spmf"
and oracle1 :: "'s1 ⇒ 'a ⇒ ('b × 's1) spmf"
and bad1 :: "'s1 ⇒ bool"
and oracle2 :: "'s2 ⇒ 'a ⇒ ('b × 's2) spmf"
and bad2 :: "'s2 ⇒ bool"
begin
partial_function (spmf) exec_until_bad :: "('x, 'a, 'b) gpv ⇒ 's1 ⇒ 's2 ⇒ (('x × 's1) × ('x × 's2)) spmf"
where
"exec_until_bad gpv s1 s2 =
(if bad1 s1 ∨ bad2 s2 then pair_spmf (exec_gpv oracle1 gpv s1) (exec_gpv oracle2 gpv s2)
else bind_spmf (the_gpv gpv) (λgenerat.
case generat of Pure x ⇒ return_spmf ((x, s1), (x, s2))
| IO out f ⇒ bind_spmf (joint_oracle (s1, s2) out) (λ((x, s1'), (y, s2')).
if bad1 s1' ∨ bad2 s2' then pair_spmf (exec_gpv oracle1 (f x) s1') (exec_gpv oracle2 (f y) s2')
else exec_until_bad (f x) s1' s2')))"
lemma exec_until_bad_fixp_induct [case_names adm bottom step]:
assumes "ccpo.admissible (fun_lub lub_spmf) (fun_ord (ord_spmf (=))) (λf. P (λgpv s1 s2. f ((gpv, s1), s2)))"
and "P (λ_ _ _. return_pmf None)"
and "⋀exec_until_bad'. P exec_until_bad' ⟹
P (λgpv s1 s2. if bad1 s1 ∨ bad2 s2 then pair_spmf (exec_gpv oracle1 gpv s1) (exec_gpv oracle2 gpv s2)
else bind_spmf (the_gpv gpv) (λgenerat.
case generat of Pure x ⇒ return_spmf ((x, s1), (x, s2))
| IO out f ⇒ bind_spmf (joint_oracle (s1, s2) out) (λ((x, s1'), (y, s2')).
if bad1 s1' ∨ bad2 s2' then pair_spmf (exec_gpv oracle1 (f x) s1') (exec_gpv oracle2 (f y) s2')
else exec_until_bad' (f x) s1' s2')))"
shows "P exec_until_bad"
using assms by(rule exec_until_bad.fixp_induct[unfolded curry_conv[abs_def]])
end
lemma exec_gpv_oracle_bisim_bad_plossless:
fixes s1 :: 's1 and s2 :: 's2 and X :: "'s1 ⇒ 's2 ⇒ bool"
and oracle1 :: "'s1 ⇒ 'a ⇒ ('b × 's1) spmf"
and oracle2 :: "'s2 ⇒ 'a ⇒ ('b × 's2) spmf"
assumes *: "if bad2 s2 then X_bad s1 s2 else X s1 s2"
and bad: "bad1 s1 = bad2 s2"
and bisim: "⋀s1 s2 x. ⟦ X s1 s2; x ∈ outs_ℐ ℐ ⟧ ⟹ rel_spmf (λ(a, s1') (b, s2'). bad1 s1' = bad2 s2' ∧ (if bad2 s2' then X_bad s1' s2' else a = b ∧ X s1' s2')) (oracle1 s1 x) (oracle2 s2 x)"
and bad_sticky1: "⋀s2. bad2 s2 ⟹ callee_invariant_on oracle1 (λs1. bad1 s1 ∧ X_bad s1 s2) ℐ"
and bad_sticky2: "⋀s1. bad1 s1 ⟹ callee_invariant_on oracle2 (λs2. bad2 s2 ∧ X_bad s1 s2) ℐ"
and lossless1: "⋀s1 x. ⟦ bad1 s1; x ∈ outs_ℐ ℐ ⟧ ⟹ lossless_spmf (oracle1 s1 x)"
and lossless2: "⋀s2 x. ⟦ bad2 s2; x ∈ outs_ℐ ℐ ⟧ ⟹ lossless_spmf (oracle2 s2 x)"
and lossless: "plossless_gpv ℐ gpv"
and WT_oracle1: "⋀s1. ℐ ⊢c oracle1 s1 √"
and WT_oracle2: "⋀s2. ℐ ⊢c oracle2 s2 √"
and WT_gpv: "ℐ ⊢g gpv √"
shows "rel_spmf (λ(a, s1') (b, s2'). bad1 s1' = bad2 s2' ∧ (if bad2 s2' then X_bad s1' s2' else a = b ∧ X s1' s2')) (exec_gpv oracle1 gpv s1) (exec_gpv oracle2 gpv s2)"
(is "rel_spmf ?R ?p ?q")
proof -
let ?R' = "λ(a, s1') (b, s2'). bad1 s1' = bad2 s2' ∧ (if bad2 s2' then X_bad s1' s2' else a = b ∧ X s1' s2')"
from bisim have "∀s1 s2. ∀x ∈ outs_ℐ ℐ. X s1 s2 ⟶ rel_spmf ?R' (oracle1 s1 x) (oracle2 s2 x)" by blast
then obtain joint_oracle
where oracle1 [symmetric]: "⋀s1 s2 x. ⟦ X s1 s2; x ∈ outs_ℐ ℐ ⟧ ⟹ map_spmf fst (joint_oracle s1 s2 x) = oracle1 s1 x"
and oracle2 [symmetric]: "⋀s1 s2 x. ⟦ X s1 s2; x ∈ outs_ℐ ℐ ⟧ ⟹ map_spmf snd (joint_oracle s1 s2 x) = oracle2 s2 x"
and 3 [rotated 2]: "⋀s1 s2 x y y' s1' s2'. ⟦ X s1 s2; x ∈ outs_ℐ ℐ; ((y, s1'), (y', s2')) ∈ set_spmf (joint_oracle s1 s2 x) ⟧
⟹ bad1 s1' = bad2 s2' ∧ (if bad2 s2' then X_bad s1' s2' else y = y' ∧ X s1' s2')"
apply atomize_elim
apply(unfold rel_spmf_simps all_conj_distrib[symmetric] all_simps(6) imp_conjR[symmetric])
apply(subst choice_iff[symmetric] ex_simps(6))+
apply fastforce
done
let ?joint_oracle = "λ(s1, s2). joint_oracle s1 s2"
let ?pq = "exec_until_bad ?joint_oracle oracle1 bad1 oracle2 bad2 gpv s1 s2"
have setD: "⋀s1 s2 x y y' s1' s2'. ⟦ X s1 s2; x ∈ outs_ℐ ℐ; ((y, s1'), (y', s2')) ∈ set_spmf (joint_oracle s1 s2 x) ⟧
⟹ (y, s1') ∈ set_spmf (oracle1 s1 x) ∧ (y', s2') ∈ set_spmf (oracle2 s2 x)"
unfolding oracle1 oracle2 by(auto intro: rev_image_eqI)
show ?thesis
proof
show "map_spmf fst ?pq = exec_gpv oracle1 gpv s1"
proof(rule spmf.leq_antisym)
show "ord_spmf (=) (map_spmf fst ?pq) (exec_gpv oracle1 gpv s1)" using * bad WT_gpv lossless
proof(induction arbitrary: s1 s2 gpv rule: exec_until_bad_fixp_induct)
case adm show ?case by simp
case bottom show ?case by simp
case (step exec_until_bad')
show ?case
proof(cases "bad2 s2")
case True
then have "weight_spmf (exec_gpv oracle2 gpv s2) = 1"
using callee_invariant_on.weight_exec_gpv[OF bad_sticky2 lossless2, of s1 gpv s2]
step.prems weight_spmf_le_1[of "exec_gpv oracle2 gpv s2"]
by(simp add: pgen_lossless_gpv_def weight_gpv_def)
then show ?thesis using True by simp
next
case False
hence "¬ bad1 s1" using step.prems(2) by simp
moreover {
fix out c r1 s1' r2 s2'
assume IO: "IO out c ∈ set_spmf (the_gpv gpv)"
and joint: "((r1, s1'), (r2, s2')) ∈ set_spmf (joint_oracle s1 s2 out)"
from step.prems(3) IO have out: "out ∈ outs_ℐ ℐ" by(rule WT_gpvD)
from setD[OF _ out joint] step.prems(1) False
have 1: "(r1, s1') ∈ set_spmf (oracle1 s1 out)"
and 2: "(r2, s2') ∈ set_spmf (oracle2 s2 out)" by simp_all
hence r1: "r1 ∈ responses_ℐ ℐ out" and r2: "r2 ∈ responses_ℐ ℐ out"
using WT_oracle1 WT_oracle2 out by(blast dest: WT_calleeD)+
have *: "plossless_gpv ℐ (c r2)" using step.prems(4) IO r2 step.prems(3)
by(rule plossless_gpv_ContD)
then have "bad2 s2' ⟹ weight_spmf (exec_gpv oracle2 (c r2) s2') = 1"
and "¬ bad2 s2' ⟹ ord_spmf (=) (map_spmf fst (exec_until_bad' (c r2) s1' s2')) (exec_gpv oracle1 (c r2) s1')"
using callee_invariant_on.weight_exec_gpv[OF bad_sticky2 lossless2, of s1' "c r2" s2']
weight_spmf_le_1[of "exec_gpv oracle2 (c r2) s2'"] WT_gpv_ContD[OF step.prems(3) IO r2]
3[OF joint _ out] step.prems(1) False
by(simp_all add: pgen_lossless_gpv_def weight_gpv_def step.IH) }
ultimately show ?thesis using False step.prems(1)
by(rewrite in "ord_spmf _ _ ⌑" exec_gpv.simps)
(fastforce simp add: split_def bind_map_spmf map_spmf_bind_spmf oracle1 WT_gpv_OutD[OF step.prems(3)] intro!: ord_spmf_bind_reflI split!: generat.split dest: 3)
qed
qed
show "ord_spmf (=) (exec_gpv oracle1 gpv s1) (map_spmf fst ?pq)" using * bad WT_gpv lossless
proof(induction arbitrary: gpv s1 s2 rule: exec_gpv_fixp_induct_strong)
case adm show ?case by simp
case bottom show ?case by simp
case (step exec_gpv')
then show ?case
proof(cases "bad2 s2")
case True
then have "weight_spmf (exec_gpv oracle2 gpv s2) = 1"
using callee_invariant_on.weight_exec_gpv[OF bad_sticky2 lossless2, of s1 gpv s2]
step.prems weight_spmf_le_1[of "exec_gpv oracle2 gpv s2"]
by(simp add: pgen_lossless_gpv_def weight_gpv_def)
then show ?thesis using True
by(rewrite exec_until_bad.simps; rewrite exec_gpv.simps)
(clarsimp intro!: ord_spmf_bind_reflI split!: generat.split simp add: step.hyps)
next
case False
hence "¬ bad1 s1" using step.prems(2) by simp
moreover {
fix out c r1 s1' r2 s2'
assume IO: "IO out c ∈ set_spmf (the_gpv gpv)"
and joint: "((r1, s1'), (r2, s2')) ∈ set_spmf (joint_oracle s1 s2 out)"
from step.prems(3) IO have out: "out ∈ outs_ℐ ℐ" by(rule WT_gpvD)
from setD[OF _ out joint] step.prems(1) False
have 1: "(r1, s1') ∈ set_spmf (oracle1 s1 out)"
and 2: "(r2, s2') ∈ set_spmf (oracle2 s2 out)" by simp_all
hence r1: "r1 ∈ responses_ℐ ℐ out" and r2: "r2 ∈ responses_ℐ ℐ out"
using WT_oracle1 WT_oracle2 out by(blast dest: WT_calleeD)+
have *: "plossless_gpv ℐ (c r2)" using step.prems(4) IO r2 step.prems(3)
by(rule plossless_gpv_ContD)
then have "bad2 s2' ⟹ weight_spmf (exec_gpv oracle2 (c r2) s2') = 1"
and "¬ bad2 s2' ⟹ ord_spmf (=) (exec_gpv' (c r2) s1') (map_spmf fst (exec_until_bad (λ(x, y). joint_oracle x y) oracle1 bad1 oracle2 bad2 (c r2) s1' s2'))"
using callee_invariant_on.weight_exec_gpv[OF bad_sticky2 lossless2, of s1' "c r2" s2']
weight_spmf_le_1[of "exec_gpv oracle2 (c r2) s2'"] WT_gpv_ContD[OF step.prems(3) IO r2]
3[OF joint _ out] step.prems(1) False
by(simp_all add: pgen_lossless_gpv_def weight_gpv_def step.IH) }
ultimately show ?thesis using False step.prems(1)
by(rewrite exec_until_bad.simps)
(fastforce simp add: map_spmf_bind_spmf WT_gpv_OutD[OF step.prems(3)] oracle1 bind_map_spmf step.hyps intro!: ord_spmf_bind_reflI split!: generat.split dest: 3)
qed
qed
qed
show "map_spmf snd ?pq = exec_gpv oracle2 gpv s2"
proof(rule spmf.leq_antisym)
show "ord_spmf (=) (map_spmf snd ?pq) (exec_gpv oracle2 gpv s2)" using * bad WT_gpv lossless
proof(induction arbitrary: s1 s2 gpv rule: exec_until_bad_fixp_induct)
case adm show ?case by simp
case bottom show ?case by simp
case (step exec_until_bad')
show ?case
proof(cases "bad2 s2")
case True
then have "weight_spmf (exec_gpv oracle1 gpv s1) = 1"
using callee_invariant_on.weight_exec_gpv[OF bad_sticky1 lossless1, of s2 gpv s1]
step.prems weight_spmf_le_1[of "exec_gpv oracle1 gpv s1"]
by(simp add: pgen_lossless_gpv_def weight_gpv_def)
then show ?thesis using True by simp
next
case False
hence "¬ bad1 s1" using step.prems(2) by simp
moreover {
fix out c r1 s1' r2 s2'
assume IO: "IO out c ∈ set_spmf (the_gpv gpv)"
and joint: "((r1, s1'), (r2, s2')) ∈ set_spmf (joint_oracle s1 s2 out)"
from step.prems(3) IO have out: "out ∈ outs_ℐ ℐ" by(rule WT_gpvD)
from setD[OF _ out joint] step.prems(1) False
have 1: "(r1, s1') ∈ set_spmf (oracle1 s1 out)"
and 2: "(r2, s2') ∈ set_spmf (oracle2 s2 out)" by simp_all
hence r1: "r1 ∈ responses_ℐ ℐ out" and r2: "r2 ∈ responses_ℐ ℐ out"
using WT_oracle1 WT_oracle2 out by(blast dest: WT_calleeD)+
have *: "plossless_gpv ℐ (c r1)" using step.prems(4) IO r1 step.prems(3)
by(rule plossless_gpv_ContD)
then have "bad2 s2' ⟹ weight_spmf (exec_gpv oracle1 (c r1) s1') = 1"
and "¬ bad2 s2' ⟹ ord_spmf (=) (map_spmf snd (exec_until_bad' (c r2) s1' s2')) (exec_gpv oracle2 (c r2) s2')"
using callee_invariant_on.weight_exec_gpv[OF bad_sticky1 lossless1, of s2' "c r1" s1']
weight_spmf_le_1[of "exec_gpv oracle1 (c r1) s1'"] WT_gpv_ContD[OF step.prems(3) IO r1]
3[OF joint _ out] step.prems(1) False
by(simp_all add: pgen_lossless_gpv_def weight_gpv_def step.IH) }
ultimately show ?thesis using False step.prems(1)
by(rewrite in "ord_spmf _ _ ⌑" exec_gpv.simps)
(fastforce simp add: split_def bind_map_spmf map_spmf_bind_spmf oracle2 WT_gpv_OutD[OF step.prems(3)] intro!: ord_spmf_bind_reflI split!: generat.split dest: 3)
qed
qed
show "ord_spmf (=) (exec_gpv oracle2 gpv s2) (map_spmf snd ?pq)" using * bad WT_gpv lossless
proof(induction arbitrary: gpv s1 s2 rule: exec_gpv_fixp_induct_strong)
case adm show ?case by simp
case bottom show ?case by simp
case (step exec_gpv')
then show ?case
proof(cases "bad2 s2")
case True
then have "weight_spmf (exec_gpv oracle1 gpv s1) = 1"
using callee_invariant_on.weight_exec_gpv[OF bad_sticky1 lossless1, of s2 gpv s1]
step.prems weight_spmf_le_1[of "exec_gpv oracle1 gpv s1"]
by(simp add: pgen_lossless_gpv_def weight_gpv_def)
then show ?thesis using True
by(rewrite exec_until_bad.simps; subst (2) exec_gpv.simps)
(clarsimp intro!: ord_spmf_bind_reflI split!: generat.split simp add: step.hyps)
next
case False
hence "¬ bad1 s1" using step.prems(2) by simp
moreover {
fix out c r1 s1' r2 s2'
assume IO: "IO out c ∈ set_spmf (the_gpv gpv)"
and joint: "((r1, s1'), (r2, s2')) ∈ set_spmf (joint_oracle s1 s2 out)"
from step.prems(3) IO have out: "out ∈ outs_ℐ ℐ" by(rule WT_gpvD)
from setD[OF _ out joint] step.prems(1) False
have 1: "(r1, s1') ∈ set_spmf (oracle1 s1 out)"
and 2: "(r2, s2') ∈ set_spmf (oracle2 s2 out)" by simp_all
hence r1: "r1 ∈ responses_ℐ ℐ out" and r2: "r2 ∈ responses_ℐ ℐ out"
using WT_oracle1 WT_oracle2 out by(blast dest: WT_calleeD)+
have *: "plossless_gpv ℐ (c r1)" using step.prems(4) IO r1 step.prems(3)
by(rule plossless_gpv_ContD)
then have "bad2 s2' ⟹ weight_spmf (exec_gpv oracle1 (c r1) s1') = 1"
and "¬ bad2 s2' ⟹ ord_spmf (=) (exec_gpv' (c r2) s2') (map_spmf snd (exec_until_bad (λ(x, y). joint_oracle x y) oracle1 bad1 oracle2 bad2 (c r2) s1' s2'))"
using callee_invariant_on.weight_exec_gpv[OF bad_sticky1 lossless1, of s2' "c r1" s1']
weight_spmf_le_1[of "exec_gpv oracle1 (c r1) s1'"] WT_gpv_ContD[OF step.prems(3) IO r1]
3[OF joint _ out] step.prems(1) False
by(simp_all add: pgen_lossless_gpv_def step.IH weight_gpv_def) }
ultimately show ?thesis using False step.prems(1)
by(rewrite exec_until_bad.simps)
(fastforce simp add: map_spmf_bind_spmf WT_gpv_OutD[OF step.prems(3)] oracle2 bind_map_spmf step.hyps intro!: ord_spmf_bind_reflI split!: generat.split dest: 3)
qed
qed
qed
have "set_spmf ?pq ⊆ {(as1, bs2). ?R' as1 bs2}" using * bad WT_gpv
proof(induction arbitrary: gpv s1 s2 rule: exec_until_bad_fixp_induct)
case adm show ?case by(intro cont_intro ccpo_class.admissible_leI)
case bottom show ?case by simp
case step
have switch: "set_spmf (exec_gpv oracle1 (c r1) s1') × set_spmf (exec_gpv oracle2 (c r2) s2')
⊆ {((a, s1'), b, s2'). bad1 s1' = bad2 s2' ∧ (if bad2 s2' then X_bad s1' s2' else a = b ∧ X s1' s2')}"
if "¬ bad1 s1" "ℐ ⊢g gpv √" "¬ bad2 s2" and X: "X s1 s2" and out: "IO out c ∈ set_spmf (the_gpv gpv)"
and joint: "((r1, s1'), (r2, s2')) ∈ set_spmf (joint_oracle s1 s2 out)"
and bad2: "bad2 s2'"
for out c r1 s1' r2 s2'
proof(clarify; rule conjI)
from step.prems(3) out have outs: "out ∈ outs_ℐ ℐ" by(rule WT_gpv_OutD)
from bad2 3[OF joint X this] have bad1: "bad1 s1' ∧ X_bad s1' s2'" by simp_all
have s1': "(r1, s1') ∈ set_spmf (oracle1 s1 out)" and s2': "(r2, s2') ∈ set_spmf (oracle2 s2 out)"
using setD[OF X outs joint] by simp_all
have resp: "r1 ∈ responses_ℐ ℐ out" using WT_oracle1 s1' outs by(rule WT_calleeD)
with step.prems(3) out have WT1: "ℐ ⊢g c r1 √" by(rule WT_gpv_ContD)
have resp: "r2 ∈ responses_ℐ ℐ out" using WT_oracle2 s2' outs by(rule WT_calleeD)
with step.prems(3) out have WT2: "ℐ ⊢g c r2 √" by(rule WT_gpv_ContD)
fix r1' s1'' r2' s2''
assume s1'': "(r1', s1'') ∈ set_spmf (exec_gpv oracle1 (c r1) s1')"
and s2'': "(r2', s2'') ∈ set_spmf (exec_gpv oracle2 (c r2) s2')"
have *: "bad1 s1'' ∧ X_bad s1'' s2'" using bad2 s1'' bad1 WT1
by(rule callee_invariant_on.exec_gpv_invariant[OF bad_sticky1])
have "bad2 s2'' ∧ X_bad s1'' s2''" using _ s2'' _ WT2
by(rule callee_invariant_on.exec_gpv_invariant[OF bad_sticky2])(simp_all add: bad2 *)
then show "bad1 s1'' = bad2 s2''" "if bad2 s2'' then X_bad s1'' s2'' else r1' = r2' ∧ X s1'' s2''"
using * by(simp_all)
qed
show ?case using step.prems
apply(clarsimp simp add: bind_UNION step.IH 3 WT_gpv_OutD WT_gpv_ContD del: subsetI intro!: UN_least split: generat.split if_split_asm)
subgoal by(auto 4 3 dest: callee_invariant_on.exec_gpv_invariant[OF bad_sticky1, rotated] callee_invariant_on.exec_gpv_invariant[OF bad_sticky2, rotated] 3)
apply(intro strip conjI)
subgoal by(drule (6) switch) auto
subgoal by(auto 4 3 intro!: step.IH[THEN order.trans] del: subsetI dest: 3 setD[rotated 2] simp add: WT_gpv_OutD WT_gpv_ContD intro: WT_gpv_ContD intro!: WT_calleeD[OF WT_oracle1])
done
qed
then show "⋀x y. (x, y) ∈ set_spmf ?pq ⟹ ?R x y" by auto
qed
qed
lemma exec_gpv_oracle_bisim_bad':
fixes s1 :: 's1 and s2 :: 's2 and X :: "'s1 ⇒ 's2 ⇒ bool"
and oracle1 :: "'s1 ⇒ 'a ⇒ ('b × 's1) spmf"
and oracle2 :: "'s2 ⇒ 'a ⇒ ('b × 's2) spmf"
assumes *: "if bad2 s2 then X_bad s1 s2 else X s1 s2"
and bad: "bad1 s1 = bad2 s2"
and bisim: "⋀s1 s2 x. ⟦ X s1 s2; x ∈ outs_ℐ ℐ ⟧ ⟹ rel_spmf (λ(a, s1') (b, s2'). bad1 s1' = bad2 s2' ∧ (if bad2 s2' then X_bad s1' s2' else a = b ∧ X s1' s2')) (oracle1 s1 x) (oracle2 s2 x)"
and bad_sticky1: "⋀s2. bad2 s2 ⟹ callee_invariant_on oracle1 (λs1. bad1 s1 ∧ X_bad s1 s2) ℐ"
and bad_sticky2: "⋀s1. bad1 s1 ⟹ callee_invariant_on oracle2 (λs2. bad2 s2 ∧ X_bad s1 s2) ℐ"
and lossless1: "⋀s1 x. ⟦ bad1 s1; x ∈ outs_ℐ ℐ ⟧ ⟹ lossless_spmf (oracle1 s1 x)"
and lossless2: "⋀s2 x. ⟦ bad2 s2; x ∈ outs_ℐ ℐ ⟧ ⟹ lossless_spmf (oracle2 s2 x)"
and lossless: "lossless_gpv ℐ gpv"
and WT_oracle1: "⋀s1. ℐ ⊢c oracle1 s1 √"
and WT_oracle2: "⋀s2. ℐ ⊢c oracle2 s2 √"
and WT_gpv: "ℐ ⊢g gpv √"
shows "rel_spmf (λ(a, s1') (b, s2'). bad1 s1' = bad2 s2' ∧ (if bad2 s2' then X_bad s1' s2' else a = b ∧ X s1' s2')) (exec_gpv oracle1 gpv s1) (exec_gpv oracle2 gpv s2)"
using assms(1-7) lossless_imp_plossless_gpv[OF lossless WT_gpv] assms(9-)
by(rule exec_gpv_oracle_bisim_bad_plossless)
lemma exec_gpv_oracle_bisim_bad_invariant:
fixes s1 :: 's1 and s2 :: 's2 and X :: "'s1 ⇒ 's2 ⇒ bool" and I1 :: "'s1 ⇒ bool" and I2 :: "'s2 ⇒ bool"
and oracle1 :: "'s1 ⇒ 'a ⇒ ('b × 's1) spmf"
and oracle2 :: "'s2 ⇒ 'a ⇒ ('b × 's2) spmf"
assumes *: "if bad2 s2 then X_bad s1 s2 else X s1 s2"
and bad: "bad1 s1 = bad2 s2"
and bisim: "⋀s1 s2 x. ⟦ X s1 s2; x ∈ outs_ℐ ℐ; I1 s1; I2 s2 ⟧ ⟹ rel_spmf (λ(a, s1') (b, s2'). bad1 s1' = bad2 s2' ∧ (if bad2 s2' then X_bad s1' s2' else a = b ∧ X s1' s2')) (oracle1 s1 x) (oracle2 s2 x)"
and bad_sticky1: "⋀s2. ⟦ bad2 s2; I2 s2 ⟧ ⟹ callee_invariant_on oracle1 (λs1. bad1 s1 ∧ X_bad s1 s2) ℐ"
and bad_sticky2: "⋀s1. ⟦ bad1 s1; I1 s1 ⟧ ⟹ callee_invariant_on oracle2 (λs2. bad2 s2 ∧ X_bad s1 s2) ℐ"
and lossless1: "⋀s1 x. ⟦ bad1 s1; I1 s1; x ∈ outs_ℐ ℐ ⟧ ⟹ lossless_spmf (oracle1 s1 x)"
and lossless2: "⋀s2 x. ⟦ bad2 s2; I2 s2; x ∈ outs_ℐ ℐ ⟧ ⟹ lossless_spmf (oracle2 s2 x)"
and lossless: "lossless_gpv ℐ gpv"
and WT_gpv: "ℐ ⊢g gpv √"
and I1: "callee_invariant_on oracle1 I1 ℐ"
and I2: "callee_invariant_on oracle2 I2 ℐ"
and s1: "I1 s1"
and s2: "I2 s2"
shows "rel_spmf (λ(a, s1') (b, s2'). bad1 s1' = bad2 s2' ∧ (if bad2 s2' then X_bad s1' s2' else a = b ∧ X s1' s2')) (exec_gpv oracle1 gpv s1) (exec_gpv oracle2 gpv s2)"
including lifting_syntax
proof -
interpret I1: callee_invariant_on oracle1 I1 ℐ by(fact I1)
interpret I2: callee_invariant_on oracle2 I2 ℐ by(fact I2)
from s1 have nonempty1: "{s. I1 s} ≠ {}" by auto
{ assume "∃(Rep1 :: 's1' ⇒ 's1) Abs1. type_definition Rep1 Abs1 {s. I1 s}"
and "∃(Rep2 :: 's2' ⇒ 's2) Abs2. type_definition Rep2 Abs2 {s. I2 s}"
then obtain Rep1 :: "'s1' ⇒ 's1" and Abs1 and Rep2 :: "'s2' ⇒ 's2" and Abs2
where td1: "type_definition Rep1 Abs1 {s. I1 s}" and td2: "type_definition Rep2 Abs2 {s. I2 s}"
by blast
interpret td1: type_definition Rep1 Abs1 "{s. I1 s}" by(rule td1)
interpret td2: type_definition Rep2 Abs2 "{s. I2 s}" by(rule td2)
define cr1 where "cr1 ≡ λx y. x = Rep1 y"
have [transfer_rule]: "bi_unique cr1" "right_total cr1" using td1 cr1_def by(rule typedef_bi_unique typedef_right_total)+
have [transfer_domain_rule]: "Domainp cr1 = I1" using type_definition_Domainp[OF td1 cr1_def] by simp
define cr2 where "cr2 ≡ λx y. x = Rep2 y"
have [transfer_rule]: "bi_unique cr2" "right_total cr2" using td2 cr2_def by(rule typedef_bi_unique typedef_right_total)+
have [transfer_domain_rule]: "Domainp cr2 = I2" using type_definition_Domainp[OF td2 cr2_def] by simp
let ?C = "eq_onp (λout. out ∈ outs_ℐ ℐ)"
define oracle1' where "oracle1' ≡ (Rep1 ---> id ---> map_spmf (map_prod id Abs1)) oracle1"
have [transfer_rule]: "(cr1 ===> ?C ===> rel_spmf (rel_prod (=) cr1)) oracle1 oracle1'"
by(auto simp add: oracle1'_def rel_fun_def cr1_def spmf_rel_map prod.rel_map td1.Abs_inverse eq_onp_def intro!: rel_spmf_reflI intro: td1.Rep[simplified] dest: I1.callee_invariant)
define oracle2' where "oracle2' ≡ (Rep2 ---> id ---> map_spmf (map_prod id Abs2)) oracle2"
have [transfer_rule]: "(cr2 ===> ?C ===> rel_spmf (rel_prod (=) cr2)) oracle2 oracle2'"
by(auto simp add: oracle2'_def rel_fun_def cr2_def spmf_rel_map prod.rel_map td2.Abs_inverse eq_onp_def intro!: rel_spmf_reflI intro: td2.Rep[simplified] dest: I2.callee_invariant)
define s1' where "s1' ≡ Abs1 s1"
have [transfer_rule]: "cr1 s1 s1'" using s1 by(simp add: cr1_def s1'_def td1.Abs_inverse)
define s2' where "s2' ≡ Abs2 s2"
have [transfer_rule]: "cr2 s2 s2'" using s2 by(simp add: cr2_def s2'_def td2.Abs_inverse)
define bad1' where "bad1' ≡ (Rep1 ---> id) bad1"
have [transfer_rule]: "(cr1 ===> (=)) bad1 bad1'" by(simp add: rel_fun_def bad1'_def cr1_def)
define bad2' where "bad2' ≡ (Rep2 ---> id) bad2"
have [transfer_rule]: "(cr2 ===> (=)) bad2 bad2'" by(simp add: rel_fun_def bad2'_def cr2_def)
define X' where "X' ≡ (Rep1 ---> Rep2 ---> id) X"
have [transfer_rule]: "(cr1 ===> cr2 ===> (=)) X X'" by(simp add: rel_fun_def X'_def cr1_def cr2_def)
define X_bad' where "X_bad' ≡ (Rep1 ---> Rep2 ---> id) X_bad"
have [transfer_rule]: "(cr1 ===> cr2 ===> (=)) X_bad X_bad'" by(simp add: rel_fun_def X_bad'_def cr1_def cr2_def)
define gpv' where "gpv' ≡ restrict_gpv ℐ gpv"
have [transfer_rule]: "rel_gpv (=) ?C gpv' gpv'"
by(fold eq_onp_top_eq_eq)(auto simp add: gpv.rel_eq_onp eq_onp_same_args pred_gpv_def gpv'_def dest: in_outs'_restrict_gpvD)
have "if bad2' s2' then X_bad' s1' s2' else X' s1' s2'" using * by transfer
moreover have "bad1' s1' ⟷ bad2' s2'" using bad by transfer
moreover have x: "?C x x" if "x ∈ outs_ℐ ℐ" for x using that by(simp add: eq_onp_def)
have "rel_spmf (λ(a, s1') (b, s2'). (bad1' s1' ⟷ bad2' s2') ∧ (if bad2' s2' then X_bad' s1' s2' else a = b ∧ X' s1' s2')) (oracle1' s1 x) (oracle2' s2 x)"
if "X' s1 s2" and "x ∈ outs_ℐ ℐ" for s1 s2 x using that(1) supply that(2)[THEN x, transfer_rule]
by(transfer)(rule bisim[OF _ that(2)])
moreover have [transfer_rule]: "rel_ℐ ?C (=) ℐ ℐ" by(rule rel_ℐI)(auto simp add: set_relator_eq_onp eq_onp_same_args rel_set_eq dest: eq_onp_to_eq)
have "callee_invariant_on oracle1' (λs1. bad1' s1 ∧ X_bad' s1 s2) ℐ" if "bad2' s2" for s2
using that unfolding callee_invariant_on_alt_def apply(transfer)
using bad_sticky1[unfolded callee_invariant_on_alt_def] by blast
moreover have "callee_invariant_on oracle2' (λs2. bad2' s2 ∧ X_bad' s1 s2) ℐ" if "bad1' s1" for s1
using that unfolding callee_invariant_on_alt_def apply(transfer)
using bad_sticky2[unfolded callee_invariant_on_alt_def] by blast
moreover have "lossless_spmf (oracle1' s1 x)" if "bad1' s1" "x ∈ outs_ℐ ℐ" for s1 x
using that supply that(2)[THEN x, transfer_rule] by transfer(rule lossless1)
moreover have "lossless_spmf (oracle2' s2 x)" if "bad2' s2" "x ∈ outs_ℐ ℐ" for s2 x
using that supply that(2)[THEN x, transfer_rule] by transfer(rule lossless2)
moreover have "lossless_gpv ℐ gpv'" using WT_gpv lossless by(simp add: gpv'_def lossless_restrict_gpvI)
moreover have "ℐ ⊢c oracle1' s1 √" for s1 using I1.WT_callee by transfer
moreover have "ℐ ⊢c oracle2' s2 √" for s2 using I2.WT_callee by transfer
moreover have "ℐ ⊢g gpv' √" by(simp add: gpv'_def)
ultimately have **: "rel_spmf (λ(a, s1') (b, s2'). bad1' s1' = bad2' s2' ∧ (if bad2' s2' then X_bad' s1' s2' else a = b ∧ X' s1' s2')) (exec_gpv oracle1' gpv' s1') (exec_gpv oracle2' gpv' s2')"
by(rule exec_gpv_oracle_bisim_bad')
have [transfer_rule]: "((=) ===> ?C ===> rel_spmf (rel_prod (=) (=))) oracle2 oracle2"
"((=) ===> ?C ===> rel_spmf (rel_prod (=) (=))) oracle1 oracle1"
by(simp_all add: rel_fun_def eq_onp_def prod.rel_eq)
note [transfer_rule] = bi_unique_eq_onp bi_unique_eq
from ** have "rel_spmf (λ(a, s1') (b, s2'). bad1 s1' = bad2 s2' ∧ (if bad2 s2' then X_bad s1' s2' else a = b ∧ X s1' s2')) (exec_gpv oracle1 gpv' s1) (exec_gpv oracle2 gpv' s2)"
by(transfer)
also have "exec_gpv oracle1 gpv' s1 = exec_gpv oracle1 gpv s1"
unfolding gpv'_def using WT_gpv s1 by(rule I1.exec_gpv_restrict_gpv_invariant)
also have "exec_gpv oracle2 gpv' s2 = exec_gpv oracle2 gpv s2"
unfolding gpv'_def using WT_gpv s2 by(rule I2.exec_gpv_restrict_gpv_invariant)
finally have ?thesis . }
from this[cancel_type_definition, OF nonempty1, cancel_type_definition] s2 show ?thesis by blast
qed
lemma exec_gpv_oracle_bisim_bad:
assumes *: "if bad2 s2 then X_bad s1 s2 else X s1 s2"
and bad: "bad1 s1 = bad2 s2"
and bisim: "⋀s1 s2 x. X s1 s2 ⟹ rel_spmf (λ(a, s1') (b, s2'). bad1 s1' = bad2 s2' ∧ (if bad2 s2' then X_bad s1' s2' else a = b ∧ X s1' s2')) (oracle1 s1 x) (oracle2 s2 x)"
and bad_sticky1: "⋀s2. bad2 s2 ⟹ callee_invariant_on oracle1 (λs1. bad1 s1 ∧ X_bad s1 s2) ℐ"
and bad_sticky2: "⋀s1. bad1 s1 ⟹ callee_invariant_on oracle2 (λs2. bad2 s2 ∧ X_bad s1 s2) ℐ"
and lossless1: "⋀s1 x. bad1 s1 ⟹ lossless_spmf (oracle1 s1 x)"
and lossless2: "⋀s2 x. bad2 s2 ⟹ lossless_spmf (oracle2 s2 x)"
and lossless: "lossless_gpv ℐ gpv"
and WT_oracle1: "⋀s1. ℐ ⊢c oracle1 s1 √"
and WT_oracle2: "⋀s2. ℐ ⊢c oracle2 s2 √"
and WT_gpv: "ℐ ⊢g gpv √"
and R: "⋀a s1 b s2. ⟦ bad1 s1 = bad2 s2; ¬ bad2 s2 ⟹ a = b ∧ X s1 s2; bad2 s2 ⟹ X_bad s1 s2 ⟧ ⟹ R (a, s1) (b, s2)"
shows "rel_spmf R (exec_gpv oracle1 gpv s1) (exec_gpv oracle2 gpv s2)"
using exec_gpv_oracle_bisim_bad'[OF * bad bisim bad_sticky1 bad_sticky2 lossless1 lossless2 lossless WT_oracle1 WT_oracle2 WT_gpv]
by(rule rel_spmf_mono)(auto intro: R)
lemma exec_gpv_oracle_bisim_bad_full:
assumes "X s1 s2"
and "bad1 s1 = bad2 s2"
and "⋀s1 s2 x. X s1 s2 ⟹ rel_spmf (λ(a, s1') (b, s2'). bad1 s1' = bad2 s2' ∧ (¬ bad2 s2' ⟶ a = b ∧ X s1' s2')) (oracle1 s1 x) (oracle2 s2 x)"
and "callee_invariant oracle1 bad1"
and "callee_invariant oracle2 bad2"
and "⋀s1 x. bad1 s1 ⟹ lossless_spmf (oracle1 s1 x)"
and "⋀s2 x. bad2 s2 ⟹ lossless_spmf (oracle2 s2 x)"
and "lossless_gpv ℐ_full gpv"
and R: "⋀a s1 b s2. ⟦ bad1 s1 = bad2 s2; ¬ bad2 s2 ⟹ a = b ∧ X s1 s2 ⟧ ⟹ R (a, s1) (b, s2)"
shows "rel_spmf R (exec_gpv oracle1 gpv s1) (exec_gpv oracle2 gpv s2)"
using assms
by(intro exec_gpv_oracle_bisim_bad[of bad2 s2 "λ_ _. True" s1 X bad1 oracle1 oracle2 ℐ_full gpv R])(auto intro: rel_spmf_mono)
lemma max_enn2ereal: "max (enn2ereal x) (enn2ereal y) = enn2ereal (max x y)"
including ennreal.lifting unfolding max_def by transfer simp
lemma identical_until_bad:
assumes bad_eq: "map_spmf bad p = map_spmf bad q"
and not_bad: "measure (measure_spmf (map_spmf (λx. (f x, bad x)) p)) (A × {False}) = measure (measure_spmf (map_spmf (λx. (f x, bad x)) q)) (A × {False})"
shows "¦measure (measure_spmf (map_spmf f p)) A - measure (measure_spmf (map_spmf f q)) A¦ ≤ spmf (map_spmf bad p) True"
proof -
have "¦enn2ereal (measure (measure_spmf (map_spmf f p)) A) - enn2ereal (measure (measure_spmf (map_spmf f q)) A)¦ =
¦enn2ereal (∫⇧+ x. indicator A (f x) ∂measure_spmf p) - enn2ereal (∫⇧+ x. indicator A (f x) ∂measure_spmf q)¦"
unfolding measure_spmf.emeasure_eq_measure[symmetric]
by(simp add: nn_integral_indicator[symmetric] indicator_vimage[abs_def] o_def)
also have "… =
¦enn2ereal (∫⇧+ x. indicator (A × {False}) (f x, bad x) + indicator (A × {True}) (f x, bad x) ∂measure_spmf p) -
enn2ereal (∫⇧+ x. indicator (A × {False}) (f x, bad x) + indicator (A × {True}) (f x, bad x) ∂measure_spmf q)¦"
by(intro arg_cong[where f=abs] arg_cong2[where f="(-)"] arg_cong[where f=enn2ereal] nn_integral_cong)(simp_all split: split_indicator)
also have "… =
¦enn2ereal (emeasure (measure_spmf (map_spmf (λx. (f x, bad x)) p)) (A × {False}) + (∫⇧+ x. indicator (A × {True}) (f x, bad x) ∂measure_spmf p)) -
enn2ereal (emeasure (measure_spmf (map_spmf (λx. (f x, bad x)) q)) (A × {False}) + (∫⇧+ x. indicator (A × {True}) (f x, bad x) ∂measure_spmf q))¦"
by(subst (1 2) nn_integral_add)(simp_all add: indicator_vimage[abs_def] o_def nn_integral_indicator[symmetric])
also have "… = ¦enn2ereal (∫⇧+ x. indicator (A × {True}) (f x, bad x) ∂measure_spmf p) - enn2ereal (∫⇧+ x. indicator (A × {True}) (f x, bad x) ∂measure_spmf q)¦"
(is "_ = ¦?x - ?y¦")
by(simp add: measure_spmf.emeasure_eq_measure not_bad plus_ennreal.rep_eq ereal_diff_add_eq_diff_diff_swap ereal_diff_add_assoc2 ereal_add_uminus_conv_diff)
also have "… ≤ max ?x ?y"
proof(rule ereal_abs_leI)
have "?x - ?y ≤ ?x - 0" by(rule ereal_minus_mono)(simp_all)
also have "… ≤ max ?x ?y" by simp
finally show "?x - ?y ≤ …" .
have "- (?x - ?y) = ?y - ?x"
by(rule ereal_minus_diff_eq)(simp_all add: measure_spmf.nn_integral_indicator_neq_top)
also have "… ≤ ?y - 0" by(rule ereal_minus_mono)(simp_all)
also have "… ≤ max ?x ?y" by simp
finally show "- (?x - ?y) ≤ …" .
qed
also have "… ≤ enn2ereal (max (∫⇧+ x. indicator {True} (bad x) ∂measure_spmf p) (∫⇧+ x. indicator {True} (bad x) ∂measure_spmf q))"
unfolding max_enn2ereal less_eq_ennreal.rep_eq[symmetric]
by(intro max.mono nn_integral_mono)(simp_all split: split_indicator)
also have "… = enn2ereal (spmf (map_spmf bad p) True)"
using arg_cong2[where f=spmf, OF bad_eq refl, of True, THEN arg_cong[where f=ennreal]]
unfolding ennreal_spmf_map_conv_nn_integral indicator_vimage[abs_def] by simp
finally show ?thesis by simp
qed
lemma (in callee_invariant_on) exec_gpv_bind_materialize:
fixes f :: "'s ⇒ 'r spmf"
and g :: "'x × 's ⇒ 'r ⇒ 'y spmf"
and s :: "'s"
defines "exec_gpv2 ≡ exec_gpv"
assumes cond: "⋀s x y s'. ⟦ (y, s') ∈ set_spmf (callee s x); I s ⟧ ⟹ f s = f s'"
and ℐ: "ℐ = ℐ_full"
shows "bind_spmf (exec_gpv callee gpv s) (λas. bind_spmf (f (snd as)) (g as)) =
exec_gpv2 (λ(r, s) x. bind_spmf (callee s x) (λ(y, s'). if I s' ∧ r = None then map_spmf (λr. (y, (Some r, s'))) (f s') else return_spmf (y, (r, s')))) gpv (None, s)
⤜ (λ(a, r, s). case r of None ⇒ bind_spmf (f s) (g (a, s)) | Some r' ⇒ g (a, s) r')"
(is "?lhs = ?rhs" is "_ = bind_spmf (exec_gpv2 ?callee2 _ _) _")
proof -
define exec_gpv1 :: "('a, 'b, 's option × 's) callee ⇒ ('x, 'a, 'b) gpv ⇒ _"
where [simp]: "exec_gpv1 = exec_gpv"
let ?X = "λs (ss, s'). s = s'"
let ?callee = "λ(ss, s) x. map_spmf (λ(y, s'). (y, if I s' ∧ ss = None then Some s' else ss, s')) (callee s x)"
let ?track = "exec_gpv1 ?callee gpv (None, s)"
have "rel_spmf (rel_prod (=) ?X) (exec_gpv callee gpv s) ?track" unfolding exec_gpv1_def
by(rule exec_gpv_oracle_bisim[where X="?X"])(auto simp add: spmf_rel_map intro!: rel_spmf_reflI)
hence "exec_gpv callee gpv s = map_spmf (λ(a, ss, s). (a, s)) ?track"
by(auto simp add: spmf_rel_eq[symmetric] spmf_rel_map elim: rel_spmf_mono)
hence "?lhs = bind_spmf ?track (λ(a, s'', s'). bind_spmf (f s') (g (a, s')))"
by(simp add: bind_map_spmf o_def split_def)
also let ?inv = "λ(ss, s). case ss of None ⇒ True | Some s' ⇒ f s = f s' ∧ I s' ∧ I s"
interpret inv: callee_invariant_on "?callee" "?inv" ℐ
by unfold_locales(auto 4 4 split: option.split if_split_asm dest: cond callee_invariant simp add: ℐ)
have "bind_spmf ?track (λ(a, s'', s'). bind_spmf (f s') (g (a, s'))) =
bind_spmf ?track (λ(a, ss', s'). bind_spmf (f (case ss' of None ⇒ s' | Some s'' ⇒ s'')) (g (a, s')))"
(is "_ = ?rhs'")
by(rule bind_spmf_cong[OF refl])(auto dest!: inv.exec_gpv_invariant split: option.split_asm simp add: ℐ)
also
have track_Some: "exec_gpv ?callee gpv (Some ss, s) = map_spmf (λ(a, s). (a, Some ss, s)) (exec_gpv callee gpv s)"
for s ss :: 's and gpv :: "('x, 'a, 'b) gpv"
proof -
let ?X = "λ(ss', s') s. s = s' ∧ ss' = Some ss"
have "rel_spmf (rel_prod (=) ?X) (exec_gpv ?callee gpv (Some ss, s)) (exec_gpv callee gpv s)"
by(rule exec_gpv_oracle_bisim[where X="?X"])(auto simp add: spmf_rel_map intro!: rel_spmf_reflI)
thus ?thesis by(auto simp add: spmf_rel_eq[symmetric] spmf_rel_map elim: rel_spmf_mono)
qed
have sample_Some: "exec_gpv ?callee2 gpv (Some r, s) = map_spmf (λ(a, s). (a, Some r, s)) (exec_gpv callee gpv s)"
for s :: 's and r :: 'r and gpv :: "('x, 'a, 'b) gpv"
proof -
let ?X = "λ(r', s') s. s' = s ∧ r' = Some r"
have "rel_spmf (rel_prod (=) ?X) (exec_gpv ?callee2 gpv (Some r, s)) (exec_gpv callee gpv s)"
by(rule exec_gpv_oracle_bisim[where X="?X"])(auto simp add: spmf_rel_map map_spmf_conv_bind_spmf[symmetric] split_def intro!: rel_spmf_reflI)
then show ?thesis by(auto simp add: spmf_rel_eq[symmetric] spmf_rel_map elim: rel_spmf_mono)
qed
have "?rhs' = ?rhs"
proof(rule spmf.leq_antisym)
show "ord_spmf (=) ?rhs' ?rhs" unfolding exec_gpv1_def
proof(induction arbitrary: gpv s rule: exec_gpv_fixp_induct_strong)
case adm show ?case by simp
case bottom show ?case by simp
case (step exec_gpv')
show ?case unfolding exec_gpv2_def
apply(rewrite in "ord_spmf _ _ ⌑" exec_gpv.simps)
apply(clarsimp split: generat.split simp add: bind_map_spmf intro!: ord_spmf_bind_reflI split del: if_split)
subgoal for out rpv ret s'
apply(cases "I s'")
subgoal
apply simp
apply(rule spmf.leq_trans)
apply(rule ord_spmf_bindI[OF step.hyps])
apply hypsubst
apply(rule spmf.leq_refl)
apply(simp add: track_Some sample_Some bind_map_spmf o_def)
apply(subst bind_commute_spmf)
apply(simp add: split_def)
done
subgoal
apply simp
apply(rule step.IH[THEN spmf.leq_trans])
apply(simp add: split_def exec_gpv2_def)
done
done
done
qed
show "ord_spmf (=) ?rhs ?rhs'" unfolding exec_gpv2_def
proof(induction arbitrary: gpv s rule: exec_gpv_fixp_induct_strong)
case adm show ?case by simp
case bottom show ?case by simp
case (step exec_gpv')
show ?case unfolding exec_gpv1_def
apply(rewrite in "ord_spmf _ _ ⌑" exec_gpv.simps)
apply(clarsimp split: generat.split simp add: bind_map_spmf intro!: ord_spmf_bind_reflI split del: if_split)
subgoal for out rpv ret s'
apply(cases "I s'")
subgoal
apply(simp add: bind_map_spmf o_def)
apply(rule spmf.leq_trans)
apply(rule ord_spmf_bind_reflI)
apply(rule ord_spmf_bindI)
apply(rule step.hyps)
apply hypsubst
apply(rule spmf.leq_refl)
apply(simp add: track_Some sample_Some bind_map_spmf o_def)
apply(subst bind_commute_spmf)
apply(simp add: split_def)
done
subgoal
apply simp
apply(rule step.IH[THEN spmf.leq_trans])
apply(simp add: split_def exec_gpv2_def)
done
done
done
qed
qed
finally show ?thesis .
qed
primcorec gpv_stop :: "('a, 'c, 'r) gpv ⇒ ('a option, 'c, 'r option) gpv"
where
"the_gpv (gpv_stop gpv) =
map_spmf (map_generat Some id (λrpv input. case input of None ⇒ Done None | Some input' ⇒ gpv_stop (rpv input')))
(the_gpv gpv)"
lemma gpv_stop_Done [simp]: "gpv_stop (Done x) = Done (Some x)"
by(rule gpv.expand) simp
lemma gpv_stop_Fail [simp]: "gpv_stop Fail = Fail"
by(rule gpv.expand) simp
lemma gpv_stop_Pause [simp]: "gpv_stop (Pause out rpv) = Pause out (λinput. case input of None ⇒ Done None | Some input' ⇒ gpv_stop (rpv input'))"
by(rule gpv.expand) simp
lemma gpv_stop_lift_spmf [simp]: "gpv_stop (lift_spmf p) = lift_spmf (map_spmf Some p)"
by(rule gpv.expand)(simp add: spmf.map_comp o_def)
lemma gpv_stop_bind [simp]:
"gpv_stop (bind_gpv gpv f) = bind_gpv (gpv_stop gpv) (λx. case x of None ⇒ Done None | Some x' ⇒ gpv_stop (f x'))"
apply(coinduction arbitrary: gpv rule: gpv.coinduct_strong)
apply(auto 4 3 simp add: spmf_rel_map map_spmf_bind_spmf o_def bind_map_spmf bind_gpv.sel generat.rel_map simp del: bind_gpv_sel' intro!: rel_spmf_bind_reflI generat.rel_refl_strong rel_spmf_reflI rel_funI split!: generat.split option.split)
done
context includes lifting_syntax begin
lemma gpv_stop_parametric':
notes [transfer_rule] = the_gpv_parametric' the_gpv_parametric' Done_parametric' corec_gpv_parametric'
shows "(rel_gpv'' A C R ===> rel_gpv'' (rel_option A) C (rel_option R)) gpv_stop gpv_stop"
unfolding gpv_stop_def by transfer_prover
lemma gpv_stop_parametric [transfer_rule]:
shows "(rel_gpv A C ===> rel_gpv (rel_option A) C) gpv_stop gpv_stop"
unfolding gpv_stop_def by transfer_prover
lemma gpv_stop_transfer:
"(rel_gpv'' A B C ===> rel_gpv'' (pcr_Some A) B (pcr_Some C)) (λx. x) gpv_stop"
apply(rule rel_funI)
subgoal for gpv gpv'
apply(coinduction arbitrary: gpv gpv')
apply(drule rel_gpv''D)
apply(auto simp add: spmf_rel_map generat.rel_map rel_fun_def elim!: pcr_SomeE generat.rel_mono_strong rel_spmf_mono)
done
done
end
lemma gpv_stop_map' [simp]:
"gpv_stop (map_gpv' f g h gpv) = map_gpv' (map_option f) g (map_option h) (gpv_stop gpv)"
apply(coinduction arbitrary: gpv rule: gpv.coinduct_strong)
apply(auto 4 3 simp add: spmf_rel_map generat.rel_map intro!: rel_spmf_reflI generat.rel_refl_strong split!: option.split)
done
lemma interaction_bound_gpv_stop [simp]:
"interaction_bound consider (gpv_stop gpv) = interaction_bound consider gpv"
proof(induction arbitrary: gpv rule: parallel_fixp_induct_strong_1_1[OF complete_lattice_partial_function_definitions complete_lattice_partial_function_definitions interaction_bound.mono interaction_bound.mono interaction_bound_def interaction_bound_def, case_names adm bottom step])
case adm show ?case by simp
case bottom show ?case by simp
next
case (step interaction_bound' interaction_bound'')
have "(SUP x. interaction_bound' (case x of None ⇒ Done None | Some input ⇒ gpv_stop (c input))) =
(SUP input. interaction_bound'' (c input))" (is "?lhs = ?rhs" is "(SUP x. ?f x) = _")
if "IO out c ∈ set_spmf (the_gpv gpv)" for out c
proof -
have "?lhs = sup (interaction_bound' (Done None)) (⨆x. ?f (Some x))"
by (simp add: UNIV_option_conv image_comp)
also have "interaction_bound' (Done None) = 0" using step.hyps(1)[of "Done None"] by simp
also have "(⨆x. ?f (Some x)) = ?rhs" by (simp add: step.IH)
finally show ?thesis by (simp add: bot_enat_def [symmetric])
qed
then show ?case
by (auto simp add: case_map_generat o_def image_comp cong del: generat.case_cong_weak if_weak_cong intro!: SUP_cong split: generat.split)
qed
abbreviation exec_gpv_stop :: "('s ⇒ 'c ⇒ ('r option × 's) spmf) ⇒ ('a, 'c, 'r) gpv ⇒ 's ⇒ ('a option × 's) spmf"
where "exec_gpv_stop callee gpv ≡ exec_gpv callee (gpv_stop gpv)"
abbreviation inline_stop :: "('s ⇒ 'c ⇒ ('r option × 's, 'c', 'r') gpv) ⇒ ('a, 'c, 'r) gpv ⇒ 's ⇒ ('a option × 's, 'c', 'r') gpv"
where "inline_stop callee gpv ≡ inline callee (gpv_stop gpv)"
context
fixes joint_oracle :: "'s1 ⇒ 's2 ⇒ 'c ⇒ (('r option × 's1) option × ('r option × 's2) option) pmf"
and callee1 :: "'s1 ⇒ 'c ⇒ ('r option × 's1) spmf"
notes [[function_internals]]
begin
partial_function (spmf) exec_until_stop :: "('a option, 'c, 'r) gpv ⇒ 's1 ⇒ 's2 ⇒ bool ⇒ ('a option × 's1 × 's2) spmf"
where
"exec_until_stop gpv s1 s2 b =
(if b then
bind_spmf (the_gpv gpv) (λgenerat. case generat of
Pure x ⇒ return_spmf (x, s1, s2)
| IO out rpv ⇒ bind_pmf (joint_oracle s1 s2 out) (λ(a, b).
case a of None ⇒ return_pmf None
| Some (r1, s1') ⇒ (case b of None ⇒ undefined | Some (r2, s2') ⇒
(case (r1, r2) of (None, None) ⇒ exec_until_stop (Done None) s1' s2' True
| (Some r1', Some r2') ⇒ exec_until_stop (rpv r1') s1' s2' True
| (None, Some r2') ⇒ exec_until_stop (Done None) s1' s2' True
| (Some r1', None) ⇒ exec_until_stop (rpv r1') s1' s2' False))))
else
bind_spmf (the_gpv gpv) (λgenerat. case generat of
Pure x ⇒ return_spmf (None, s1, s2)
| IO out rpv ⇒ bind_spmf (callee1 s1 out) (λ(r1, s1').
case r1 of None ⇒ exec_until_stop (Done None) s1' s2 False
| Some r1' ⇒ exec_until_stop (rpv r1') s1' s2 False)))"
end
lemma ord_spmf_exec_gpv_stop:
fixes callee1 :: "('c, 'r option, 's) callee"
and callee2 :: "('c, 'r option, 's) callee"
and S :: "'s ⇒ 's ⇒ bool"
and gpv :: "('a, 'c, 'r) gpv"
assumes bisim:
"⋀s1 s2 x. ⟦ S s1 s2; ¬ stop s2 ⟧ ⟹
ord_spmf (λ(r1, s1') (r2, s2'). le_option r2 r1 ∧ S s1' s2' ∧ (r2 = None ∧ r1 ≠ None ⟷ stop s2'))
(callee1 s1 x) (callee2 s2 x)"
and init: "S s1 s2"
and go: "¬ stop s2"
and sticking: "⋀s1 s2 x y s1'. ⟦ (y, s1') ∈ set_spmf (callee1 s1 x); S s1 s2; stop s2 ⟧ ⟹ S s1' s2"
shows "ord_spmf (rel_prod (ord_option ⊤)¯¯ S) (exec_gpv_stop callee1 gpv s1) (exec_gpv_stop callee2 gpv s2)"
proof -
let ?R = "λ(r1, s1') (r2, s2'). le_option r2 r1 ∧ S s1' s2' ∧ (r2 = None ∧ r1 ≠ None ⟷ stop s2')"
obtain joint :: "'s ⇒ 's ⇒ 'c ⇒ (('r option × 's) option × ('r option × 's) option) pmf"
where j1: "map_pmf fst (joint s1 s2 x) = callee1 s1 x"
and j2: "map_pmf snd (joint s1 s2 x) = callee2 s2 x"
and rel [rule_format, rotated -1]: "∀(a, b) ∈ set_pmf (joint s1 s2 x). ord_option ?R a b"
if "S s1 s2" "¬ stop s2" for x s1 s2 using bisim
apply atomize_elim
apply(subst (asm) rel_pmf.simps)
apply(unfold rel_spmf_simps all_conj_distrib[symmetric] all_simps(6) imp_conjR[symmetric])
apply(subst all_comm)
apply(subst (2) all_comm)
apply(subst choice_iff[symmetric] ex_simps(6))+
apply fastforce
done
note [simp del] = top_apply conversep_iff id_apply
have "¬ stop s2 ⟹ rel_spmf (rel_prod (ord_option ⊤)¯¯ S) (exec_gpv_stop callee1 gpv s1) (map_spmf (λ(x, s1, s2). (x, s2)) (exec_until_stop joint callee1 (map_gpv Some id gpv) s1 s2 True))"
and "rel_spmf (rel_prod (ord_option ⊤)¯¯ S) (exec_gpv callee1 (Done None :: ('a option, 'c, 'r option) gpv) s1) (map_spmf (λ(x, s1, s2). (x, s2)) (exec_until_stop joint callee1 (Done None :: ('a option, 'c, 'r) gpv) s1 s2 b))"
and "stop s2 ⟹ rel_spmf (rel_prod (ord_option ⊤)¯¯ S) (exec_gpv_stop callee1 gpv s1) (map_spmf (λ(x, s1, y). (x, y)) (exec_until_stop joint callee1 (map_gpv Some id gpv) s1 s2 False))"
for b using init
proof(induction arbitrary: gpv s1 s2 b rule: parallel_fixp_induct_2_4[OF partial_function_definitions_spmf partial_function_definitions_spmf exec_gpv.mono exec_until_stop.mono exec_gpv_def exec_until_stop_def, unfolded lub_spmf_empty, case_names adm bottom step])
case adm show ?case by simp
{ case bottom case 1 show ?case by simp }
{ case bottom case 2 show ?case by simp }
{ case bottom case 3 show ?case by simp }
next
case (step exec_gpv' exec_until_stop') case step: 1
show ?case using step.prems
apply(rewrite gpv_stop.sel)
apply(simp add: map_spmf_bind_spmf bind_map_spmf gpv.map_sel)
apply(rule rel_spmf_bind_reflI)
apply(clarsimp split!: generat.split)
apply(rewrite j1[symmetric], assumption+)
apply(rewrite bind_spmf_def)
apply(auto 4 3 split!: option.split dest: rel intro: step.IH intro!: rel_pmf_bind_reflI simp add: map_bind_pmf bind_map_pmf)
done
next
case step case 2
then show ?case by(simp add: conversep_iff)
next
case (step exec_gpv' exec_until_stop') case step: 3
show ?case using step.prems
apply(simp add: map_spmf_bind_spmf bind_map_spmf gpv.map_sel)
apply(rule rel_spmf_bind_reflI)
apply(clarsimp simp add: map_spmf_bind_spmf split!: generat.split)
apply(rule rel_spmf_bind_reflI)
apply clarsimp
apply(drule (2) sticking)
apply(auto split!: option.split intro: step.IH)
done
qed
note this(1)[OF go]
also
have "¬ stop s2 ⟹ ord_spmf (=) (map_spmf (λ(x, s1, s2). (x, s2)) (exec_until_stop joint callee1 (map_gpv Some id gpv) s1 s2 True)) (exec_gpv_stop callee2 gpv s2)"
and "ord_spmf (=) (map_spmf (λ(x, s1, y). (x, y)) (exec_until_stop joint callee1 (Done None :: ('a option, 'c, 'r) gpv) s1 s2 b)) (return_spmf (None, s2))"
and "stop s2 ⟹ ord_spmf (=) (map_spmf (λ(x, s1, s2). (x, s2)) (exec_until_stop joint callee1 (map_gpv Some id gpv) s1 s2 False)) (return_spmf (None, s2))"
for b using init
proof(induction arbitrary: gpv s1 s2 b rule: exec_until_stop.fixp_induct[case_names adm bottom step])
case adm show ?case by simp
{ case bottom case 1 show ?case by simp }
{ case bottom case 2 show ?case by simp }
{ case bottom case 3 show ?case by simp }
next
case (step exec_until_stop') case step: 1
show ?case using step.prems
using [[show_variants]]
apply(rewrite exec_gpv.simps)
apply(simp add: map_spmf_bind_spmf bind_map_spmf gpv.map_sel)
apply(rule ord_spmf_bind_reflI)
apply(clarsimp split!: generat.split simp add: map_bind_pmf bind_spmf_def)
apply(rewrite j2[symmetric], assumption+)
apply(auto 4 3 split!: option.split dest: rel intro: step.IH intro!: rel_pmf_bind_reflI simp add: bind_map_pmf)
done
next
case step case 2 thus ?case by simp
next
case (step exec_until_stop') case 3
thus ?case
apply(simp add: map_spmf_bind_spmf o_def)
apply(rule ord_spmf_bind_spmfI1)
apply(clarsimp split!: generat.split simp add: map_spmf_bind_spmf o_def gpv.map_sel)
apply(rule ord_spmf_bind_spmfI1)
apply clarsimp
apply(drule (2) sticking)
apply(clarsimp split!: option.split simp add: step.IH)
done
qed
note this(1)[OF go]
finally show ?thesis by(rule pmf.rel_mono_strong)(auto elim!: option.rel_cases)
qed
end