Theory Applicative_Lifting.Applicative_PMF
subsection ‹Probability mass functions›
theory Applicative_PMF imports
Applicative
"HOL-Probability.Probability"
"HOL-Library.Adhoc_Overloading"
begin
abbreviation (input) pure_pmf :: "'a ⇒ 'a pmf"
where "pure_pmf ≡ return_pmf"
definition ap_pmf :: "('a ⇒ 'b) pmf ⇒ 'a pmf ⇒ 'b pmf"
where "ap_pmf f x = map_pmf (λ(f, x). f x) (pair_pmf f x)"
adhoc_overloading Applicative.ap ap_pmf
context includes applicative_syntax
begin
lemma ap_pmf_id: "pure_pmf (λx. x) ⋄ x = x"
by(simp add: ap_pmf_def pair_return_pmf1 pmf.map_comp o_def)
lemma ap_pmf_comp: "pure_pmf (∘) ⋄ u ⋄ v ⋄ w = u ⋄ (v ⋄ w)"
by(simp add: ap_pmf_def pair_return_pmf1 pair_map_pmf1 pair_map_pmf2 pmf.map_comp o_def split_def pair_pair_pmf)
lemma ap_pmf_homo: "pure_pmf f ⋄ pure_pmf x = pure_pmf (f x)"
by(simp add: ap_pmf_def pair_return_pmf1)
lemma ap_pmf_interchange: "u ⋄ pure_pmf x = pure_pmf (λf. f x) ⋄ u"
by(simp add: ap_pmf_def pair_return_pmf1 pair_return_pmf2 pmf.map_comp o_def)
lemma ap_pmf_K: "return_pmf (λx _. x) ⋄ x ⋄ y = x"
by(simp add: ap_pmf_def pair_map_pmf1 pmf.map_comp pair_return_pmf1 o_def split_def map_fst_pair_pmf)
lemma ap_pmf_C: "return_pmf (λf x y. f y x) ⋄ f ⋄ x ⋄ y = f ⋄ y ⋄ x"
apply(simp add: ap_pmf_def pair_map_pmf1 pmf.map_comp pair_return_pmf1 pair_pair_pmf o_def split_def)
apply(subst (2) pair_commute_pmf)
apply(simp add: pair_map_pmf2 pmf.map_comp o_def split_def)
done
lemma ap_pmf_transfer[transfer_rule]:
"rel_fun (rel_pmf (rel_fun A B)) (rel_fun (rel_pmf A) (rel_pmf B)) ap_pmf ap_pmf"
unfolding ap_pmf_def[abs_def] pair_pmf_def
by transfer_prover
applicative pmf (C, K)
for
pure: pure_pmf
ap: ap_pmf
rel: rel_pmf
set: set_pmf
proof -
fix R :: "'a ⇒ 'b ⇒ bool"
show "rel_fun R (rel_pmf R) pure_pmf pure_pmf" by transfer_prover
next
fix R and f :: "('a ⇒ 'b) pmf" and g :: "('a ⇒ 'c) pmf" and x
assume [transfer_rule]: "rel_pmf (rel_fun (eq_on (set_pmf x)) R) f g"
have [transfer_rule]: "rel_pmf (eq_on (set_pmf x)) x x" by (simp add: pmf.rel_refl_strong)
show "rel_pmf R (ap_pmf f x) (ap_pmf g x)" by transfer_prover
qed(rule ap_pmf_comp[unfolded o_def[abs_def]] ap_pmf_homo ap_pmf_C ap_pmf_K)+
end
end