Theory Monomorphic_Monad
theory Monomorphic_Monad imports
"HOL-Probability.Probability"
"HOL-Library.Multiset"
"HOL-Library.Countable_Set_Type"
begin
section ‹Preliminaries›
lemma (in comp_fun_idem) fold_set_union:
"⟦ finite A; finite B ⟧ ⟹ Finite_Set.fold f x (A ∪ B) = Finite_Set.fold f (Finite_Set.fold f x A) B"
by(induction A arbitrary: x rule: finite_induct)(simp_all add: fold_insert_idem2 del: fold_insert_idem)
lemma (in comp_fun_idem) ffold_set_union: "ffold f x (A |∪| B) = ffold f (ffold f x A) B"
including fset.lifting by(transfer fixing: f)(rule fold_set_union)
lemma relcompp_top_top [simp]: "top OO top = top"
by(auto simp add: fun_eq_iff)
attribute_setup locale_witness = ‹Scan.succeed Locale.witness_add›
named_theorems monad_unfold "Defining equations for overloaded monad operations"
context includes lifting_syntax begin
inductive rel_itself :: "'a itself ⇒ 'b itself ⇒ bool"
where "rel_itself TYPE(_) TYPE(_)"
lemma type_parametric [transfer_rule]: "rel_itself TYPE('a) TYPE('b)"
by(simp add: rel_itself.simps)
lemma plus_multiset_parametric [transfer_rule]:
"(rel_mset A ===> rel_mset A ===> rel_mset A) (+) (+)"
apply(rule rel_funI)+
subgoal premises prems using prems by induction(auto intro: rel_mset_Plus)
done
lemma Mempty_parametric [transfer_rule]: "rel_mset A {#} {#}"
by(fact rel_mset_Zero)
lemma fold_mset_parametric:
assumes 12: "(A ===> B ===> B) f1 f2"
and "comp_fun_commute f1" "comp_fun_commute f2"
shows "(B ===> rel_mset A ===> B) (fold_mset f1) (fold_mset f2)"
proof(rule rel_funI)+
interpret f1: comp_fun_commute f1 by fact
interpret f2: comp_fun_commute f2 by fact
show "B (fold_mset f1 z1 X) (fold_mset f2 z2 Y)"
if "rel_mset A X Y" "B z1 z2" for z1 z2 X Y
using that(1) by(induction R≡A X Y)(simp_all add: that(2) 12[THEN rel_funD, THEN rel_funD])
qed
lemma rel_fset_induct [consumes 1, case_names empty step, induct pred: rel_fset]:
assumes XY: "rel_fset A X Y"
and empty: "P {||} {||}"
and step: "⋀X Y x y. ⟦ rel_fset A X Y; P X Y; A x y; x |∉| X ∨ y |∉| Y ⟧ ⟹ P (finsert x X) (finsert y Y)"
shows "P X Y"
proof -
from XY obtain Z where X: "X = fst |`| Z" and Y: "Y = snd |`| Z" and Z: "fBall Z (λ(x, y). A x y)"
unfolding fset.in_rel by auto
from Z show ?thesis unfolding X Y
proof(induction Z)
case (insert xy Z)
obtain x y where [simp]: "xy = (x, y)" by(cases xy)
show ?case using insert
apply(cases "x |∈| fst |`| Z ∧ y |∈| snd |`| Z")
apply(simp add: finsert_absorb)
apply(auto intro!: step simp add: fset.in_rel; blast)
done
qed(simp add: assms)
qed
lemma ffold_parametric:
assumes 12: "(A ===> B ===> B) f1 f2"
and "comp_fun_idem f1" "comp_fun_idem f2"
shows "(B ===> rel_fset A ===> B) (ffold f1) (ffold f2)"
proof(rule rel_funI)+
interpret f1: comp_fun_idem f1 by fact
interpret f2: comp_fun_idem f2 by fact
show "B (ffold f1 z1 X) (ffold f2 z2 Y)"
if "rel_fset A X Y" "B z1 z2" for z1 z2 X Y
using that(1) by(induction)(simp_all add: that(2) 12[THEN rel_funD, THEN rel_funD])
qed
end
lemma rel_set_Grp: "rel_set (BNF_Def.Grp A f) = BNF_Def.Grp {X. X ⊆ A} (image f)"
by(auto simp add: fun_eq_iff Grp_def rel_set_def)
context includes cset.lifting begin
lemma cUNION_assoc: "cUNION (cUNION A f) g = cUNION A (λx. cUNION (f x) g)"
by transfer auto
lemma cUnion_cempty [simp]: "cUnion cempty = cempty"
by transfer simp
lemma cUNION_cempty [simp]: "cUNION cempty f = cempty"
by simp
lemma cUnion_cinsert: "cUnion (cinsert x A) = cUn x (cUnion A)"
by transfer simp
lemma cUNION_cinsert: "cUNION (cinsert x A) f = cUn (f x) (cUNION A f)"
by (simp add: cUnion_cinsert)
lemma cUnion_csingle [simp]: "cUnion (csingle x) = x"
by (simp add: cUnion_cinsert)
lemma cUNION_csingle [simp]: "cUNION (csingle x) f = f x"
by simp
lemma cUNION_csingle2 [simp]: "cUNION A csingle = A"
by (fact cUN_csingleton)
lemma cUNION_cUn: "cUNION (cUn A B) f = cUn (cUNION A f) (cUNION B f)"
by simp
lemma cUNION_parametric [transfer_rule]: includes lifting_syntax shows
"(rel_cset A ===> (A ===> rel_cset B) ===> rel_cset B) cUNION cUNION"
unfolding rel_fun_def by transfer(blast intro: rel_set_UNION)
end
locale three =
fixes tytok :: "'a itself"
assumes ex_three: "∃x y z :: 'a. x ≠ y ∧ x ≠ z ∧ y ≠ z"
begin
definition threes :: "'a × 'a × 'a" where
"threes = (SOME (x, y, z). x ≠ y ∧ x ≠ z ∧ y ≠ z)"
definition three⇩1 :: 'a ("❙1") where "❙1 = fst threes"
definition three⇩2 :: 'a ("❙2") where "❙2 = fst (snd threes)"
definition three⇩3 :: 'a ("❙3") where "❙3 = snd (snd (threes))"
lemma three_neq_aux: "❙1 ≠ ❙2" "❙1 ≠ ❙3" "❙2 ≠ ❙3"
proof -
have "❙1 ≠ ❙2 ∧ ❙1 ≠ ❙3 ∧ ❙2 ≠ ❙3"
unfolding three⇩1_def three⇩2_def three⇩3_def threes_def split_def
by(rule someI_ex)(use ex_three in auto)
then show "❙1 ≠ ❙2" "❙1 ≠ ❙3" "❙2 ≠ ❙3" by simp_all
qed
lemmas three_neq [simp] = three_neq_aux three_neq_aux[symmetric]
inductive rel_12_23 :: "'a ⇒ 'a ⇒ bool" where
"rel_12_23 ❙1 ❙2"
| "rel_12_23 ❙2 ❙3"
lemma bi_unique_rel_12_23 [simp, transfer_rule]: "bi_unique rel_12_23"
by(auto simp add: bi_unique_def rel_12_23.simps)
inductive rel_12_21 :: "'a ⇒ 'a ⇒ bool" where
"rel_12_21 ❙1 ❙2"
| "rel_12_21 ❙2 ❙1"
lemma bi_unique_rel_12_21 [simp, transfer_rule]: "bi_unique rel_12_21"
by(auto simp add: bi_unique_def rel_12_21.simps)
end
lemma bernoulli_pmf_0: "bernoulli_pmf 0 = return_pmf False"
by(rule pmf_eqI)(simp split: split_indicator)
lemma bernoulli_pmf_1: "bernoulli_pmf 1 = return_pmf True"
by(rule pmf_eqI)(simp split: split_indicator)
lemma bernoulli_Not: "map_pmf Not (bernoulli_pmf r) = bernoulli_pmf (1 - r)"
apply(rule pmf_eqI)
apply(rewrite in "pmf _ ⌑ = _" not_not[symmetric])
apply(subst pmf_map_inj')
apply(simp_all add: inj_on_def bernoulli_pmf.rep_eq min_def max_def)
done
lemma pmf_eqI_avoid: "p = q" if "⋀i. i ≠ x ⟹ pmf p i = pmf q i"
proof(rule pmf_eqI)
show "pmf p i = pmf q i" for i
proof(cases "i = x")
case [simp]: True
have "pmf p i = measure_pmf.prob p {i}" by(simp add: measure_pmf_single)
also have "… = 1 - measure_pmf.prob p (UNIV - {i})"
by(subst measure_pmf.prob_compl[unfolded space_measure_pmf]) simp_all
also have "measure_pmf.prob p (UNIV - {i}) = measure_pmf.prob q (UNIV - {i})"
unfolding integral_pmf[symmetric] by(rule Bochner_Integration.integral_cong)(auto intro: that)
also have "1 - … = measure_pmf.prob q {i}"
by(subst measure_pmf.prob_compl[unfolded space_measure_pmf]) simp_all
also have "… = pmf q i" by(simp add: measure_pmf_single)
finally show ?thesis .
next
case False
then show ?thesis by(rule that)
qed
qed
section ‹Locales for monomorphic monads›
subsection ‹Plain monad›
type_synonym ('a, 'm) bind = "'m ⇒ ('a ⇒ 'm) ⇒ 'm"
type_synonym ('a, 'm) return = "'a ⇒ 'm"
locale monad_base =
fixes return :: "('a, 'm) return"
and bind :: "('a, 'm) bind"
begin
primrec sequence :: "'m list ⇒ ('a list ⇒ 'm) ⇒ 'm"
where
"sequence [] f = f []"
| "sequence (x # xs) f = bind x (λa. sequence xs (f ∘ (#) a))"
definition lift :: "('a ⇒ 'a) ⇒ 'm ⇒ 'm"
where "lift f x = bind x (λx. return (f x))"
end
declare
monad_base.sequence.simps [code]
monad_base.lift_def [code]
context includes lifting_syntax begin
lemma sequence_parametric [transfer_rule]:
"((M ===> (A ===> M) ===> M) ===> list_all2 M ===> (list_all2 A ===> M) ===> M) monad_base.sequence monad_base.sequence"
unfolding monad_base.sequence_def[abs_def] by transfer_prover
lemma lift_parametric [transfer_rule]:
"((A ===> M) ===> (M ===> (A ===> M) ===> M) ===> (A ===> A) ===> M ===> M) monad_base.lift monad_base.lift"
unfolding monad_base.lift_def by transfer_prover
end
locale monad = monad_base return bind
for return :: "('a, 'm) return"
and bind :: "('a, 'm) bind"
+
assumes bind_assoc: "⋀(x :: 'm) f g. bind (bind x f) g = bind x (λy. bind (f y) g)"
and return_bind: "⋀x f. bind (return x) f = f x"
and bind_return: "⋀x. bind x return = x"
begin
lemma bind_lift [simp]: "bind (lift f x) g = bind x (g ∘ f)"
by(simp add: lift_def bind_assoc return_bind o_def)
lemma lift_bind [simp]: "lift f (bind m g) = bind m (λx. lift f (g x))"
by(simp add: lift_def bind_assoc)
end
subsection ‹State›
type_synonym ('s, 'm) get = "('s ⇒ 'm) ⇒ 'm"
type_synonym ('s, 'm) put = "'s ⇒ 'm ⇒ 'm"
locale monad_state_base = monad_base return bind
for return :: "('a, 'm) return"
and bind :: "('a, 'm) bind"
+
fixes get :: "('s, 'm) get"
and put :: "('s, 'm) put"
begin
definition update :: "('s ⇒ 's) ⇒ 'm ⇒ 'm"
where "update f m = get (λs. put (f s) m)"
end
declare monad_state_base.update_def [code]
lemma update_parametric [transfer_rule]: includes lifting_syntax shows
"(((S ===> M) ===> M) ===> (S ===> M ===> M) ===> (S ===> S) ===> M ===> M)
monad_state_base.update monad_state_base.update"
unfolding monad_state_base.update_def by transfer_prover
locale monad_state = monad_state_base return bind get put + monad return bind
for return :: "('a, 'm) return"
and bind :: "('a, 'm) bind"
and get :: "('s, 'm) get"
and put :: "('s, 'm) put"
+
assumes put_get: "⋀f. put s (get f) = put s (f s)"
and get_get: "⋀f. get (λs. get (f s)) = get (λs. f s s)"
and put_put: "put s (put s' m) = put s' m"
and get_put: "get (λs. put s m) = m"
and get_const: "⋀m. get (λ_. m) = m"
and bind_get: "⋀f g. bind (get f) g = get (λs. bind (f s) g)"
and bind_put: "⋀f. bind (put s m) f = put s (bind m f)"
begin
lemma put_update: "put s (update f m) = put (f s) m"
by(simp add: update_def put_get put_put)
lemma update_put: "update f (put s m) = put s m"
by(simp add: update_def put_put get_const)
lemma bind_update: "bind (update f m) g = update f (bind m g)"
by(simp add: update_def bind_get bind_put)
lemma update_get: "update f (get g) = get (update f ∘ g ∘ f)"
by(simp add: update_def put_get get_get o_def)
lemma update_const: "update (λ_. s) m = put s m"
by(simp add: update_def get_const)
lemma update_update: "update f (update g m) = update (g ∘ f) m"
by(simp add: update_def put_get put_put)
lemma update_id: "update id m = m"
by(simp add: update_def get_put)
end
subsection ‹Failure›
type_synonym 'm fail = "'m"
locale monad_fail_base = monad_base return bind
for return :: "('a, 'm) return"
and bind :: "('a, 'm) bind"
+
fixes fail :: "'m fail"
begin
definition assert :: "('a ⇒ bool) ⇒ 'm ⇒ 'm"
where "assert P m = bind m (λx. if P x then return x else fail)"
end
declare monad_fail_base.assert_def [code]
lemma assert_parametric [transfer_rule]: includes lifting_syntax shows
"((A ===> M) ===> (M ===> (A ===> M) ===> M) ===> M ===> (A ===> (=)) ===> M ===> M)
monad_fail_base.assert monad_fail_base.assert"
unfolding monad_fail_base.assert_def by transfer_prover
locale monad_fail = monad_fail_base return bind fail + monad return bind
for return :: "('a, 'm) return"
and bind :: "('a, 'm) bind"
and fail :: "'m fail"
+
assumes fail_bind: "⋀f. bind fail f = fail"
begin
lemma assert_fail: "assert P fail = fail"
by(simp add: assert_def fail_bind)
end
subsection ‹Exception›
type_synonym 'm catch = "'m ⇒ 'm ⇒ 'm"
locale monad_catch_base = monad_fail_base return bind fail
for return :: "('a, 'm) return"
and bind :: "('a, 'm) bind"
and fail :: "'m fail"
+
fixes catch :: "'m catch"
locale monad_catch = monad_catch_base return bind fail catch + monad_fail return bind fail
for return :: "('a, 'm) return"
and bind :: "('a, 'm) bind"
and fail :: "'m fail"
and catch :: "'m catch"
+
assumes catch_return: "catch (return x) m = return x"
and catch_fail: "catch fail m = m"
and catch_fail2: "catch m fail = m"
and catch_assoc: "catch (catch m m') m'' = catch m (catch m' m'')"
locale monad_catch_state = monad_catch return bind fail catch + monad_state return bind get put
for return :: "('a, 'm) return"
and bind :: "('a, 'm) bind"
and fail :: "'m fail"
and catch :: "'m catch"
and get :: "('s, 'm) get"
and put :: "('s, 'm) put"
+
assumes catch_get: "catch (get f) m = get (λs. catch (f s) m)"
and catch_put: "catch (put s m) m' = put s (catch m m')"
begin
lemma catch_update: "catch (update f m) m' = update f (catch m m')"
by(simp add: update_def catch_get catch_put)
end
subsection ‹Reader›
text ‹As ask takes a continuation, we have to restate the monad laws for ask›
type_synonym ('r, 'm) ask = "('r ⇒ 'm) ⇒ 'm"
locale monad_reader_base = monad_base return bind
for return :: "('a, 'm) return"
and bind :: "('a, 'm) bind"
+
fixes ask :: "('r, 'm) ask"
locale monad_reader = monad_reader_base return bind ask + monad return bind
for return :: "('a, 'm) return"
and bind :: "('a, 'm) bind"
and ask :: "('r, 'm) ask"
+
assumes ask_ask: "⋀f. ask (λr. ask (f r)) = ask (λr. f r r)"
and ask_const: "ask (λ_. m) = m"
and bind_ask: "⋀f g. bind (ask f) g = ask (λr. bind (f r) g)"
and bind_ask2: "⋀f. bind m (λx. ask (f x)) = ask (λr. bind m (λx. f x r))"
begin
lemma ask_bind: "ask (λr. bind (f r) (g r)) = bind (ask f) (λx. ask (λr. g r x))"
by(simp add: bind_ask bind_ask2 ask_ask)
end
locale monad_reader_state =
monad_reader return bind ask +
monad_state return bind get put
for return :: "('a, 'm) return"
and bind :: "('a, 'm) bind"
and ask :: "('r, 'm) ask"
and get :: "('s, 'm) get"
and put :: "('s, 'm) put"
+
assumes ask_get: "⋀f. ask (λr. get (f r)) = get (λs. ask (λr. f r s))"
and put_ask: "⋀f. put s (ask f) = ask (λr. put s (f r))"
subsection ‹Probability›
type_synonym ('p, 'm) sample = "'p pmf ⇒ ('p ⇒ 'm) ⇒ 'm"
locale monad_prob_base = monad_base return bind
for return :: "('a, 'm) return"
and bind :: "('a, 'm) bind"
+
fixes sample :: "('p, 'm) sample"
locale monad_prob = monad return bind + monad_prob_base return bind sample
for return :: "('a, 'm) return"
and bind :: "('a, 'm) bind"
and sample :: "('p, 'm) sample"
+
assumes sample_const: "⋀p m. sample p (λ_. m) = m"
and sample_return_pmf: "⋀x f. sample (return_pmf x) f = f x"
and sample_bind_pmf: "⋀p f g. sample (bind_pmf p f) g = sample p (λx. sample (f x) g)"
and sample_commute: "⋀p q f. sample p (λx. sample q (f x)) = sample q (λy. sample p (λx. f x y))"
and bind_sample1: "⋀p f g. bind (sample p f) g = sample p (λx. bind (f x) g)"
and bind_sample2: "⋀m f p. bind m (λy. sample p (f y)) = sample p (λx. bind m (λy. f y x))"
and sample_parametric: "⋀R. bi_unique R ⟹ rel_fun (rel_pmf R) (rel_fun (rel_fun R (=)) (=)) sample sample"
begin
lemma sample_cong: "(⋀x. x ∈ set_pmf p ⟹ f x = g x) ⟹ sample p f = sample q g" if "p = q"
by(rule sample_parametric[where R="eq_onp (λx. x ∈ set_pmf p)", THEN rel_funD, THEN rel_funD])
(simp_all add: bi_unique_def eq_onp_def rel_fun_def pmf.rel_refl_strong that)
end
text ‹We can implement binary probabilistic choice using @{term sample} provided that the sample space
contains at least three elements.›
locale monad_prob3 = monad_prob return bind sample + three "TYPE('p)"
for return :: "('a, 'm) return"
and bind :: "('a, 'm) bind"
and sample :: "('p, 'm) sample"
begin
definition pchoose :: "real ⇒ 'm ⇒ 'm ⇒ 'm" where
"pchoose r m m' = sample (map_pmf (λb. if b then ❙1 else ❙2) (bernoulli_pmf r)) (λx. if x = ❙1 then m else m')"
abbreviation pchoose_syntax :: "'m ⇒ real ⇒ 'm ⇒ 'm" ("_ ⊲ _ ⊳ _" [100, 0, 100] 99) where
"m ⊲ r ⊳ m' ≡ pchoose r m m'"
lemma pchoose_0: "m ⊲ 0 ⊳ m' = m'"
by(simp add: pchoose_def bernoulli_pmf_0 sample_return_pmf)
lemma pchoose_1: "m ⊲ 1 ⊳ m' = m"
by(simp add: pchoose_def bernoulli_pmf_1 sample_return_pmf)
lemma pchoose_idemp: "m ⊲ r ⊳ m = m"
by(simp add: pchoose_def sample_const)
lemma pchoose_bind1: "bind (m ⊲ r ⊳ m') f = bind m f ⊲ r ⊳ bind m' f"
by(simp add: pchoose_def bind_sample1 if_distrib[where f="λm. bind m _"])
lemma pchoose_bind2: "bind m (λx. f x ⊲ p ⊳ g x) = bind m f ⊲ p ⊳ bind m g"
by(auto simp add: pchoose_def bind_sample2 intro!: arg_cong2[where f=sample])
lemma pchoose_commute: "m ⊲ 1 - r ⊳ m' = m' ⊲ r ⊳ m"
apply(simp add: pchoose_def bernoulli_Not[symmetric] pmf.map_comp o_def)
apply(rule sample_parametric[where R=rel_12_21, THEN rel_funD, THEN rel_funD])
subgoal by(simp)
subgoal by(rule pmf.map_transfer[where Rb="(=)", THEN rel_funD, THEN rel_funD])
(simp_all add: rel_fun_def rel_12_21.simps pmf.rel_eq)
subgoal by(simp add: rel_fun_def rel_12_21.simps)
done
lemma pchoose_assoc: "m ⊲ p ⊳ (m' ⊲ q ⊳ m'') = (m ⊲ r ⊳ m') ⊲ s ⊳ m''" (is "?lhs = ?rhs")
if "min 1 (max 0 p) = min 1 (max 0 r) * min 1 (max 0 s)"
and "1 - min 1 (max 0 s) = (1 - min 1 (max 0 p)) * (1 - min 1 (max 0 q))"
proof -
let ?f = "(λx. if x = ❙1 then m else if x = ❙2 then m' else m'')"
let ?p = "bind_pmf (map_pmf (λb. if b then ❙1 else ❙2) (bernoulli_pmf p))
(λx. if x = ❙1 then return_pmf ❙1 else map_pmf (λb. if b then ❙2 else ❙3) (bernoulli_pmf q))"
let ?q = "bind_pmf (map_pmf (λb. if b then ❙1 else ❙2) (bernoulli_pmf s))
(λx. if x = ❙1 then map_pmf (λb. if b then ❙1 else ❙2) (bernoulli_pmf r) else return_pmf ❙3)"
have [simp]: "{x. ¬ x} = {False}" "{x. x} = {True}" by auto
have "?lhs = sample ?p ?f"
by(auto simp add: pchoose_def sample_bind_pmf if_distrib[where f="λx. sample x _"] sample_return_pmf rel_fun_def rel_12_23.simps pmf.rel_eq cong: if_cong intro!: sample_cong[OF refl] sample_parametric[where R="rel_12_23", THEN rel_funD, THEN rel_funD] pmf.map_transfer[where Rb="(=)", THEN rel_funD, THEN rel_funD])
also have "?p = ?q"
proof(rule pmf_eqI_avoid)
fix i :: "'p"
assume "i ≠ ❙2"
then consider (one) "i = ❙1" | (three) "i = ❙3" | (other) "i ≠ ❙1" "i ≠ ❙2" "i ≠ ❙3" by metis
then show "pmf ?p i = pmf ?q i"
proof cases
case [simp]: one
have "pmf ?p i = measure_pmf.expectation (map_pmf (λb. if b then ❙1 else ❙2) (bernoulli_pmf p)) (indicator {❙1})"
unfolding pmf_bind
by(rule arg_cong2[where f=measure_pmf.expectation, OF refl])(auto simp add: fun_eq_iff pmf_eq_0_set_pmf)
also have "… = min 1 (max 0 p)"
by(simp add: vimage_def)(simp add: measure_pmf_single bernoulli_pmf.rep_eq)
also have "… = min 1 (max 0 s) * min 1 (max 0 r)" using that(1) by simp
also have "… = measure_pmf.expectation (bernoulli_pmf s)
(λx. indicator {True} x * pmf (map_pmf (λb. if b then ❙1 else ❙2) (bernoulli_pmf r)) ❙1)"
by(simp add: pmf_map vimage_def measure_pmf_single)(simp add: bernoulli_pmf.rep_eq)
also have "… = pmf ?q i"
unfolding pmf_bind integral_map_pmf
by(rule arg_cong2[where f=measure_pmf.expectation, OF refl])(auto simp add: fun_eq_iff pmf_eq_0_set_pmf)
finally show ?thesis .
next
case [simp]: three
have "pmf ?p i = measure_pmf.expectation (bernoulli_pmf p)
(λx. indicator {False} x * pmf (map_pmf (λb. if b then ❙2 else ❙3) (bernoulli_pmf q)) ❙3)"
unfolding pmf_bind integral_map_pmf
by(rule arg_cong2[where f=measure_pmf.expectation, OF refl])(auto simp add: fun_eq_iff pmf_eq_0_set_pmf)
also have "… = (1 - min 1 (max 0 p)) * (1 - min 1 (max 0 q))"
by(simp add: pmf_map vimage_def measure_pmf_single)(simp add: bernoulli_pmf.rep_eq)
also have "… = 1 - min 1 (max 0 s)" using that(2) by simp
also have "… = measure_pmf.expectation (map_pmf (λb. if b then ❙1 else ❙2) (bernoulli_pmf s)) (indicator {❙2})"
by(simp add: vimage_def)(simp add: measure_pmf_single bernoulli_pmf.rep_eq)
also have "… = pmf ?q i"
unfolding pmf_bind
by(rule Bochner_Integration.integral_cong_AE)(auto simp add: fun_eq_iff pmf_eq_0_set_pmf AE_measure_pmf_iff)
finally show ?thesis .
next
case other
then have "pmf ?p i = 0" "pmf ?q i = 0" by(auto simp add: pmf_eq_0_set_pmf)
then show ?thesis by simp
qed
qed
also have "sample ?q ?f = ?rhs"
by(auto simp add: pchoose_def sample_bind_pmf if_distrib[where f="λx. sample x _"] sample_return_pmf cong: if_cong intro!: sample_cong[OF refl])
finally show ?thesis .
qed
lemma pchoose_assoc': "m ⊲ p ⊳ (m' ⊲ q ⊳ m'') = (m ⊲ r ⊳ m') ⊲ s ⊳ m''"
if "p = r * s" and "1 - s = (1 - p) * (1 - q)"
and "0 ≤ p" "p ≤ 1" "0 ≤ q" "q ≤ 1" "0 ≤ r" "r ≤ 1" "0 ≤ s" "s ≤ 1"
by(rule pchoose_assoc; use that in ‹simp add: min_def max_def›)
end
locale monad_state_prob = monad_state return bind get put + monad_prob return bind sample
for return :: "('a, 'm) return"
and bind :: "('a, 'm) bind"
and get :: "('s, 'm) get"
and put :: "('s, 'm) put"
and sample :: "('p, 'm) sample"
+
assumes sample_get: "sample p (λx. get (f x)) = get (λs. sample p (λx. f x s))"
begin
lemma sample_put: "sample p (λx. put s (m x)) = put s (sample p m)"
proof -
fix UU
have "sample p (λx. put s (m x)) = sample p (λx. bind (put s (return UU)) (λ_. m x))"
by(simp add: bind_put return_bind)
also have "… = bind (put s (return UU)) (λ_. sample p m)"
by(simp add: bind_sample2)
also have "… = put s (sample p m)"
by(simp add: bind_put return_bind)
finally show ?thesis .
qed
lemma sample_update: "sample p (λx. update f (m x)) = update f (sample p m)"
by(simp add: update_def sample_get sample_put)
end
subsection ‹Nondeterministic choice›
subsubsection ‹Binary choice›
type_synonym 'm alt = "'m ⇒ 'm ⇒ 'm"
locale monad_alt_base = monad_base return bind
for return :: "('a, 'm) return"
and bind :: "('a, 'm) bind"
+
fixes alt :: "'m alt"
locale monad_alt = monad return bind + monad_alt_base return bind alt
for return :: "('a, 'm) return"
and bind :: "('a, 'm) bind"
and alt :: "'m alt"
+
assumes alt_assoc: "alt (alt m1 m2) m3 = alt m1 (alt m2 m3)"
and bind_alt1: "bind (alt m m') f = alt (bind m f) (bind m' f)"
locale monad_fail_alt = monad_fail return bind fail + monad_alt return bind alt
for return :: "('a, 'm) return"
and bind :: "('a, 'm) bind"
and fail :: "'m fail"
and alt :: "'m alt"
+
assumes alt_fail1: "alt fail m = m"
and alt_fail2: "alt m fail = m"
begin
lemma assert_alt: "assert P (alt m m') = alt (assert P m) (assert P m')"
by(simp add: assert_def bind_alt1)
end
locale monad_state_alt = monad_state return bind get put + monad_alt return bind alt
for return :: "('a, 'm) return"
and bind :: "('a, 'm) bind"
and get :: "('s, 'm) get"
and put :: "('s, 'm) put"
and alt :: "'m alt"
+
assumes alt_get: "alt (get f) (get g) = get (λx. alt (f x) (g x))"
and alt_put: "alt (put s m) (put s m') = put s (alt m m')"
begin
lemma alt_update: "alt (update f m) (update f m') = update f (alt m m')"
by(simp add: update_def alt_get alt_put)
end
subsubsection ‹Countable choice›
type_synonym ('c, 'm) altc = "'c cset ⇒ ('c ⇒ 'm) ⇒ 'm"
locale monad_altc_base = monad_base return bind
for return :: "('a, 'm) return"
and bind :: "('a, 'm) bind"
+
fixes altc :: "('c, 'm) altc"
begin
definition fail :: "'m fail" where "fail = altc cempty (λ_. undefined)"
end
declare monad_altc_base.fail_def [code]
locale monad_altc = monad return bind + monad_altc_base return bind altc
for return :: "('a, 'm) return"
and bind :: "('a, 'm) bind"
and altc :: "('c, 'm) altc"
+
assumes bind_altc1: "⋀C g f. bind (altc C g) f = altc C (λc. bind (g c) f)"
and altc_single: "⋀x f. altc (csingle x) f = f x"
and altc_cUNION: "⋀C f g. altc (cUNION C f) g = altc C (λx. altc (f x) g)"
and altc_parametric: "⋀R. bi_unique R ⟹ rel_fun (rel_cset R) (rel_fun (rel_fun R (=)) (=)) altc altc"
begin
lemma altc_cong: "cBall C (λx. f x = g x) ⟹ altc C f = altc C g"
apply(rule altc_parametric[where R="eq_onp (λx. cin x C)", THEN rel_funD, THEN rel_funD])
subgoal by(simp add: bi_unique_def eq_onp_def)
subgoal by(simp add: cset.rel_eq_onp eq_onp_same_args pred_cset_def cin_def)
subgoal by(simp add: rel_fun_def eq_onp_def cBall_def cin_def)
done
lemma monad_fail [locale_witness]: "monad_fail return bind fail"
proof
show "bind fail f = fail" for f
by(simp add: fail_def bind_altc1 cong: altc_cong)
qed
end
text ‹We can implement ‹alt› via ‹altc› only if we know that there are sufficiently
many elements in the choice type @{typ 'c}. For the associativity law, we need at least
three elements.›
locale monad_altc3 = monad_altc return bind altc + three "TYPE('c)"
for return :: "('a, 'm) return"
and bind :: "('a, 'm) bind"
and altc :: "('c, 'm) altc"
begin
definition alt :: "'m alt"
where "alt m1 m2 = altc (cinsert ❙1 (csingle ❙2)) (λc. if c = ❙1 then m1 else m2)"
lemma monad_alt: "monad_alt return bind alt"
proof
show "bind (alt m m') f = alt (bind m f) (bind m' f)" for m m' f
by(simp add: alt_def bind_altc1 if_distrib[where f="λm. bind m _"])
fix m1 m2 m3 :: 'm
let ?C = "cUNION (cinsert ❙1 (csingle ❙2)) (λc. if c = ❙1 then cinsert ❙1 (csingle ❙2) else csingle ❙3)"
let ?D = "cUNION (cinsert ❙1 (csingle ❙2)) (λc. if c = ❙1 then csingle ❙1 else cinsert ❙2 (csingle ❙3))"
let ?f = "λc. if c = ❙1 then m1 else if c = ❙2 then m2 else m3"
have "alt (alt m1 m2) m3 = altc ?C ?f"
by (simp only: altc_cUNION) (auto simp add: alt_def altc_single intro!: altc_cong)
also have "?C = ?D" including cset.lifting by transfer(auto simp add: insert_commute)
also have "altc ?D ?f = alt m1 (alt m2 m3)"
apply (simp only: altc_cUNION)
apply (clarsimp simp add: alt_def altc_single intro!: altc_cong)
apply (rule altc_parametric [where R="conversep rel_12_23", THEN rel_funD, THEN rel_funD])
subgoal by simp
subgoal including cset.lifting by transfer
(simp add: rel_set_def rel_12_23.simps)
subgoal by (simp add: rel_fun_def rel_12_23.simps)
done
finally show "alt (alt m1 m2) m3 = alt m1 (alt m2 m3)" .
qed
end
locale monad_state_altc =
monad_state return bind get put +
monad_altc return bind altc
for return :: "('a, 'm) return"
and bind :: "('a, 'm) bind"
and get :: "('s, 'm) get"
and put :: "('s, 'm) put"
and altc :: "('c, 'm) altc"
+
assumes altc_get: "⋀C f. altc C (λc. get (f c)) = get (λs. altc C (λc. f c s))"
and altc_put: "⋀C f. altc C (λc. put s (f c)) = put s (altc C f)"
subsection ‹Writer monad›
type_synonym ('w, 'm) tell = "'w ⇒ 'm ⇒ 'm"
locale monad_writer_base = monad_base return bind
for return :: "('a, 'm) return"
and bind :: "('a, 'm) bind"
+
fixes tell :: "('w, 'm) tell"
locale monad_writer = monad_writer_base return bind tell + monad return bind
for return :: "('a, 'm) return"
and bind :: "('a, 'm) bind"
and tell :: "('w, 'm) tell"
+
assumes bind_tell: "⋀w m f. bind (tell w m) f = tell w (bind m f)"
subsection ‹Resumption monad›
type_synonym ('o, 'i, 'm) pause = "'o ⇒ ('i ⇒ 'm) ⇒ 'm"
locale monad_resumption_base = monad_base return bind
for return :: "('a, 'm) return"
and bind :: "('a, 'm) bind"
+
fixes pause :: "('o, 'i, 'm) pause"
locale monad_resumption = monad_resumption_base return bind pause + monad return bind
for return :: "('a, 'm) return"
and bind :: "('a, 'm) bind"
and pause :: "('o, 'i, 'm) pause"
+
assumes bind_pause: "bind (pause out c) f = pause out (λi. bind (c i) f)"
subsection ‹Commutative monad›
locale monad_commute = monad return bind
for return :: "('a, 'm) return"
and bind :: "('a, 'm) bind"
+
assumes bind_commute: "bind m (λx. bind m' (f x)) = bind m' (λy. bind m (λx. f x y))"
subsection ‹Discardable monad›
locale monad_discard = monad return bind
for return :: "('a, 'm) return"
and bind :: "('a, 'm) bind"
+
assumes bind_const: "bind m (λ_. m') = m'"
subsection ‹Duplicable monad›
locale monad_duplicate = monad return bind
for return :: "('a, 'm) return"
and bind :: "('a, 'm) bind"
+
assumes bind_duplicate: "bind m (λx. bind m (f x)) = bind m (λx. f x x)"
section ‹Monad implementations›
subsection ‹Identity monad›
text ‹We need a type constructor such that we can overload the monad operations›