Theory Monad_Normalisation_Test
theory Monad_Normalisation_Test
imports Monad_Normalisation
section ‹Tests and examples›
context includes monad_normalisation
assumes "f = id"
"do {x ← B; z ← C x; d ← E z x; a ← D z x; y ← A; return_pmf (x,y)} =
do {y ← A; x ← B; z ← C x; a ← D z x; d ← E z x; return_pmf (f (x,y))}"
apply (simp)
apply (simp add: assms)
lemma "(do {a ← E; b ← E; w ← B b a; z ← B a b; return_pmf (w,z)}) =
(do {a ← E; b ← E; z ← B a b; w ← B b a; return_pmf (w,z)})"
by (simp)
lemma "(do {a ← E; b ← E; w ← B b a; z ← B a b; return_pmf (w,z)}) =
(do {a ← E; b ← E; z ← B a b; w ← B b a; return_pmf (w,z)})"
by (simp)
lemma "do {y ← A; x ← A; z ← B x y y; w ← B x x y; Some (x,y)} =
do {x ← A; y ← A; z ← B x x y; w ← B x y y; Some (x,y)}"
by (simp)
lemma "do {y ← A; x ← A; z ← B x y y; w ← B x x y; {x,y}} =
do {x ← A; y ← A; z ← B x x y; w ← B x y y; {x,y}}"
by (simp)
lemma "do {y ← A; x ← A; z ← B x y y; w ← B x x y; return_pmf (x,y)} =
do {x ← A; y ← A; z ← B x x y; w ← B x y y; return_pmf (x,y)}"
by (simp)
lemma "do {x ← A 0; y ← A x; w ← B y y; z ← B x y; a ← C; Predicate.single (a,a)} =
do {x ← A 0; y ← A x; z ← B x y; w ← B y y; a ← C; Predicate.single (a,a)}"
by (simp)
lemma "do {x ← A 0; y ← A x; z ← B x y; w ← B y y; a ← C; return_pmf (a,a)} =
do {x ← A 0; y ← A x; z ← B y y; w ← B x y; a ← C; return_pmf (a,a)}"
by (simp)
lemma "do {x ← B; z ← C x; d ← E z x; a ← D z x; y ← A; return_pmf (x,y)} =
do {y ← A; x ← B; z ← C x; a ← D z x; d ← E z x; return_pmf (x,y)}"
by (simp)
no_adhoc_overloading Monad_Syntax.bind ⇌ bind_pmf
fixes 𝒜1 :: "'a ⇒ (('a × 'a) × 'b) spmf"
and 𝒜2 :: "'a × 'a ⇒ 'b ⇒ bool spmf"
and sample_uniform :: "nat ⇒ nat spmf"
and order :: "'a ⇒ nat"
"do {
x ← sample_uniform (order 𝒢);
y ← sample_uniform (order 𝒢);
z ← sample_uniform (order 𝒢);
b ← coin_spmf;
((msg1, msg2), σ) ← 𝒜1 (f x);
_ :: unit ← assert_spmf (valid_plain msg1 ∧ valid_plain msg2);
guess ← 𝒜2 (f y, xor (f z) (if b then msg1 else msg2)) σ;
return_spmf (guess ⟷ b)
} = do {
x ← sample_uniform (order 𝒢);
y ← sample_uniform (order 𝒢);
((msg1, msg2), σ) ← 𝒜1 (f x);
_ :: unit ← assert_spmf (valid_plain msg1 ∧ valid_plain msg2);
b ← coin_spmf;
x ← sample_uniform (order 𝒢);
guess ← 𝒜2 (f y, xor (f x) (if b then msg1 else msg2)) σ;
return_spmf (guess ⟷ b)
}" for xor
by (simp add: split_def)
"do {
x ← sample_uniform (order 𝒢);
xa ← sample_uniform (order 𝒢);
x ← 𝒜1 (f x);
case x of
(x, xb) ⇒
(case x of
(msg1, msg2) ⇒
λσ. do {
a ← assert_spmf (valid_plain msg1 ∧ valid_plain msg2);
x ← coin_spmf;
xaa ← map_spmf f (sample_uniform (order 𝒢));
guess ← 𝒜2 (f xa, xaa) σ;
return_spmf (guess ⟷ x)
} = do {
x ← sample_uniform (order 𝒢);
xa ← sample_uniform (order 𝒢);
x ← 𝒜1 (f x);
case x of
(x, xb) ⇒
(case x of
(msg1, msg2) ⇒
λσ. do {
a ← assert_spmf (valid_plain msg1 ∧ valid_plain msg2);
z ← map_spmf f (sample_uniform (order 𝒢));
guess ← 𝒜2 (f xa, z) σ;
map_spmf ((⟷) guess) coin_spmf
by (simp add: map_spmf_conv_bind_spmf)
lemma elgamal_step3:
"do {
x ← sample_uniform (order 𝒢);
y ← sample_uniform (order 𝒢);
b ← coin_spmf;
p ← 𝒜1 (f x);
_ ← assert_spmf (valid_plain (fst (fst p)) ∧ valid_plain (snd (fst p)));
guess ←
𝒜2 (f y, xor (f (x * y)) (if b then fst (fst p) else snd (fst p)))
(snd p);
return_spmf (guess ⟷ b)
} = do {
y ← sample_uniform (order 𝒢);
b ← coin_spmf;
p ← 𝒜1 (f y);
_ ← assert_spmf (valid_plain (fst (fst p)) ∧ valid_plain (snd (fst p)));
ya ← sample_uniform (order 𝒢);
b' ← 𝒜2 (f ya,
xor (f (y * ya)) (if b then fst (fst p) else snd (fst p)))
(snd p);
return_spmf (b' ⟷ b)
}" for xor
by (simp)
text ‹Distributivity›
"do {
x ← A :: nat spmf;
a ← B;
b ← B;
if a = b then do {
return_spmf x
} else do {
y ← C;
return_spmf (x + y)
} = do {
a ← B;
b ← B;
if b = a then A else do {
y ← C;
x ← A;
return_spmf (y + x)
by (simp add: add.commute cong: if_cong)
"do {
x ← A :: nat spmf;
p ← do {
a ← B;
b ← B;
return_spmf (a, b)
q ← coin_spmf;
if q then do {
return_spmf (x + fst p)
} else do {
y ← C;
return_spmf (y + snd p)
} = do {
q ← coin_spmf;
if q then do {
x ← A;
a ← B;
_ ← B;
return_spmf (x + a)
} else do {
y ← C;
a ← B;
_ ← B;
_ ← A;
return_spmf (y + a)
by (simp cong: if_cong)
fixes f :: "nat ⇒ nat ⇒ nat + nat"
"do {
x ← (A::nat set);
a ← B;
b ← B;
case f a b of
Inl c ⇒ {x}
| Inr c ⇒ do {
y ← C x;
{(x + y + c)}
} = do {
a ← B;
b ← B;
case f b a of
Inl c ⇒ A
| Inr c ⇒ do {
x ← A;
y ← C x;
{(y + c + x)}
by (simp add: add.commute add.left_commute cong: sum.case_cong)
section ‹Limits›
text ‹
The following example shows that the combination of monad normalisation and regular ordered
rewriting is not necessarily confluent.
lemma "do {a ← A; b ← A; Some (a ∧ b, b)} =
do {a ← A; b ← A; Some (a ∧ b, a)}"
apply (simp add: conj_comms)?
apply (rewrite option_bind_commute)
apply (simp only: conj_comms)
text ‹
The next example shows that even monad normalisation alone is not confluent because
the term ordering prevents the reordering of ‹f A› with ‹f B›.
But if we change ‹A› to ‹E›, then the reordering works as expected.
"do {a ← f A; b ← f B; c ← D b; d ← f C; F a c d} =
do {b ← f B; c ← D b; a ← f A; d ← f C; F a c d}"
for f :: "'b ⇒ 'a option" and D :: "'a ⇒ 'a option"
apply(subst option_bind_commute, subst (2) option_bind_commute, rule refl)
"do {a ← f E; b ← f B; c ← D b; d ← f C; F a c d} =
do {b ← f B; c ← D b; a ← f E; d ← f C; F a c d}"
for f :: "'b ⇒ 'a option" and D :: "'a ⇒ 'a option"
by simp