Theory Tracking_SPMF
section ‹Tracking SPMFs\label{sec:tracking_spmfs}›
text ‹This section introduces tracking SPMFs --- this is a resource monad on top of SPMFs, we also
introduce the Scott-continous monad morphism @{term "tspmf_of_ra"}, with which it is possible to
reason about the joint-distribution of a randomized algorithm's result and used coin-flips.
An example application of the results in this theory can be found in Section~\ref{sec:dice_roll}.›
theory Tracking_SPMF
imports Tracking_Randomized_Algorithm
begin
type_synonym 'a tspmf = "('a × nat) spmf"
definition return_tspmf :: "'a ⇒ 'a tspmf" where
"return_tspmf x = return_spmf (x,0)"
definition coin_tspmf :: "bool tspmf" where
"coin_tspmf = pair_spmf coin_spmf (return_spmf 1)"
definition bind_tspmf :: "'a tspmf ⇒ ('a ⇒ 'b tspmf) ⇒ 'b tspmf" where
"bind_tspmf f g = bind_spmf f (λ(r,c). map_spmf (apsnd ((+) c)) (g r))"
adhoc_overloading Monad_Syntax.bind bind_tspmf
text ‹Monad laws:›
lemma return_bind_tspmf:
"bind_tspmf (return_tspmf x) g = g x"
unfolding bind_tspmf_def return_tspmf_def map_spmf_conv_bind_spmf
by (simp add:apsnd_def map_prod_def)
lemma bind_tspmf_assoc:
"bind_tspmf (bind_tspmf f g) h = bind_tspmf f (λx. bind_tspmf (g x) h)"
unfolding bind_tspmf_def
by (simp add: case_prod_beta' algebra_simps map_spmf_conv_bind_spmf apsnd_def map_prod_def)
lemma bind_return_tspmf:
"bind_tspmf m return_tspmf = m"
unfolding bind_tspmf_def return_tspmf_def map_spmf_conv_bind_spmf apsnd_def
by (simp add:case_prod_beta')
lemma bind_mono_tspmf_aux:
assumes "ord_spmf (=) f1 f2" "⋀y. ord_spmf (=) (g1 y) (g2 y)"
shows "ord_spmf (=) (bind_tspmf f1 g1) (bind_tspmf f2 g2)"
using assms unfolding bind_tspmf_def map_spmf_conv_bind_spmf
by (auto intro!: bind_spmf_mono' simp add:case_prod_beta')
lemma bind_mono_tspmf [partial_function_mono]:
assumes "mono_spmf B" and "⋀y. mono_spmf (C y)"
shows "mono_spmf (λf. bind_tspmf (B f) (λy. C y f))"
using assms by (intro monotoneI bind_mono_tspmf_aux) (auto simp:monotone_def)
definition ord_tspmf :: "'a tspmf ⇒ 'a tspmf ⇒ bool" where
"ord_tspmf = ord_spmf (λx y. fst x = fst y ∧ snd x ≥ snd y)"
bundle ord_tspmf_notation
begin
notation ord_tspmf ("(_/ ≤⇩R _)" [51, 51] 50)
end
bundle no_ord_tspmf_notation
begin
no_notation ord_tspmf ("(_/ ≤⇩R _)" [51, 51] 50)
end
unbundle ord_tspmf_notation
definition coin_usage_of_tspmf :: "'a tspmf ⇒ enat pmf"
where "coin_usage_of_tspmf = map_pmf (λx. case x of None ⇒ ∞ | Some y ⇒ enat (snd y))"
definition expected_coin_usage_of_tspmf :: "'a tspmf ⇒ ennreal"
where
"expected_coin_usage_of_tspmf p = (∫⇧+x. x ∂(map_pmf ennreal_of_enat (coin_usage_of_tspmf p)))"
definition expected_coin_usage_of_ra where
"expected_coin_usage_of_ra p = ∫⇧+x. x ∂(map_pmf ennreal_of_enat (coin_usage_of_ra p))"
definition result :: "'a tspmf ⇒ 'a spmf"
where "result = map_spmf fst"
lemma coin_usage_of_tspmf_alt_def:
"coin_usage_of_tspmf p = map_pmf (λx. case x of None ⇒ ∞ | Some y ⇒ enat y) (map_spmf snd p)"
unfolding coin_usage_of_tspmf_def map_pmf_comp map_option_case
by (metis enat_def infinity_enat_def option.case_eq_if option.sel)
lemma coin_usage_of_tspmf_bind_return:
"coin_usage_of_tspmf (bind_tspmf f (λx. return_tspmf (g x))) = (coin_usage_of_tspmf f)"
unfolding bind_tspmf_def return_tspmf_def coin_usage_of_tspmf_alt_def map_spmf_bind_spmf
by (simp add:comp_def case_prod_beta map_spmf_conv_bind_spmf)
lemma coin_usage_of_tspmf_mono:
assumes "ord_tspmf p q"
shows "measure (coin_usage_of_tspmf p) {..k} ≤ measure (coin_usage_of_tspmf q) {..k}"
proof -
define p' where "p' = map_spmf snd p"
define q' where "q' = map_spmf snd q"
have 0:"ord_spmf (≥) p' q'"
using assms(1) ord_spmf_mono unfolding p'_def q'_def ord_tspmf_def ord_spmf_map_spmf12 by fastforce
have cp:"coin_usage_of_tspmf p = map_pmf (case_option ∞ enat) p'"
unfolding coin_usage_of_tspmf_alt_def p'_def by simp
have cq:"coin_usage_of_tspmf q = map_pmf (case_option ∞ enat) q'"
unfolding coin_usage_of_tspmf_alt_def q'_def by simp
have 0:"rel_pmf (≥) (coin_usage_of_tspmf p) (coin_usage_of_tspmf q)"
unfolding cp cq map_pmf_def by (intro rel_pmf_bindI[OF 0]) (auto split:option.split)
show ?thesis
unfolding atMost_def by (intro measure_Ici[OF 0] transp_on_ge) (simp add:reflp_def)
qed
lemma coin_usage_of_tspmf_mono_rev:
assumes "ord_tspmf p q"
shows "measure (coin_usage_of_tspmf q) {x. x > k} ≤ measure (coin_usage_of_tspmf p) {x. x > k}"
(is "?L ≤ ?R")
proof -
have 0:"UNIV - {x. x > k} = {..k}"
by (auto simp add:set_diff_eq set_eq_iff)
have "1 - ?R ≤ 1 - ?L"
using coin_usage_of_tspmf_mono[OF assms]
by (subst (1 2) measure_pmf.prob_compl[symmetric]) (auto simp:0)
thus ?thesis
by simp
qed
lemma expected_coin_usage_of_tspmf:
"expected_coin_usage_of_tspmf p = (∑k. ennreal (measure (coin_usage_of_tspmf p) {x. x > enat k}))" (is "?L = ?R")
proof -
have "?L = integral⇧N (measure_pmf (coin_usage_of_tspmf p)) ennreal_of_enat"
unfolding expected_coin_usage_of_tspmf_def by simp
also have "... = (∑k. emeasure (measure_pmf (coin_usage_of_tspmf p)) {x. enat k < x})"
by (subst nn_integral_enat_function) auto
also have "... = ?R"
by (subst measure_pmf.emeasure_eq_measure) simp
finally show ?thesis
by simp
qed
lemma ord_tspmf_min: "ord_tspmf (return_pmf None) p"
unfolding ord_tspmf_def by (simp add: ord_spmf_reflI)
lemma ord_tspmf_refl: "ord_tspmf p p"
unfolding ord_tspmf_def by (simp add: ord_spmf_reflI)
lemma ord_tspmf_trans[trans]:
assumes "ord_tspmf p q" "ord_tspmf q r"
shows "ord_tspmf p r"
proof -
have 0:"transp (ord_tspmf)"
unfolding ord_tspmf_def
by (intro transp_rel_pmf transp_ord_option) (auto simp:transp_def)
thus ?thesis
using assms transpD[OF 0] by auto
qed
lemma ord_tspmf_map_spmf:
assumes "⋀x. x ≤ f x"
shows "ord_tspmf (map_spmf (apsnd f) p) p"
using assms unfolding ord_tspmf_def ord_spmf_map_spmf1
by (intro ord_spmf_reflI) auto
lemma ord_tspmf_bind_pmf:
assumes "⋀x. ord_tspmf (f x) (g x)"
shows "ord_tspmf (bind_pmf p f) (bind_pmf p g)"
using assms unfolding ord_tspmf_def
by (intro rel_pmf_bindI[where R="(=)"]) (auto simp add: pmf.rel_refl)
lemma ord_tspmf_bind_tspmf:
assumes "⋀x. ord_tspmf (f x) (g x)"
shows "ord_tspmf (bind_tspmf p f) (bind_tspmf p g)"
using assms unfolding bind_tspmf_def ord_tspmf_def
by (intro ord_spmf_bind_reflI) (simp add:case_prod_beta ord_spmf_map_spmf12)
definition use_coins :: "nat ⇒ 'a tspmf ⇒ 'a tspmf"
where "use_coins k = map_spmf (apsnd ((+) k))"
lemma use_coins_add:
"use_coins k (use_coins s f) = use_coins (k+s) f"
unfolding use_coins_def spmf.map_comp
by (simp add:comp_def apsnd_def map_prod_def case_prod_beta' algebra_simps)
lemma coin_tspmf_split:
fixes f :: "bool ⇒ 'b tspmf"
shows "(coin_tspmf ⤜ f) = use_coins 1 (coin_spmf ⤜ f)"
unfolding coin_tspmf_def use_coins_def map_spmf_conv_bind_spmf pair_spmf_alt_def bind_tspmf_def
by (simp)
lemma ord_tspmf_use_coins:
"ord_tspmf (use_coins k p) p"
unfolding use_coins_def by (intro ord_tspmf_map_spmf) auto
lemma ord_tspmf_use_coins_2:
assumes "ord_tspmf p q"
shows "ord_tspmf (use_coins k p) (use_coins k q)"
using assms unfolding use_coins_def ord_tspmf_def ord_spmf_map_spmf12 by auto
lemma result_mono:
assumes "ord_tspmf p q"
shows "ord_spmf (=) (result p) (result q)"
using assms ord_spmf_mono unfolding result_def ord_tspmf_def ord_spmf_map_spmf12 by force
lemma result_bind:
"result (bind_tspmf f g) = result f ⤜ (λx. result (g x))"
unfolding bind_tspmf_def result_def map_spmf_conv_bind_spmf by (simp add:case_prod_beta')
lemma result_return:
"result (return_tspmf x) = return_spmf x"
unfolding return_tspmf_def result_def map_spmf_conv_bind_spmf by (simp add:case_prod_beta')
lemma result_coin:
"result (coin_tspmf) = coin_spmf"
unfolding coin_tspmf_def result_def pair_spmf_alt_def map_spmf_conv_bind_spmf by (simp add:case_prod_beta')
definition tspmf_of_ra :: "'a random_alg ⇒ 'a tspmf" where
"tspmf_of_ra = spmf_of_ra ∘ track_coin_use"
lemma tspmf_of_ra_coin: "tspmf_of_ra coin_ra = coin_tspmf"
unfolding tspmf_of_ra_def comp_def track_coin_use_coin coin_tra_def coin_tspmf_def
spmf_of_ra_bind spmf_of_ra_coin spmf_of_ra_return pair_spmf_alt_def
by simp
lemma tspmf_of_ra_return: "tspmf_of_ra (return_ra x) = return_tspmf x"
unfolding tspmf_of_ra_def comp_def track_coin_use_return return_tra_def return_tspmf_def
spmf_of_ra_return by simp
lemma tspmf_of_ra_bind:
"tspmf_of_ra (bind_ra m f) = bind_tspmf (tspmf_of_ra m) (λx. tspmf_of_ra (f x))"
unfolding tspmf_of_ra_def comp_def track_coin_use_bind bind_tra_def bind_tspmf_def
map_spmf_conv_bind_spmf
by (simp add:case_prod_beta' spmf_of_ra_bind spmf_of_ra_return apsnd_def map_prod_def)
lemmas tspmf_of_ra_simps = tspmf_of_ra_bind tspmf_of_ra_return tspmf_of_ra_coin
lemma tspmf_of_ra_mono:
assumes "ord_ra f g"
shows "ord_spmf (=) (tspmf_of_ra f) (tspmf_of_ra g)"
unfolding tspmf_of_ra_def comp_def
by (intro spmf_of_ra_mono track_coin_use_mono assms)
lemma tspmf_of_ra_lub:
assumes "Complete_Partial_Order.chain ord_ra A"
shows "tspmf_of_ra (lub_ra A) = lub_spmf (tspmf_of_ra ` A)" (is "?L = ?R")
proof -
have 0:"Complete_Partial_Order.chain ord_ra (track_coin_use ` A)"
by (intro chain_imageI[OF assms] track_coin_use_mono)
have "?L = spmf_of_ra (lub_ra (track_coin_use ` A))"
unfolding tspmf_of_ra_def comp_def
by (intro arg_cong[where f="spmf_of_ra"] track_coin_use_lub assms)
also have "... = lub_spmf (spmf_of_ra ` track_coin_use ` A)"
by (intro spmf_of_ra_lub_ra 0)
also have "... = ?R"
unfolding image_image tspmf_of_ra_def by simp
finally show "?thesis" by simp
qed
definition rel_tspmf_of_ra :: "'a tspmf ⇒ 'a random_alg ⇒ bool" where
"rel_tspmf_of_ra q p ⟷ q = tspmf_of_ra p"
lemma admissible_rel_tspmf_of_ra:
"ccpo.admissible (prod_lub lub_spmf lub_ra) (rel_prod (ord_spmf (=)) ord_ra) (case_prod rel_tspmf_of_ra)"
(is "ccpo.admissible ?lub ?ord ?P")
proof (rule ccpo.admissibleI)
fix Y
assume chain: "Complete_Partial_Order.chain ?ord Y"
and Y: "Y ≠ {}"
and R: "∀(p, q) ∈ Y. rel_tspmf_of_ra p q"
from R have R: "⋀p q. (p, q) ∈ Y ⟹ rel_tspmf_of_ra p q" by auto
have chain1: "Complete_Partial_Order.chain (ord_spmf (=)) (fst ` Y)"
and chain2: "Complete_Partial_Order.chain ord_ra (snd ` Y)"
using chain by(rule chain_imageI; clarsimp)+
from Y have Y1: "fst ` Y ≠ {}" and Y2: "snd ` Y ≠ {}" by auto
have "lub_spmf (fst ` Y) = lub_spmf (tspmf_of_ra ` snd ` Y)"
unfolding image_image using R
by (intro arg_cong[of _ _ lub_spmf] image_cong) (auto simp: rel_tspmf_of_ra_def)
also have "… = tspmf_of_ra (lub_ra (snd ` Y))"
by (intro tspmf_of_ra_lub[symmetric] chain2)
finally have "rel_tspmf_of_ra (lub_spmf (fst ` Y)) (lub_ra (snd ` Y))"
unfolding rel_tspmf_of_ra_def .
then show "?P (?lub Y)"
by (simp add: prod_lub_def)
qed
lemma admissible_rel_tspmf_of_ra_cont [cont_intro]:
fixes ord
shows "⟦ mcont lub ord lub_spmf (ord_spmf (=)) f; mcont lub ord lub_ra ord_ra g ⟧
⟹ ccpo.admissible lub ord (λx. rel_tspmf_of_ra (f x) (g x))"
by (rule admissible_subst[OF admissible_rel_tspmf_of_ra, where f="λx. (f x, g x)", simplified])
(rule mcont_Pair)
lemma mcont_tspmf_of_ra:
"mcont lub_ra ord_ra lub_spmf (ord_spmf (=)) tspmf_of_ra"
unfolding mcont_def monotone_def cont_def
by (auto simp: tspmf_of_ra_mono tspmf_of_ra_lub)
lemmas mcont2mcont_tspmf_of_ra = mcont_tspmf_of_ra[THEN spmf.mcont2mcont]
context includes lifting_syntax
begin
lemma fixp_rel_tspmf_of_ra_parametric[transfer_rule]:
assumes f: "⋀x. mono_spmf (λf. F f x)"
and g: "⋀x. mono_ra (λf. G f x)"
and param: "((A ===> rel_tspmf_of_ra) ===> A ===> rel_tspmf_of_ra) F G"
shows "(A ===> rel_tspmf_of_ra) (spmf.fixp_fun F) (random_alg_pf.fixp_fun G)"
using f g
proof(rule parallel_fixp_induct_1_1[OF
partial_function_definitions_spmf random_alg_pfd _ _ reflexive reflexive,
where P="(A ===> rel_tspmf_of_ra)"])
show "ccpo.admissible (prod_lub (fun_lub lub_spmf) (fun_lub lub_ra))
(rel_prod (fun_ord (ord_spmf (=))) (fun_ord ord_ra))
(λx. (A ===> rel_tspmf_of_ra) (fst x) (snd x))"
unfolding rel_fun_def
by(rule admissible_all admissible_imp cont_intro)+
have 0:"tspmf_of_ra (lub_ra {}) = return_pmf None"
using tspmf_of_ra_lub[where A="{}"]
by (simp add:Complete_Partial_Order.chain_def)
show "(A ===> rel_tspmf_of_ra) (λ_. lub_spmf {}) (λ_. lub_ra {})"
by (auto simp: rel_fun_def rel_tspmf_of_ra_def 0)
show "(A ===> rel_tspmf_of_ra) (F f) (G g)" if "(A ===> rel_tspmf_of_ra) f g" for f g
using that by(rule rel_funD[OF param])
qed
lemma return_ra_tranfer[transfer_rule]: "((=) ===> rel_tspmf_of_ra) return_tspmf return_ra"
unfolding rel_fun_def rel_tspmf_of_ra_def tspmf_of_ra_return by simp
lemma bind_ra_tranfer[transfer_rule]:
"(rel_tspmf_of_ra ===> ((=) ===> rel_tspmf_of_ra) ===> rel_tspmf_of_ra) bind_tspmf bind_ra"
unfolding rel_fun_def rel_tspmf_of_ra_def tspmf_of_ra_bind by simp presburger
lemma coin_ra_tranfer[transfer_rule]:
"rel_tspmf_of_ra coin_tspmf coin_ra"
unfolding rel_fun_def rel_tspmf_of_ra_def tspmf_of_ra_coin by simp
end
lemma spmf_of_tspmf:
"result (tspmf_of_ra f) = spmf_of_ra f"
unfolding tspmf_of_ra_def result_def
by (simp add: untrack_coin_use spmf_of_ra_map[symmetric])
lemma coin_usage_of_tspmf_correct:
"coin_usage_of_tspmf (tspmf_of_ra p) = coin_usage_of_ra p" (is "?L = ?R")
proof -
let ?p = "Rep_random_alg p"
have "measure_pmf (map_spmf snd (tspmf_of_ra p)) =
distr (distr_rai (track_random_bits ?p)) 𝒟 (map_option snd)"
unfolding tspmf_of_ra_def map_pmf_rep_eq spmf_of_ra.rep_eq comp_def track_coin_use.rep_eq
by simp
also have "... = distr ℬ 𝒟 (map_option snd ∘ (map_option fst ∘ track_random_bits ?p))"
unfolding distr_rai_def
by (intro distr_distr distr_rai_measurable wf_track_random_bits wf_rep_rand_alg) simp
also have "... = distr ℬ 𝒟 (λx. ?p x ⤜ (λxa. consumed_bits ?p x))"
unfolding track_random_bits_def by (simp add:comp_def map_option_bind case_prod_beta)
also have "... = distr ℬ 𝒟 (λx. consumed_bits ?p x)"
by (intro arg_cong[where f="distr ℬ 𝒟"] ext)
(auto simp:consumed_bits_inf_iff[OF wf_rep_rand_alg] split:bind_split)
also have "... = measure_pmf (coin_usage_of_ra_aux p)"
unfolding coin_usage_of_ra_aux.rep_eq used_bits_distr_def by simp
finally have "measure_pmf (map_spmf snd (tspmf_of_ra p)) = measure_pmf (coin_usage_of_ra_aux p)"
by simp
hence 0:"map_spmf snd (tspmf_of_ra p) = coin_usage_of_ra_aux p"
using measure_pmf_inject by auto
show ?thesis
unfolding coin_usage_of_tspmf_def 0[symmetric] coin_usage_of_ra_def map_pmf_comp
by (intro map_pmf_cong) (auto split:option.split)
qed
lemma expected_coin_usage_of_tspmf_correct:
"expected_coin_usage_of_tspmf (tspmf_of_ra p) = expected_coin_usage_of_ra p"
unfolding expected_coin_usage_of_tspmf_def coin_usage_of_tspmf_correct
expected_coin_usage_of_ra_def by simp
end