Theory SPMF_Applicative
theory SPMF_Applicative imports
Applicative_Lifting.Applicative_PMF
Set_Applicative
"HOL-Probability.SPMF"
begin
declare eq_on_def [simp del]
subsection ‹Applicative instance for @{typ "'a spmf"}›
abbreviation (input) pure_spmf :: "'a ⇒ 'a spmf"
where "pure_spmf ≡ return_spmf"
definition ap_spmf :: "('a ⇒ 'b) spmf ⇒ 'a spmf ⇒ 'b spmf"
where "ap_spmf f x = map_spmf (λ(f, x). f x) (pair_spmf f x)"
lemma ap_spmf_conv_bind: "ap_spmf f x = bind_spmf f (λf. bind_spmf x (λx. return_spmf (f x)))"
by(simp add: ap_spmf_def map_spmf_conv_bind_spmf pair_spmf_alt_def)
adhoc_overloading Applicative.ap ap_spmf
context includes applicative_syntax begin
lemma ap_spmf_id: "pure_spmf (λx. x) ⋄ x = x"
by(simp add: ap_spmf_def pair_spmf_return_spmf1 spmf.map_comp o_def)
lemma ap_spmf_comp: "pure_spmf (∘) ⋄ u ⋄ v ⋄ w = u ⋄ (v ⋄ w)"
by(simp add: ap_spmf_def pair_spmf_return_spmf1 pair_map_spmf1 pair_map_spmf2 spmf.map_comp o_def split_def pair_pair_spmf)
lemma ap_spmf_homo: "pure_spmf f ⋄ pure_spmf x = pure_spmf (f x)"
by(simp add: ap_spmf_def pair_spmf_return_spmf1)
lemma ap_spmf_interchange: "u ⋄ pure_spmf x = pure_spmf (λf. f x) ⋄ u"
by(simp add: ap_spmf_def pair_spmf_return_spmf1 pair_spmf_return_spmf2 spmf.map_comp o_def)
lemma ap_spmf_C: "return_spmf (λf x y. f y x) ⋄ f ⋄ x ⋄ y = f ⋄ y ⋄ x"
apply(simp add: ap_spmf_def pair_map_spmf1 spmf.map_comp pair_spmf_return_spmf1 pair_pair_spmf o_def split_def)
apply(subst (2) pair_commute_spmf)
apply(simp add: pair_map_spmf2 spmf.map_comp o_def split_def)
done
applicative spmf (C)
for
pure: pure_spmf
ap: ap_spmf
by(rule ap_spmf_id ap_spmf_comp[unfolded o_def[abs_def]] ap_spmf_homo ap_spmf_interchange ap_spmf_C)+
lemma set_ap_spmf [simp]: "set_spmf (p ⋄ q) = set_spmf p ⋄ set_spmf q"
by(auto simp add: ap_spmf_def ap_set_def)
lemma bind_ap_spmf: "bind_spmf (p ⋄ x) f = bind_spmf p (λp. x ⤜ (λx. f (p x)))"
by(simp add: ap_spmf_conv_bind)
lemma bind_pmf_ap_return_spmf [simp]: "bind_pmf (ap_spmf (return_spmf f) p) g = bind_pmf p (g ∘ map_option f)"
by(auto simp add: ap_spmf_conv_bind bind_spmf_def bind_return_pmf bind_assoc_pmf intro: bind_pmf_cong split: option.split)
lemma map_spmf_conv_ap [applicative_unfold]: "map_spmf f p = return_spmf f ⋄ p"
by(simp add: map_spmf_conv_bind_spmf ap_spmf_conv_bind)
end
end