Theory HOL-Probability.Giry_Monad

(*  Title:      HOL/Probability/Giry_Monad.thy
    Author:     Johannes Hölzl, TU München
    Author:     Manuel Eberl, TU München

Defines subprobability spaces, the subprobability functor and the Giry monad on subprobability
spaces.
*)

section ‹The Giry monad›

theory Giry_Monad
  imports Probability_Measure "HOL-Library.Monad_Syntax"
begin

subsection ‹Sub-probability spaces›

locale subprob_space = finite_measure +
  assumes emeasure_space_le_1: "emeasure M (space M)  1"
  assumes subprob_not_empty: "space M  {}"

lemma subprob_spaceI[Pure.intro!]:
  assumes *: "emeasure M (space M)  1"
  assumes "space M  {}"
  shows "subprob_space M"
proof -
  interpret finite_measure M
  proof
    show "emeasure M (space M)  " using * by (auto simp: top_unique)
  qed
  show "subprob_space M" by standard fact+
qed

lemma (in subprob_space) emeasure_subprob_space_less_top: "emeasure M A  top"
  by simp

lemma prob_space_imp_subprob_space:
  "prob_space M  subprob_space M"
  by (rule subprob_spaceI) (simp_all add: prob_space.emeasure_space_1 prob_space.not_empty)

lemma subprob_space_imp_sigma_finite: "subprob_space M  sigma_finite_measure M"
  unfolding subprob_space_def finite_measure_def by simp

sublocale prob_space  subprob_space
  by (rule subprob_spaceI) (simp_all add: emeasure_space_1 not_empty)

lemma subprob_space_sigma [simp]: "Ω  {}  subprob_space (sigma Ω X)"
  by(rule subprob_spaceI)(simp_all add: emeasure_sigma space_measure_of_conv)

lemma subprob_space_null_measure: "space M  {}  subprob_space (null_measure M)"
  by(simp add: null_measure_def)

lemma (in subprob_space) subprob_space_distr:
  assumes f: "f  measurable M M'" and "space M'  {}" shows "subprob_space (distr M M' f)"
proof (rule subprob_spaceI)
  have "f -` space M'  space M = space M" using f by (auto dest: measurable_space)
  with f show "emeasure (distr M M' f) (space (distr M M' f))  1"
    by (auto simp: emeasure_distr emeasure_space_le_1)
  show "space (distr M M' f)  {}" by (simp add: assms)
qed

lemma (in subprob_space) subprob_emeasure_le_1: "emeasure M X  1"
  by (rule order.trans[OF emeasure_space emeasure_space_le_1])

lemma (in subprob_space) subprob_measure_le_1: "measure M X  1"
  using subprob_emeasure_le_1[of X] by (simp add: emeasure_eq_measure)

lemma (in subprob_space) nn_integral_le_const:
  assumes "0  c" "AE x in M. f x  c"
  shows "(+x. f x M)  c"
proof -
  have "(+ x. f x M)  (+ x. c M)"
    by(rule nn_integral_mono_AE) fact
  also have "  c * emeasure M (space M)"
    using 0  c by simp
  also have "  c * 1" using emeasure_space_le_1 0  c by(rule mult_left_mono)
  finally show ?thesis by simp
qed

lemma emeasure_density_distr_interval:
  fixes h :: "real  real" and g :: "real  real" and g' :: "real  real"
  assumes [simp]: "a  b"
  assumes Mf[measurable]: "f  borel_measurable borel"
  assumes Mg[measurable]: "g  borel_measurable borel"
  assumes Mg'[measurable]: "g'  borel_measurable borel"
  assumes Mh[measurable]: "h  borel_measurable borel"
  assumes prob: "subprob_space (density lborel f)"
  assumes nonnegf: "x. f x  0"
  assumes derivg: "x. x  {a..b}  (g has_real_derivative g' x) (at x)"
  assumes contg': "continuous_on {a..b} g'"
  assumes mono: "strict_mono_on {a..b} g" and inv: "x. h x  {a..b}  g (h x) = x"
  assumes range: "{a..b}  range h"
  shows "emeasure (distr (density lborel f) lborel h) {a..b} =
             emeasure (density lborel (λx. f (g x) * g' x)) {a..b}"
proof (cases "a < b")
  assume "a < b"
  from mono have inj: "inj_on g {a..b}" by (rule strict_mono_on_imp_inj_on)
  from mono have mono': "mono_on {a..b} g" by (rule strict_mono_on_imp_mono_on)
  from mono' derivg have "x. x  {a<..<b}  g' x  0"
    by (rule mono_on_imp_deriv_nonneg) auto
  from contg' this have derivg_nonneg: "x. x  {a..b}  g' x  0"
    by (rule continuous_ge_on_Ioo) (simp_all add: a < b)

  from derivg have contg: "continuous_on {a..b} g" by (rule has_real_derivative_imp_continuous_on)
  have A: "h -` {a..b} = {g a..g b}"
  proof (intro equalityI subsetI)
    fix x assume x: "x  h -` {a..b}"
    hence "g (h x)  {g a..g b}" by (auto intro: mono_onD[OF mono'])
    with inv and x show "x  {g a..g b}" by simp
  next
    fix y assume y: "y  {g a..g b}"
    with IVT'[OF _ _ _ contg, of y] obtain x where "x  {a..b}" "y = g x" by auto
    with range and inv show "y  h -` {a..b}" by auto
  qed

  have prob': "subprob_space (distr (density lborel f) lborel h)"
    by (rule subprob_space.subprob_space_distr[OF prob]) (simp_all add: Mh)
  have B: "emeasure (distr (density lborel f) lborel h) {a..b} =
            +x. f x * indicator (h -` {a..b}) x lborel"
    by (subst emeasure_distr) (simp_all add: emeasure_density Mf Mh measurable_sets_borel[OF Mh])
  also note A
  also have "emeasure (distr (density lborel f) lborel h) {a..b}  1"
    by (rule subprob_space.subprob_emeasure_le_1) (rule prob')
  hence "emeasure (distr (density lborel f) lborel h) {a..b}  " by (auto simp: top_unique)
  with assms have "(+x. f x * indicator {g a..g b} x lborel) =
                      (+x. f (g x) * g' x * indicator {a..b} x lborel)"
    by (intro nn_integral_substitution_aux)
       (auto simp: derivg_nonneg A B emeasure_density mult.commute a < b)
  also have "... = emeasure (density lborel (λx. f (g x) * g' x)) {a..b}"
    by (simp add: emeasure_density)
  finally show ?thesis .
next
  assume "¬a < b"
  with a  b have [simp]: "b = a" by (simp add: not_less del: a  b)
  from inv and range have "h -` {a} = {g a}" by auto
  thus ?thesis by (simp_all add: emeasure_distr emeasure_density measurable_sets_borel[OF Mh])
qed

locale pair_subprob_space =
  pair_sigma_finite M1 M2 + M1: subprob_space M1 + M2: subprob_space M2 for M1 M2

sublocale pair_subprob_space  P?: subprob_space "M1 M M2"
proof
  from mult_le_one[OF M1.emeasure_space_le_1 _ M2.emeasure_space_le_1]
  show "emeasure (M1 M M2) (space (M1 M M2))  1"
    by (simp add: M2.emeasure_pair_measure_Times space_pair_measure)
  from M1.subprob_not_empty and M2.subprob_not_empty show "space (M1 M M2)  {}"
    by (simp add: space_pair_measure)
qed

lemma subprob_space_null_measure_iff:
    "subprob_space (null_measure M)  space M  {}"
  by (auto intro!: subprob_spaceI dest: subprob_space.subprob_not_empty)

lemma subprob_space_restrict_space:
  assumes M: "subprob_space M"
  and A: "A  space M  sets M" "A  space M  {}"
  shows "subprob_space (restrict_space M A)"
proof(rule subprob_spaceI)
  have "emeasure (restrict_space M A) (space (restrict_space M A)) = emeasure M (A  space M)"
    using A by(simp add: emeasure_restrict_space space_restrict_space)
  also have "  1" by(rule subprob_space.subprob_emeasure_le_1)(rule M)
  finally show "emeasure (restrict_space M A) (space (restrict_space M A))  1" .
next
  show "space (restrict_space M A)  {}"
    using A by(simp add: space_restrict_space)
qed

definition subprob_algebra :: "'a measure  'a measure measure" where
  "subprob_algebra K =
    (SUP A  sets K. vimage_algebra {M. subprob_space M  sets M = sets K} (λM. emeasure M A) borel)"

lemma space_subprob_algebra: "space (subprob_algebra A) = {M. subprob_space M  sets M = sets A}"
  by (auto simp add: subprob_algebra_def space_Sup_eq_UN)

lemma subprob_algebra_cong: "sets M = sets N  subprob_algebra M = subprob_algebra N"
  by (simp add: subprob_algebra_def)

lemma measurable_emeasure_subprob_algebra[measurable]:
  "a  sets A  (λM. emeasure M a)  borel_measurable (subprob_algebra A)"
  by (auto intro!: measurable_Sup1 measurable_vimage_algebra1 simp: subprob_algebra_def)

lemma measurable_measure_subprob_algebra[measurable]:
  "a  sets A  (λM. measure M a)  borel_measurable (subprob_algebra A)"
  unfolding measure_def by measurable

lemma subprob_measurableD:
  assumes N: "N  measurable M (subprob_algebra S)" and x: "x  space M"
  shows "space (N x) = space S"
    and "sets (N x) = sets S"
    and "measurable (N x) K = measurable S K"
    and "measurable K (N x) = measurable K S"
  using measurable_space[OF N x]
  by (auto simp: space_subprob_algebra intro!: measurable_cong_sets dest: sets_eq_imp_space_eq)

ML fun subprob_cong thm ctxt = (
  let
    val thm' = Thm.transfer' ctxt thm
    val free = thm' |> Thm.concl_of |> HOLogic.dest_Trueprop |> dest_comb |> fst |>
      dest_comb |> snd |> strip_abs_body |> head_of |> is_Free
  in
    if free then ([], Measurable.add_local_cong (thm' RS @{thm subprob_measurableD(2)}) ctxt)
            else ([], ctxt)
  end
  handle THM _ => ([], ctxt) | TERM _ => ([], ctxt))

setup Context.theory_map (Measurable.add_preprocessor "subprob_cong" subprob_cong)

context
  fixes K M N assumes K: "K  measurable M (subprob_algebra N)"
begin

lemma subprob_space_kernel: "a  space M  subprob_space (K a)"
  using measurable_space[OF K] by (simp add: space_subprob_algebra)

lemma sets_kernel: "a  space M  sets (K a) = sets N"
  using measurable_space[OF K] by (simp add: space_subprob_algebra)

lemma measurable_emeasure_kernel[measurable]:
    "A  sets N  (λa. emeasure (K a) A)  borel_measurable M"
  using measurable_compose[OF K measurable_emeasure_subprob_algebra] .

end

lemma measurable_subprob_algebra:
  "(a. a  space M  subprob_space (K a)) 
  (a. a  space M  sets (K a) = sets N) 
  (A. A  sets N  (λa. emeasure (K a) A)  borel_measurable M) 
  K  measurable M (subprob_algebra N)"
  by (auto intro!: measurable_Sup2 measurable_vimage_algebra2 simp: subprob_algebra_def)

lemma measurable_submarkov:
  "K  measurable M (subprob_algebra M) 
    (xspace M. subprob_space (K x)  sets (K x) = sets M) 
    (Asets M. (λx. emeasure (K x) A)  measurable M borel)"
proof
  assume "(xspace M. subprob_space (K x)  sets (K x) = sets M) 
    (Asets M. (λx. emeasure (K x) A)  borel_measurable M)"
  then show "K  measurable M (subprob_algebra M)"
    by (intro measurable_subprob_algebra) auto
next
  assume "K  measurable M (subprob_algebra M)"
  then show "(xspace M. subprob_space (K x)  sets (K x) = sets M) 
    (Asets M. (λx. emeasure (K x) A)  borel_measurable M)"
    by (auto dest: subprob_space_kernel sets_kernel)
qed

lemma measurable_subprob_algebra_generated:
  assumes eq: "sets N = sigma_sets Ω G" and "Int_stable G" "G  Pow Ω"
  assumes subsp: "a. a  space M  subprob_space (K a)"
  assumes sets: "a. a  space M  sets (K a) = sets N"
  assumes "A. A  G  (λa. emeasure (K a) A)  borel_measurable M"
  assumes Ω: "(λa. emeasure (K a) Ω)  borel_measurable M"
  shows "K  measurable M (subprob_algebra N)"
proof (rule measurable_subprob_algebra)
  fix a assume "a  space M" then show "subprob_space (K a)" "sets (K a) = sets N" by fact+
next
  interpret G: sigma_algebra Ω "sigma_sets Ω G"
    using G  Pow Ω by (rule sigma_algebra_sigma_sets)
  fix A assume "A  sets N" with assms(2,3) show "(λa. emeasure (K a) A)  borel_measurable M"
    unfolding sets N = sigma_sets Ω G
  proof (induction rule: sigma_sets_induct_disjoint)
    case (basic A) then show ?case by fact
  next
    case empty then show ?case by simp
  next
    case (compl A)
    have "(λa. emeasure (K a) (Ω - A))  borel_measurable M 
      (λa. emeasure (K a) Ω - emeasure (K a) A)  borel_measurable M"
      using G.top G.sets_into_space sets eq compl subprob_space.emeasure_subprob_space_less_top[OF subsp]
      by (intro measurable_cong emeasure_Diff) auto
    with compl Ω show ?case
      by simp
  next
    case (union F)
    moreover have "(λa. emeasure (K a) (i. F i))  borel_measurable M 
        (λa. i. emeasure (K a) (F i))  borel_measurable M"
      using sets union eq
      by (intro measurable_cong suminf_emeasure[symmetric]) auto
    ultimately show ?case
      by auto
  qed
qed

lemma space_subprob_algebra_empty_iff:
  "space (subprob_algebra N) = {}  space N = {}"
proof
  have "x. x  space N  density N (λ_. 0)  space (subprob_algebra N)"
    by (auto simp: space_subprob_algebra emeasure_density intro!: subprob_spaceI)
  then show "space (subprob_algebra N) = {}  space N = {}"
    by auto
next
  assume "space N = {}"
  hence "sets N = {{}}" by (simp add: space_empty_iff)
  moreover have "M. subprob_space M  sets M  {{}}"
    by (simp add: subprob_space.subprob_not_empty space_empty_iff[symmetric])
  ultimately show "space (subprob_algebra N) = {}" by (auto simp: space_subprob_algebra)
qed

lemma nn_integral_measurable_subprob_algebra[measurable]:
  assumes f: "f  borel_measurable N"
  shows "(λM. integralN M f)  borel_measurable (subprob_algebra N)" (is "_  ?B")
  using f
proof induct
  case (cong f g)
  moreover have "(λM'. +M''. f M'' M')  ?B  (λM'. +M''. g M'' M')  ?B"
    by (intro measurable_cong nn_integral_cong cong)
       (auto simp: space_subprob_algebra dest!: sets_eq_imp_space_eq)
  ultimately show ?case by simp
next
  case (set B)
  then have "(λM'. +M''. indicator B M'' M')  ?B  (λM'. emeasure M' B)  ?B"
    by (intro measurable_cong nn_integral_indicator) (simp add: space_subprob_algebra)
  with set show ?case
    by (simp add: measurable_emeasure_subprob_algebra)
next
  case (mult f c)
  then have "(λM'. +M''. c * f M'' M')  ?B  (λM'. c * +M''. f M'' M')  ?B"
    by (intro measurable_cong nn_integral_cmult) (auto simp add: space_subprob_algebra)
  with mult show ?case
    by simp
next
  case (add f g)
  then have "(λM'. +M''. f M'' + g M'' M')  ?B  (λM'. (+M''. f M'' M') + (+M''. g M'' M'))  ?B"
    by (intro measurable_cong nn_integral_add) (auto simp add: space_subprob_algebra)
  with add show ?case
    by (simp add: ac_simps)
next
  case (seq F)
  then have "(λM'. +M''. (SUP i. F i) M'' M')  ?B  (λM'. SUP i. (+M''. F i M'' M'))  ?B"
    unfolding SUP_apply
    by (intro measurable_cong nn_integral_monotone_convergence_SUP) (auto simp add: space_subprob_algebra)
  with seq show ?case
    by (simp add: ac_simps)
qed

lemma measurable_distr:
  assumes [measurable]: "f  measurable M N"
  shows "(λM'. distr M' N f)  measurable (subprob_algebra M) (subprob_algebra N)"
proof (cases "space N = {}")
  case False
  show ?thesis
  proof (rule measurable_subprob_algebra)
    fix A assume A: "A  sets N"
    then have "(λM'. emeasure (distr M' N f) A)  borel_measurable (subprob_algebra M) 
      (λM'. emeasure M' (f -` A  space M))  borel_measurable (subprob_algebra M)"
      by (intro measurable_cong)
         (auto simp: emeasure_distr space_subprob_algebra
               intro!: arg_cong2[where f=emeasure] sets_eq_imp_space_eq arg_cong2[where f="(∩)"])
    also have ""
      using A by (intro measurable_emeasure_subprob_algebra) simp
    finally show "(λM'. emeasure (distr M' N f) A)  borel_measurable (subprob_algebra M)" .
  qed (auto intro!: subprob_space.subprob_space_distr simp: space_subprob_algebra False cong: measurable_cong_sets)
qed (use assms in auto simp: measurable_empty_iff space_subprob_algebra_empty_iff)

lemma emeasure_space_subprob_algebra[measurable]:
  "(λa. emeasure a (space a))  borel_measurable (subprob_algebra N)"
proof-
  have "(λa. emeasure a (space N))  borel_measurable (subprob_algebra N)" (is "?f  ?M")
    by (rule measurable_emeasure_subprob_algebra) simp
  also have "?f  ?M  (λa. emeasure a (space a))  ?M"
    by (rule measurable_cong) (auto simp: space_subprob_algebra dest: sets_eq_imp_space_eq)
  finally show ?thesis .
qed

lemma integrable_measurable_subprob_algebra[measurable]:
  fixes f :: "'a  'b::{banach, second_countable_topology}"
  assumes [measurable]: "f  borel_measurable N"
  shows "Measurable.pred (subprob_algebra N) (λM. integrable M f)"
proof (rule measurable_cong[THEN iffD2])
  show "M  space (subprob_algebra N)  integrable M f  (+x. norm (f x) M) < " for M
    by (auto simp: space_subprob_algebra integrable_iff_bounded)
qed measurable

lemma integral_measurable_subprob_algebra[measurable]:
  fixes f :: "'a  'b::{banach, second_countable_topology}"
  assumes f [measurable]: "f  borel_measurable N"
  shows "(λM. integralL M f)  subprob_algebra N M borel"
proof -
  from borel_measurable_implies_sequence_metric[OF f, of 0]
  obtain F where F: "i. simple_function N (F i)"
    "x. x  space N  (λi. F i x)  f x"
    "i x. x  space N  norm (F i x)  2 * norm (f x)"
    unfolding norm_conv_dist by blast

  have [measurable]: "F i  N M count_space UNIV" for i
    using F(1) by (rule measurable_simple_function)

  define F' where [abs_def]:
    "F' M i = (if integrable M f then integralL M (F i) else 0)" for M i

  have "(λM. F' M i)  subprob_algebra N M borel" for i
  proof (rule measurable_cong[THEN iffD2])
    fix M assume "M  space (subprob_algebra N)"
    then have [simp]: "sets M = sets N" "space M = space N" "subprob_space M"
      by (auto simp: space_subprob_algebra intro!: sets_eq_imp_space_eq)
    interpret subprob_space M by fact
    have "F' M i = (if integrable M f then Bochner_Integration.simple_bochner_integral M (F i) else 0)"
      using F(1)
      by (subst simple_bochner_integrable_eq_integral)
         (auto simp: simple_bochner_integrable.simps simple_function_def F'_def)
    then show "F' M i = (if integrable M f then yF i ` space N. measure M {xspace N. F i x = y} *R y else 0)"
      unfolding simple_bochner_integral_def by simp
  qed measurable
  moreover
  have "F' M  integralL M f" if M: "M  space (subprob_algebra N)" for M
  proof cases
    from M have [simp]: "sets M = sets N" "space M = space N"
      by (auto simp: space_subprob_algebra intro!: sets_eq_imp_space_eq)
    assume "integrable M f" then show ?thesis
      unfolding F'_def using F(1)[THEN borel_measurable_simple_function] F
      by (auto intro!: integral_dominated_convergence[where w="λx. 2 * norm (f x)"]
               cong: measurable_cong_sets)
  qed (auto simp: F'_def not_integrable_integral_eq)
  ultimately show ?thesis
    by (rule borel_measurable_LIMSEQ_metric)
qed

(* TODO: Rename. This name is too general -- Manuel *)
lemma measurable_pair_measure:
  assumes f: "f  measurable M (subprob_algebra N)"
  assumes g: "g  measurable M (subprob_algebra L)"
  shows "(λx. f x M g x)  measurable M (subprob_algebra (N M L))"
proof (rule measurable_subprob_algebra)
  { fix x assume "x  space M"
    with measurable_space[OF f] measurable_space[OF g]
    have fx: "f x  space (subprob_algebra N)" and gx: "g x  space (subprob_algebra L)"
      by auto
    interpret F: subprob_space "f x"
      using fx by (simp add: space_subprob_algebra)
    interpret G: subprob_space "g x"
      using gx by (simp add: space_subprob_algebra)

    interpret pair_subprob_space "f x" "g x" ..
    show "subprob_space (f x M g x)" by unfold_locales
    show sets_eq: "sets (f x M g x) = sets (N M L)"
      using fx gx by (simp add: space_subprob_algebra)

    have 1: "A B. A  sets N  B  sets L  emeasure (f x M g x) (A × B) = emeasure (f x) A * emeasure (g x) B"
      using fx gx by (intro G.emeasure_pair_measure_Times) (auto simp: space_subprob_algebra)
    have "emeasure (f x M g x) (space (f x M g x)) =
              emeasure (f x) (space (f x)) * emeasure (g x) (space (g x))"
      by (subst G.emeasure_pair_measure_Times[symmetric]) (simp_all add: space_pair_measure)
    hence 2: "A. A  sets (N M L)  emeasure (f x M g x) (space N × space L - A) =
                                             ... - emeasure (f x M g x) A"
      using emeasure_compl[simplified, OF _ P.emeasure_finite]
      unfolding sets_eq
      unfolding sets_eq_imp_space_eq[OF sets_eq]
      by (simp add: space_pair_measure G.emeasure_pair_measure_Times)
    note 1 2 sets_eq }
  note Times = this(1) and Compl = this(2) and sets_eq = this(3)

  fix A assume A: "A  sets (N M L)"
  show "(λa. emeasure (f a M g a) A)  borel_measurable M"
    using Int_stable_pair_measure_generator pair_measure_closed A
    unfolding sets_pair_measure
  proof (induct A rule: sigma_sets_induct_disjoint)
    case (basic A) then show ?case
      by (auto intro!: borel_measurable_times_ennreal simp: Times cong: measurable_cong)
         (auto intro!: measurable_emeasure_kernel f g)
  next
    case (compl A)
    then have A: "A  sets (N M L)"
      by (auto simp: sets_pair_measure)
    have "(λx. emeasure (f x) (space (f x)) * emeasure (g x) (space (g x)) -
                   emeasure (f x M g x) A)  borel_measurable M" (is "?f  ?M")
      using compl(2) f g by measurable
    thus ?case by (simp add: Compl A cong: measurable_cong)
  next
    case (union A)
    then have "range A  sets (N M L)" "disjoint_family A"
      by (auto simp: sets_pair_measure)
    then have "(λa. emeasure (f a M g a) (i. A i))  borel_measurable M 
      (λa. i. emeasure (f a M g a) (A i))  borel_measurable M"
      by (intro measurable_cong suminf_emeasure[symmetric])
         (auto simp: sets_eq)
    also have ""
      using union by auto
    finally show ?case .
  qed simp
qed

lemma restrict_space_measurable:
  assumes X: "X  {}" "X  sets K"
  assumes N: "N  measurable M (subprob_algebra K)"
  shows "(λx. restrict_space (N x) X)  measurable M (subprob_algebra (restrict_space K X))"
proof (rule measurable_subprob_algebra)
  fix a assume a: "a  space M"
  from N[THEN measurable_space, OF this]
  have "subprob_space (N a)" and [simp]: "sets (N a) = sets K" "space (N a) = space K"
    by (auto simp add: space_subprob_algebra dest: sets_eq_imp_space_eq)
  then interpret subprob_space "N a"
    by simp
  show "subprob_space (restrict_space (N a) X)"
  proof
    show "space (restrict_space (N a) X)  {}"
      using X by (auto simp add: space_restrict_space)
    show "emeasure (restrict_space (N a) X) (space (restrict_space (N a) X))  1"
      using X by (simp add: emeasure_restrict_space space_restrict_space subprob_emeasure_le_1)
  qed
  show "sets (restrict_space (N a) X) = sets (restrict_space K X)"
    by (intro sets_restrict_space_cong) fact
next
  fix A assume A: "A  sets (restrict_space K X)"
  show "(λa. emeasure (restrict_space (N a) X) A)  borel_measurable M"
  proof (subst measurable_cong)
    fix a assume "a  space M"
    from N[THEN measurable_space, OF this]
    have [simp]: "sets (N a) = sets K" "space (N a) = space K"
      by (auto simp add: space_subprob_algebra dest: sets_eq_imp_space_eq)
    show "emeasure (restrict_space (N a) X) A = emeasure (N a) (A  X)"
      using X A by (subst emeasure_restrict_space) (auto simp add: sets_restrict_space ac_simps)
  next
    show "(λw. emeasure (N w) (A  X))  borel_measurable M"
      using A X
      by (intro measurable_compose[OF N measurable_emeasure_subprob_algebra])
         (auto simp: sets_restrict_space)
  qed
qed

subsection ‹Properties of ``return''›

definition return :: "'a measure  'a  'a measure" where
  "return R x = measure_of (space R) (sets R) (λA. indicator A x)"

lemma space_return[simp]: "space (return M x) = space M"
  by (simp add: return_def)

lemma sets_return[simp]: "sets (return M x) = sets M"
  by (simp add: return_def)

lemma measurable_return1[simp]: "measurable (return N x) L = measurable N L"
  by (simp cong: measurable_cong_sets)

lemma measurable_return2[simp]: "measurable L (return N x) = measurable L N"
  by (simp cong: measurable_cong_sets)

lemma return_sets_cong: "sets M = sets N  return M = return N"
  by (auto dest: sets_eq_imp_space_eq simp: fun_eq_iff return_def)

lemma return_cong: "sets A = sets B  return A x = return B x"
  by (auto simp add: return_def dest: sets_eq_imp_space_eq)

lemma emeasure_return[simp]:
  assumes "A  sets M"
  shows "emeasure (return M x) A = indicator A x"
proof (rule emeasure_measure_of[OF return_def])
  show "sets M  Pow (space M)" by (rule sets.space_closed)
  show "positive (sets (return M x)) (λA. indicator A x)" by (simp add: positive_def)
  from assms show "A  sets (return M x)" unfolding return_def by simp
  show "countably_additive (sets (return M x)) (λA. indicator A x)"
    by (auto intro!: countably_additiveI suminf_indicator)
qed

lemma prob_space_return: "x  space M  prob_space (return M x)"
  by rule simp

lemma subprob_space_return: "x  space M  subprob_space (return M x)"
  by (intro prob_space_return prob_space_imp_subprob_space)

lemma subprob_space_return_ne:
  assumes "space M  {}" shows "subprob_space (return M x)"
  by (metis assms emeasure_return indicator_simps(2) sets.top space_return subprob_spaceI subprob_space_return zero_le)

lemma measure_return: assumes X: "X  sets M" shows "measure (return M x) X = indicator X x"
  unfolding measure_def emeasure_return[OF X, of x] by (simp split: split_indicator)

lemma AE_return:
  assumes [simp]: "x  space M" and [measurable]: "Measurable.pred M P"
  shows "(AE y in return M x. P y)  P x"
proof -
  have "(AE y in return M x. y  {xspace M. ¬ P x})  P x"
    by (subst AE_iff_null_sets[symmetric]) (simp_all add: null_sets_def split: split_indicator)
  also have "(AE y in return M x. y  {xspace M. ¬ P x})  (AE y in return M x. P y)"
    by (rule AE_cong) auto
  finally show ?thesis .
qed

lemma nn_integral_return:
  assumes "x  space M" "g  borel_measurable M"
  shows "(+ a. g a return M x) = g x"
proof-
  interpret prob_space "return M x" by (rule prob_space_return[OF x  space M])
  have "(+ a. g a return M x) = (+ a. g x return M x)" using assms
    by (intro nn_integral_cong_AE) (auto simp: AE_return)
  also have "... = g x"
    using nn_integral_const[of "return M x"] emeasure_space_1 by simp
  finally show ?thesis .
qed

lemma integral_return:
  fixes g :: "_  'a :: {banach, second_countable_topology}"
  assumes "x  space M" "g  borel_measurable M"
  shows "(a. g a return M x) = g x"
proof-
  interpret prob_space "return M x" by (rule prob_space_return[OF x  space M])
  have "(a. g a return M x) = (a. g x return M x)" using assms
    by (intro integral_cong_AE) (auto simp: AE_return)
  then show ?thesis
    using prob_space by simp
qed

lemma return_measurable[measurable]: "return N  measurable N (subprob_algebra N)"
  by (rule measurable_subprob_algebra) (auto simp: subprob_space_return)

lemma distr_return:
  assumes "f  measurable M N" and "x  space M"
  shows "distr (return M x) N f = return N (f x)"
  using assms by (intro measure_eqI) (simp_all add: indicator_def emeasure_distr)

lemma return_restrict_space:
  "Ω  sets M  return (restrict_space M Ω) x = restrict_space (return M x) Ω"
  by (auto intro!: measure_eqI simp: sets_restrict_space emeasure_restrict_space)

lemma measurable_distr2:
  assumes f[measurable]: "case_prod f  measurable (L M M) N"
  assumes g[measurable]: "g  measurable L (subprob_algebra M)"
  shows "(λx. distr (g x) N (f x))  measurable L (subprob_algebra N)"
proof -
  have "(λx. distr (g x) N (f x))  measurable L (subprob_algebra N)
     (λx. distr (return L x M g x) N (case_prod f))  measurable L (subprob_algebra N)"
  proof (rule measurable_cong)
    fix x assume x: "x  space L"
    have gx: "g x  space (subprob_algebra M)"
      using measurable_space[OF g x] .
    then have [simp]: "sets (g x) = sets M"
      by (simp add: space_subprob_algebra)
    then have [simp]: "space (g x) = space M"
      by (rule sets_eq_imp_space_eq)
    let ?R = "return L x"
    from measurable_compose_Pair1[OF x f] have f_M': "f x  measurable M N"
      by simp
    interpret subprob_space "g x"
      using gx by (simp add: space_subprob_algebra)
    have space_pair_M'[simp]: "X. space (X M g x) = space (X M M)"
      by (simp add: space_pair_measure)
    show "distr (g x) N (f x) = distr (?R M g x) N (case_prod f)" (is "?l = ?r")
    proof (rule measure_eqI)
      show "sets ?l = sets ?r"
        by simp
    next
      fix A assume "A  sets ?l"
      then have A[measurable]: "A  sets N"
        by simp
      then have "emeasure ?r A = emeasure (?R M g x) ((λ(x, y). f x y) -` A  space (?R M g x))"
        by (auto simp add: emeasure_distr f_M' cong: measurable_cong_sets)
      also have " = (+M''. emeasure (g x) (f M'' -` A  space M) ?R)"
        apply (subst emeasure_pair_measure_alt)
        apply (force simp add: f_M' cong: measurable_cong_sets intro!: measurable_sets[OF _ A])
        apply (intro nn_integral_cong arg_cong[where f="emeasure (g x)"])
        apply (auto simp: space_subprob_algebra space_pair_measure)
        done
      also have " = emeasure (g x) (f x -` A  space M)"
        by (subst nn_integral_return)
           (auto simp: x intro!: measurable_emeasure)
      also have " = emeasure ?l A"
        by (simp add: emeasure_distr f_M' cong: measurable_cong_sets)
      finally show "emeasure ?l A = emeasure ?r A" ..
    qed
  qed
  also have ""
  proof (intro measurable_compose[OF measurable_pair_measure measurable_distr])
    show "return L  L M subprob_algebra L"
      by (rule return_measurable)
  qed measurable
  finally show ?thesis .
qed

lemma nn_integral_measurable_subprob_algebra2:
  assumes f[measurable]: "(λ(x, y). f x y)  borel_measurable (M M N)"
  assumes N[measurable]: "L  measurable M (subprob_algebra N)"
  shows "(λx. integralN (L x) (f x))  borel_measurable M"
proof -
  note nn_integral_measurable_subprob_algebra[measurable]
  note measurable_distr2[measurable]
  have "(λx. integralN (distr (L x) (M M N) (λy. (x, y))) (λ(x, y). f x y))  borel_measurable M"
    by measurable
  then show "(λx. integralN (L x) (f x))  borel_measurable M"
    by (rule measurable_cong[THEN iffD1, rotated])
       (simp add: nn_integral_distr)
qed

lemma emeasure_measurable_subprob_algebra2:
  assumes A[measurable]: "(SIGMA x:space M. A x)  sets (M M N)"
  assumes L[measurable]: "L  measurable M (subprob_algebra N)"
  shows "(λx. emeasure (L x) (A x))  borel_measurable M"
proof -
  { fix x assume "x  space M"
    then have "Pair x -` Sigma (space M) A = A x"
      by auto
    with sets_Pair1[OF A, of x] have "A x  sets N"
      by auto }
  note ** = this

  have *: "x. fst x  space M  snd x  A (fst x)  x  (SIGMA x:space M. A x)"
    by (auto simp: fun_eq_iff)
  have MN: "Measurable.pred (M M N) (λw. w  Sigma (space M) A)"
    by auto
  then have "(λ(x, y). indicator (A x) y::ennreal)  borel_measurable (M M N)"
    apply measurable
    by (smt (verit, best) MN measurable_cong mem_Sigma_iff prod.collapse space_pair_measure)
  then have "(λx. integralN (L x) (indicator (A x)))  borel_measurable M"
    by (intro nn_integral_measurable_subprob_algebra2[where N=N] L)
  then show "(λx. emeasure (L x) (A x))  borel_measurable M"
    by (smt (verit) "**" L measurable_cong_simp nn_integral_indicator sets_kernel)
qed

lemma measure_measurable_subprob_algebra2:
  assumes A[measurable]: "(SIGMA x:space M. A x)  sets (M M N)"
  assumes L[measurable]: "L  measurable M (subprob_algebra N)"
  shows "(λx. measure (L x) (A x))  borel_measurable M"
  unfolding measure_def
  by (intro borel_measurable_enn2real emeasure_measurable_subprob_algebra2[OF assms])

definition "select_sets M = (SOME N. sets M = sets (subprob_algebra N))"

lemma select_sets1:
  "sets M = sets (subprob_algebra N)  sets M = sets (subprob_algebra (select_sets M))"
  unfolding select_sets_def by (rule someI)

lemma sets_select_sets[simp]:
  assumes sets: "sets M = sets (subprob_algebra N)"
  shows "sets (select_sets M) = sets N"
  unfolding select_sets_def
proof (rule someI2)
  show "sets M = sets (subprob_algebra N)"
    by fact
next
  fix L assume "sets M = sets (subprob_algebra L)"
  with sets have eq: "space (subprob_algebra N) = space (subprob_algebra L)"
    by (intro sets_eq_imp_space_eq) simp
  show "sets L = sets N"
  proof cases
    assume "space (subprob_algebra N) = {}"
    with space_subprob_algebra_empty_iff[of N] space_subprob_algebra_empty_iff[of L]
    show ?thesis
      by (simp add: eq space_empty_iff)
  next
    assume "space (subprob_algebra N)  {}"
    with eq show ?thesis
      by (smt (verit) equals0I mem_Collect_eq space_subprob_algebra)
  qed
qed

lemma space_select_sets[simp]:
  "sets M = sets (subprob_algebra N)  space (select_sets M) = space N"
  by (intro sets_eq_imp_space_eq sets_select_sets)

subsection ‹Join›

definition join :: "'a measure measure  'a measure" where
  "join M = measure_of (space (select_sets M)) (sets (select_sets M)) (λB. + M'. emeasure M' B M)"

lemma
  shows space_join[simp]: "space (join M) = space (select_sets M)"
    and sets_join[simp]: "sets (join M) = sets (select_sets M)"
  by (simp_all add: join_def)

lemma emeasure_join:
  assumes M[simp, measurable_cong]: "sets M = sets (subprob_algebra N)" and A: "A  sets N"
  shows "emeasure (join M) A = (+ M'. emeasure M' A M)"
proof (rule emeasure_measure_of[OF join_def])
  show "countably_additive (sets (join M)) (λB. + M'. emeasure M' B M)"
  proof (rule countably_additiveI)
    fix A :: "nat  'a set" assume A: "range A  sets (join M)" "disjoint_family A"
    have "(i. + M'. emeasure M' (A i) M) = (+M'. (i. emeasure M' (A i)) M)"
      using A by (subst nn_integral_suminf) (auto simp: measurable_emeasure_subprob_algebra)
    also have " = (+M'. emeasure M' (i. A i) M)"
    proof (rule nn_integral_cong)
      fix M' assume "M'  space M"
      then show "(i. emeasure M' (A i)) = emeasure M' (i. A i)"
        using A sets_eq_imp_space_eq[OF M] by (simp add: suminf_emeasure space_subprob_algebra)
    qed
    finally show "(i. +M'. emeasure M' (A i) M) = (+M'. emeasure M' (i. A i) M)" .
  qed
qed (auto simp: A sets.space_closed positive_def)

lemma measurable_join:
  "join  measurable (subprob_algebra (subprob_algebra N)) (subprob_algebra N)"
proof (cases "space N  {}", rule measurable_subprob_algebra)
  fix A assume "A  sets N"
  let ?B = "borel_measurable (subprob_algebra (subprob_algebra N))"
  have "(λM'. emeasure (join M') A)  ?B  (λM'. (+ M''. emeasure M'' A M'))  ?B"
  proof (rule measurable_cong)
    fix M' assume "M'  space (subprob_algebra (subprob_algebra N))"
    then show "emeasure (join M') A = (+ M''. emeasure M'' A M')"
      by (intro emeasure_join) (auto simp: space_subprob_algebra Asets N)
  qed
  also have "(λM'. +M''. emeasure M'' A M')  ?B"
    using measurable_emeasure_subprob_algebra[OF Asets N]
    by (rule nn_integral_measurable_subprob_algebra)
  finally show "(λM'. emeasure (join M') A)  borel_measurable (subprob_algebra (subprob_algebra N))" .
next
  assume [simp]: "space N  {}"
  fix M assume M: "M  space (subprob_algebra (subprob_algebra N))"
  then have "(+M'. emeasure M' (space N) M)  (+M'. 1 M)"
  proof (intro nn_integral_mono)
    show "x. M  space (subprob_algebra (subprob_algebra N)); x  space M
          emeasure x (space N)  1"
      by (smt (verit) mem_Collect_eq sets_eq_imp_space_eq space_subprob_algebra subprob_space.subprob_emeasure_le_1)
  qed
  with M show "subprob_space (join M)"
    by (intro subprob_spaceI)
       (auto simp: emeasure_join space_subprob_algebra M dest: subprob_space.emeasure_space_le_1)
next
  assume "¬(space N  {})"
  thus ?thesis by (simp add: measurable_empty_iff space_subprob_algebra_empty_iff)
qed (auto simp: space_subprob_algebra)

lemma nn_integral_join:
  assumes f: "f  borel_measurable N"
    and M[measurable_cong]: "sets M = sets (subprob_algebra N)"
  shows "(+x. f x join M) = (+M'. +x. f x M' M)"
  using f
proof induct
  case (cong f g)
  moreover have "integralN (join M) f = integralN (join M) g"
    by (intro nn_integral_cong cong) (simp add: M)
  moreover from M have "(+ M'. integralN M' f M) = (+ M'. integralN M' g M)"
    by (intro nn_integral_cong cong)
       (auto simp add: space_subprob_algebra dest!: sets_eq_imp_space_eq)
  ultimately show ?case
    by simp
next
  case (set A)
  with M have "(+ M'. integralN M' (indicator A) M) = (+ M'. emeasure M' A M)"
    by (intro nn_integral_cong nn_integral_indicator)
       (auto simp: space_subprob_algebra dest!: sets_eq_imp_space_eq)
  with set show ?case
    using M by (simp add: emeasure_join)
next
  case (mult f c)
  have "(+ M'. + x. c * f x M' M) = (+ M'. c * + x. f x M' M)"
    using mult M M[THEN sets_eq_imp_space_eq]
    by (intro nn_integral_cong nn_integral_cmult) (auto simp add: space_subprob_algebra)
  also have " = c * (+ M'. + x. f x M' M)"
    using nn_integral_measurable_subprob_algebra[OF mult(2)]
    by (intro nn_integral_cmult mult) (simp add: M)
  also have " = c * (integralN (join M) f)"
    by (simp add: mult)
  also have " = (+ x. c * f x join M)"
    using mult(2,3) by (intro nn_integral_cmult[symmetric] mult) (simp add: M cong: measurable_cong_sets)
  finally show ?case by simp
next
  case (add f g)
  have "(+ M'. + x. f x + g x M' M) = (+ M'. (+ x. f x M') + (+ x. g x M') M)"
    using add M M[THEN sets_eq_imp_space_eq]
    by (intro nn_integral_cong nn_integral_add) (auto simp add: space_subprob_algebra)
  also have " = (+ M'. + x. f x M' M) + (+ M'. + x. g x M' M)"
    using nn_integral_measurable_subprob_algebra[OF add(1)]
    using nn_integral_measurable_subprob_algebra[OF add(4)]
    by (intro nn_integral_add add) (simp_all add: M)
  also have " = (integralN (join M) f) + (integralN (join M) g)"
    by (simp add: add)
  also have " = (+ x. f x + g x join M)"
    using add by (intro nn_integral_add[symmetric] add) (simp_all add: M cong: measurable_cong_sets)
  finally show ?case by (simp add: ac_simps)
next
  case (seq F)
  have "(+ M'. + x. (SUP i. F i) x M' M) = (+ M'. (SUP i. + x. F i x M') M)"
    using seq M M[THEN sets_eq_imp_space_eq] unfolding SUP_apply
    by (intro nn_integral_cong nn_integral_monotone_convergence_SUP)
       (auto simp add: space_subprob_algebra)
  also have " = (SUP i. + M'. + x. F i x M' M)"
    using nn_integral_measurable_subprob_algebra[OF seq(1)] seq
    by (intro nn_integral_monotone_convergence_SUP)
       (simp_all add: M incseq_nn_integral incseq_def le_fun_def nn_integral_mono )
  also have " = (SUP i. integralN (join M) (F i))"
    by (simp add: seq)
  also have " = (+ x. (SUP i. F i x) join M)"
    using seq by (intro nn_integral_monotone_convergence_SUP[symmetric] seq)
                 (simp_all add: M cong: measurable_cong_sets)
  finally show ?case by (simp add: ac_simps image_comp)
qed

lemma measurable_join1:
  " f  measurable N K; sets M = sets (subprob_algebra N) 
   f  measurable (join M) K"
  by(simp add: measurable_def)

lemma
  fixes f :: "_  real"
  assumes f_measurable [measurable]: "f  borel_measurable N"
  and f_bounded: "x. x  space N  ¦f x¦  B"
  and M [measurable_cong]: "sets M = sets (subprob_algebra N)"
  and fin: "finite_measure M"
  and M_bounded: "AE M' in M. emeasure M' (space M')  ennreal B'"
  shows integrable_join: "integrable (join M) f" (is ?integrable)
  and integral_join: "integralL (join M) f =  M'. integralL M' f M" (is ?integral)
proof(case_tac [!] "space N = {}")
  assume *: "space N = {}"
  show ?integrable
    using M * by(simp add: real_integrable_def measurable_def nn_integral_empty)
  have "( M'. integralL M' f M) = ( M'. 0 M)"
  proof(rule Bochner_Integration.integral_cong)
    fix M'
    assume "M'  space M"
    with sets_eq_imp_space_eq[OF M] have "space M' = space N"
      by(auto simp add: space_subprob_algebra dest: sets_eq_imp_space_eq)
    with * show "( x. f x M') = 0" by(simp add: Bochner_Integration.integral_empty)
  qed simp
  then show ?integral
    using M * by(simp add: Bochner_Integration.integral_empty)
next
  assume *: "space N  {}"

  from * have B [simp]: "0  B" by(auto dest: f_bounded)

  have [measurable]: "f  borel_measurable (join M)" using f_measurable M
    by(rule measurable_join1)

  { fix f M'
    assume [measurable]: "f  borel_measurable N"
      and f_bounded: "x. x  space N  f x  B"
      and "M'  space M" "emeasure M' (space M')  ennreal B'"
    have "AE x in M'. ennreal (f x)  ennreal B"
    proof(rule AE_I2)
      fix x
      assume "x  space M'"
      with M'  space M sets_eq_imp_space_eq[OF M]
      have "x  space N" by(auto simp add: space_subprob_algebra dest: sets_eq_imp_space_eq)
      from f_bounded[OF this] show "ennreal (f x)  ennreal B" by simp
    qed
    then have "(+ x. ennreal (f x) M')  (+ x. ennreal B M')"
      by(rule nn_integral_mono_AE)
    also have " = ennreal B * emeasure M' (space M')" by(simp)
    also have "  ennreal B * ennreal B'" by(rule mult_left_mono)(fact, simp)
    also have "  ennreal B * ennreal ¦B'¦" by(rule mult_left_mono)(simp_all)
    finally have "(+ x. ennreal (f x) M')  ennreal (B * ¦B'¦)" by (simp add: ennreal_mult) }
  note bounded1 = this

  have bounded:
    "f.  f  borel_measurable N; x. x  space N  f x  B 
     (+ x. ennreal (f x) join M)  top"
  proof -
    fix f
    assume [measurable]: "f  borel_measurable N"
      and f_bounded: "x. x  space N  f x  B"
    have "(+ x. ennreal (f x) join M) = (+ M'. + x. ennreal (f x) M' M)"
      by(rule nn_integral_join[OF _ M]) simp
    also have "  + M'. B * ¦B'¦ M"
      using bounded1[OF f  borel_measurable N f_bounded]
      by(rule nn_integral_mono_AE[OF AE_mp[OF M_bounded AE_I2], rule_format])
    also have " = B * ¦B'¦ * emeasure M (space M)" by simp
    also have " < "
      using finite_measure.finite_emeasure_space[OF fin]
      by(simp add: ennreal_mult_less_top less_top)
    finally show "?thesis f" by simp
  qed
  have f_pos: "(+ x. ennreal (f x) join M)  "
    and f_neg: "(+ x. ennreal (- f x) join M)  "
    using f_bounded by(auto del: notI intro!: bounded simp add: abs_le_iff)

  show ?integrable using f_pos f_neg by(simp add: real_integrable_def)

  note [measurable] = nn_integral_measurable_subprob_algebra

  have int_f: "(+ x. f x join M) = + M'. + x. f x M' M"
    by(simp add: nn_integral_join[OF _ M])
  have int_mf: "(+ x. - f x join M) = (+ M'. + x. - f x M' M)"
    by(simp add: nn_integral_join[OF _ M])

  have pos_finite: "AE M' in M. (+ x. f x M')  "
    using AE_space M_bounded
  proof eventually_elim
    fix M' assume "M'  space M" "emeasure M' (space M')  ennreal B'"
    then have "(+ x. ennreal (f x) M')  ennreal (B * ¦B'¦)"
      using f_measurable by(auto intro!: bounded1 dest: f_bounded)
    then show "(+ x. ennreal (f x) M')  "
      by (auto simp: top_unique)
  qed
  hence [simp]: "(+ M'. ennreal (enn2real (+ x. f x M')) M) = (+ M'. + x. f x M' M)"
    by (rule nn_integral_cong_AE[OF AE_mp]) (simp add: less_top)
  from f_pos have [simp]: "integrable M (λM'. enn2real (+ x. f x M'))"
    by(simp add: int_f real_integrable_def nn_integral_0_iff_AE[THEN iffD2] ennreal_neg enn2real_nonneg)

  have neg_finite: "AE M' in M. (+ x. - f x M')  "
    using AE_space M_bounded
  proof eventually_elim
    fix M' assume "M'  space M" "emeasure M' (space M')  ennreal B'"
    then have "(+ x. ennreal (- f x) M')  ennreal (B * ¦B'¦)"
      using f_measurable by(auto intro!: bounded1 dest: f_bounded)
    then show "(+ x. ennreal (- f x) M')  "
      by (auto simp: top_unique)
  qed
  hence [simp]: "(+ M'. ennreal (enn2real (+ x. - f x M')) M) = (+ M'. + x. - f x M' M)"
    by (rule nn_integral_cong_AE[OF AE_mp]) (simp add: less_top)
  from f_neg have [simp]: "integrable M (λM'. enn2real (+ x. - f x M'))"
    by(simp add: int_mf real_integrable_def nn_integral_0_iff_AE[THEN iffD2] ennreal_neg enn2real_nonneg)

  have "( x. f x join M) = enn2real (+ N. +x. f x N M) - enn2real (+ N. +x. - f x N M)"
    unfolding real_lebesgue_integral_def[OF ?integrable] by (simp add: nn_integral_join[OF _ M])
  also have " = (N. enn2real (+x. f x N) M) - (N. enn2real (+x. - f x N) M)"
    using pos_finite neg_finite by (subst (1 2) integral_eq_nn_integral) (auto simp: enn2real_nonneg)
  also have " = (N. enn2real (+x. f x N) - enn2real (+x. - f x N) M)"
    by simp
  also have " = M'.  x. f x M' M"
  proof (rule integral_cong_AE)
    show "AE x in M.
        enn2real (+ x. ennreal (f x) x) - enn2real (+ x. ennreal (- f x) x) = integralL x f"
      using AE_space M_bounded
    proof eventually_elim
      fix M' assume "M'  space M" "emeasure M' (space M')  B'"
      then interpret subprob_space M'
        by (auto simp: M[THEN sets_eq_imp_space_eq] space_subprob_algebra)

      from M'  space M sets_eq_imp_space_eq[OF M]
      have [measurable_cong]: "sets M' = sets N" by(simp add: space_subprob_algebra)
      hence [simp]: "space M' = space N" by(rule sets_eq_imp_space_eq)
      have "integrable M' f"
        by(rule integrable_const_bound[where B=B])(auto simp add: f_bounded)
      then show "enn2real (+ x. f x M') - enn2real (+ x. - f x M') =  x. f x M'"
        by(simp add: real_lebesgue_integral_def)
    qed
  qed simp_all
  finally show ?integral by simp
qed

lemma join_assoc:
  assumes M[measurable_cong]: "sets M = sets (subprob_algebra (subprob_algebra N))"
  shows "join (distr M (subprob_algebra N) join) = join (join M)"
proof (rule measure_eqI)
  fix A assume "A  sets (join (distr M (subprob_algebra N) join))"
  then have A: "A  sets N" by simp
  show "emeasure (join (distr M (subprob_algebra N) join)) A = emeasure (join (join M)) A"
    using measurable_join[of N]
    by (auto simp: M A nn_integral_distr emeasure_join measurable_emeasure_subprob_algebra
                   sets_eq_imp_space_eq[OF M] space_subprob_algebra nn_integral_join[OF _ M]
             intro!: nn_integral_cong emeasure_join)
qed (simp add: M)

lemma join_return:
  assumes "sets M = sets N" and "subprob_space M"
  shows "join (return (subprob_algebra N) M) = M"
  by (rule measure_eqI)
     (simp_all add: emeasure_join space_subprob_algebra
                    measurable_emeasure_subprob_algebra nn_integral_return assms)

lemma join_return':
  assumes "sets N = sets M"
  shows "join (distr M (subprob_algebra N) (return N)) = M"
proof (rule measure_eqI)
  fix A
  have "return N  measurable M (subprob_algebra N)"
    using assms by auto
  moreover
  assume "A  sets (join (distr M (subprob_algebra N) (return N)))"
  ultimately show "emeasure (join (distr M (subprob_algebra N) (return N))) A = emeasure M A"
    by (simp add: emeasure_join nn_integral_distr measurable_emeasure_subprob_algebra assms)
qed (simp add: assms)

lemma join_distr_distr:
  fixes f :: "'a  'b" and M :: "'a measure measure" and N :: "'b measure"
  assumes "sets M = sets (subprob_algebra R)" and "f  measurable R N"
  shows "join (distr M (subprob_algebra N) (λM. distr M N f)) = distr (join M) N f" (is "?r = ?l")
proof (rule measure_eqI)
  fix A assume "A  sets ?r"
  hence A_in_N: "A  sets N" by simp

  from assms have "f  measurable (join M) N"
      by (simp cong: measurable_cong_sets)
  moreover from assms and A_in_N have "f-`A  space R  sets R"
      by (intro measurable_sets) simp_all
  ultimately have "emeasure (distr (join M) N f) A = +M'. emeasure M' (f-`A  space R) M"
      by (simp_all add: A_in_N emeasure_distr emeasure_join assms)

  also have "... = + x. emeasure (distr x N f) A M" using A_in_N
  proof (intro nn_integral_cong, subst emeasure_distr)
    fix M' assume "M'  space M"
    from assms have "space M = space (subprob_algebra R)"
        using sets_eq_imp_space_eq by blast
    with M'  space M have [simp]: "sets M' = sets R" using space_subprob_algebra by blast
    show "f  measurable M' N" by (simp cong: measurable_cong_sets add: assms)
    have "space M' = space R" by (rule sets_eq_imp_space_eq) simp
    thus "emeasure M' (f -` A  space R) = emeasure M' (f -` A  space M')" by simp
  qed

  also have "(λM. distr M N f)  measurable M (subprob_algebra N)"
      by (simp cong: measurable_cong_sets add: assms measurable_distr)
  hence "(+ x. emeasure (distr x N f) A M) =
             emeasure (join (distr M (subprob_algebra N) (λM. distr M N f))) A"
      by (simp_all add: emeasure_join assms A_in_N nn_integral_distr measurable_emeasure_subprob_algebra)
  finally show "emeasure ?r A = emeasure ?l A" ..
qed simp

definition bind :: "'a measure  ('a  'b measure)  'b measure" where
  "bind M f = (if space M = {} then count_space {} else
    join (distr M (subprob_algebra (f (SOME x. x  space M))) f))"

adhoc_overloading Monad_Syntax.bind bind

lemma bind_empty:
  "space M = {}  bind M f = count_space {}"
  by (simp add: bind_def)

lemma bind_nonempty:
  "space M  {}  bind M f = join (distr M (subprob_algebra (f (SOME x. x  space M))) f)"
  by (simp add: bind_def)

lemma sets_bind_empty: "sets M = {}  sets (bind M f) = {{}}"
  by auto

lemma space_bind_empty: "space M = {}  space (bind M f) = {}"
  by (simp add: bind_def)

lemma sets_bind[simp, measurable_cong]:
  assumes f: "x. x  space M  sets (f x) = sets N" and M: "space M  {}"
  shows "sets (bind M f) = sets N"
  using f [of "SOME x. x  space M"] by (simp add: bind_nonempty M some_in_eq)

lemma space_bind[simp]:
  assumes "x. x  space M  sets (f x) = sets N" and "space M  {}"
  shows "space (bind M f) = space N"
  using assms by (intro sets_eq_imp_space_eq sets_bind)

lemma bind_cong_All:
  assumes "x  space M. f x = g x"
  shows "bind M f = bind M g"
proof (cases "space M = {}")
  assume "space M  {}"
  hence "(SOME x. x  space M)  space M" by (rule_tac someI_ex) blast
  with assms have "f (SOME x. x  space M) = g (SOME x. x  space M)" by blast
  with space M  {} and assms show ?thesis by (simp add: bind_nonempty cong: distr_cong)
qed (simp add: bind_empty)

lemma bind_cong:
  "M = N  (x. x  space M  f x = g x)  bind M f = bind N g"
  using bind_cong_All[of M f g] by auto

lemma bind_nonempty':
  assumes "f  measurable M (subprob_algebra N)" "x  space M"
  shows "bind M f = join (distr M (subprob_algebra N) f)"
proof -
  have "join (distr M (subprob_algebra (f (SOME x. x  space M))) f) = join (distr M (subprob_algebra N) f)"
    by (metis assms someI_ex subprob_algebra_cong subprob_measurableD(2))
  with assms show ?thesis
    by (metis bind_nonempty empty_iff)
qed

lemma bind_nonempty'':
  assumes "f  measurable M (subprob_algebra N)" "space M  {}"
  shows "bind M f = join (distr M (subprob_algebra N) f)"
  using assms by (auto intro: bind_nonempty')

lemma emeasure_bind:
    "space M  {}; f  measurable M (subprob_algebra N);X  sets N
       emeasure (M  f) X = +x. emeasure (f x) X M"
  by (simp add: bind_nonempty'' emeasure_join nn_integral_distr measurable_emeasure_subprob_algebra)

lemma nn_integral_bind:
  assumes f: "f  borel_measurable B"
  assumes N: "N  measurable M (subprob_algebra B)"
  shows "(+x. f x (M  N)) = (+x. +y. f y N x M)"
proof cases
  assume M: "space M  {}" show ?thesis
    unfolding bind_nonempty''[OF N M] nn_integral_join[OF f sets_distr]
    by (rule nn_integral_distr[OF N])
       (simp add: f nn_integral_measurable_subprob_algebra)
qed (simp add: bind_empty space_empty[of M] nn_integral_count_space)

lemma AE_bind:
  assumes N[measurable]: "N  measurable M (subprob_algebra B)"
  assumes P[measurable]: "Measurable.pred B P"
  shows "(AE x in M  N. P x)  (AE x in M. AE y in N x. P y)"
proof cases
  assume M: "space M = {}" show ?thesis
    unfolding bind_empty[OF M] unfolding space_empty[OF M] by (simp add: AE_count_space)
next
  assume M: "space M  {}"
  note sets_kernel[OF N, simp]
  have *: "(+x. indicator {x. ¬ P x} x (M  N)) = (+x. indicator {xspace B. ¬ P x} x (M  N))"
    by (intro nn_integral_cong) (simp add: space_bind[OF _ M] split: split_indicator)

  have "(AE x in M  N. P x)  (+ x. integralN (N x) (indicator {x  space B. ¬ P x}) M) = 0"
    by (simp add: AE_iff_nn_integral sets_bind[OF _ M] space_bind[OF _ M] * nn_integral_bind[where B=B]
             del: nn_integral_indicator)
  also have "... = (AE x in M. integralN (N x) (indicator {x  space B. ¬ P x}) = 0)"
  proof (rule nn_integral_0_iff_AE)
    show "(λx. integralN (N x) (indicator {x  space B. ¬ P x}))  borel_measurable M"
    apply (rule measurable_compose[OF N nn_integral_measurable_subprob_algebra])
      by measurable
  qed
  also have " = (AE x in M. AE y in N x. P y)"
    apply (intro eventually_subst AE_I2)
    by (auto simp add: subprob_measurableD(1)[OF N] intro!: AE_iff_measurable[symmetric])
  finally show ?thesis .
qed

lemma measurable_bind':
  assumes M1: "f  measurable M (subprob_algebra N)" and
          M2: "case_prod g  measurable (M M N) (subprob_algebra R)"
  shows "(λx. bind (f x) (g x))  measurable M (subprob_algebra R)"
proof (subst measurable_cong)
  fix x assume x_in_M: "x  space M"
  with assms have "space (f x)  {}"
      by (blast dest: subprob_space_kernel subprob_space.subprob_not_empty)
  moreover from M2 x_in_M have "g x  measurable (f x) (subprob_algebra R)"
      by (subst measurable_cong_sets[OF sets_kernel[OF M1 x_in_M] refl])
         (auto dest: measurable_Pair2)
  ultimately show "bind (f x) (g x) = join (distr (f x) (subprob_algebra R) (g x))"
      by (simp_all add: bind_nonempty'')
next
  show "(λw. join (distr (f w) (subprob_algebra R) (g w)))  measurable M (subprob_algebra R)"
    apply (rule measurable_compose[OF _ measurable_join])
    apply (rule measurable_distr2[OF M2 M1])
    done
qed

lemma measurable_bind[measurable (raw)]:
  assumes M1: "f  measurable M (subprob_algebra N)" and
          M2: "(λx. g (fst x) (snd x))  measurable (M M N) (subprob_algebra R)"
  shows "(λx. bind (f x) (g x))  measurable M (subprob_algebra R)"
  using assms by (auto intro: measurable_bind' simp: measurable_split_conv)

lemma measurable_bind2:
  assumes "f  measurable M (subprob_algebra N)" and "g  measurable N (subprob_algebra R)"
  shows "(λx. bind (f x) g)  measurable M (subprob_algebra R)"
    using assms by (intro measurable_bind' measurable_const) auto

lemma subprob_space_bind:
  assumes "subprob_space M" "f  measurable M (subprob_algebra N)"
  shows "subprob_space (M  f)"
proof (rule subprob_space_kernel[of "λx. x  f"])
  show "(λx. x  f)  measurable (subprob_algebra M) (subprob_algebra N)"
    by (rule measurable_bind, rule measurable_ident_sets, rule refl,
        rule measurable_compose[OF measurable_snd assms(2)])
  from assms(1) show "M  space (subprob_algebra M)"
    by (simp add: space_subprob_algebra)
qed

lemma
  fixes f :: "_  real"
  assumes f_measurable [measurable]: "f  borel_measurable K"
  and f_bounded: "x. x  space K  ¦f x¦  B"
  and N [measurable]: "N  measurable M (subprob_algebra K)"
  and fin: "finite_measure M"
  and M_bounded: "AE x in M. emeasure (N x) (space (N x))  ennreal B'"
  shows integrable_bind: "integrable (bind M N) f" (is ?integrable)
  and integral_bind: "integralL (bind M N) f =  x. integralL (N x) f M" (is ?integral)
proof(case_tac [!] "space M = {}")
  assume [simp]: "space M  {}"
  interpret finite_measure M by(rule fin)

  have "integrable (join (distr M (subprob_algebra K) N)) f"
    using f_measurable f_bounded
    by(rule integrable_join[where B'=B'])(simp_all add: finite_measure_distr AE_distr_iff M_bounded)
  then show ?integrable by(simp add: bind_nonempty''[where N=K])

  have "integralL (join (distr M (subprob_algebra K) N)) f =  M'. integralL M' f distr M (subprob_algebra K) N"
    using f_measurable f_bounded
    by(rule integral_join[where B'=B'])(simp_all add: finite_measure_distr AE_distr_iff M_bounded)
  also have " =  x. integralL (N x) f M"
    by(rule integral_distr)(simp_all add: integral_measurable_subprob_algebra[OF _])
  finally show ?integral by(simp add: bind_nonempty''[where N=K])
qed(simp_all add: bind_def integrable_count_space lebesgue_integral_count_space_finite Bochner_Integration.integral_empty)

lemma (in prob_space) prob_space_bind:
  assumes ae: "AE x in M. prob_space (N x)"
    and N[measurable]: "N  measurable M (subprob_algebra S)"
  shows "prob_space (M  N)"
proof
  have "emeasure (M  N) (space (M  N)) = (+x. emeasure (N x) (space (N x)) M)"
    by (subst emeasure_bind[where N=S])
       (auto simp: not_empty space_bind[OF sets_kernel] subprob_measurableD[OF N] intro!: nn_integral_cong)
  also have " = (+x. 1 M)"
    using ae by (intro nn_integral_cong_AE, eventually_elim) (rule prob_space.emeasure_space_1)
  finally show "emeasure (M  N) (space (M  N)) = 1"
    by (simp add: emeasure_space_1)
qed

lemma (in subprob_space) bind_in_space:
  "A  measurable M (subprob_algebra N)  (M  A)  space (subprob_algebra N)"
  by (auto simp add: space_subprob_algebra subprob_not_empty sets_kernel intro!: subprob_space_bind)
     unfold_locales

lemma (in subprob_space) measure_bind:
  assumes f: "f  measurable M (subprob_algebra N)" and X: "X  sets N"
  shows "measure (M  f) X = x. measure (f x) X M"
proof -
  interpret Mf: subprob_space "M  f"
    by (rule subprob_space_bind[OF _ f]) unfold_locales

  { fix x assume "x  space M"
    from f[THEN measurable_space, OF this] interpret subprob_space "f x"
      by (simp add: space_subprob_algebra)
    have "emeasure (f x) X = ennreal (measure (f x) X)" "measure (f x) X  1"
      by (auto simp: emeasure_eq_measure subprob_measure_le_1) }
  note this[simp]

  have "emeasure (M  f) X = +x. emeasure (f x) X M"
    using subprob_not_empty f X by (rule emeasure_bind)
  also have " = +x. ennreal (measure (f x) X) M"
    by (intro nn_integral_cong) simp
  also have " = x. measure (f x) X M"
    by (intro nn_integral_eq_integral integrable_const_bound[where B=1]
              measure_measurable_subprob_algebra2[OF _ f] pair_measureI X)
       (auto simp: measure_nonneg)
  finally show ?thesis
    by (simp add: Mf.emeasure_eq_measure measure_nonneg integral_nonneg)
qed

lemma emeasure_bind_const:
    "space M  {}  X  sets N  subprob_space N 
         emeasure (M  (λx. N)) X = emeasure N X * emeasure M (space M)"
  by (simp add: bind_nonempty emeasure_join nn_integral_distr
                space_subprob_algebra measurable_emeasure_subprob_algebra)

lemma emeasure_bind_const':
  assumes "subprob_space M" "subprob_space N"
  shows "emeasure (M  (λx. N)) X = emeasure N X * emeasure M (space M)"
using assms
proof (case_tac "X  sets N")
  fix X assume "X  sets N"
  thus "emeasure (M  (λx. N)) X = emeasure N X * emeasure M (space M)" using assms
      by (subst emeasure_bind_const)
         (simp_all add: subprob_space.subprob_not_empty subprob_space.emeasure_space_le_1)
next
  fix X assume "X  sets N"
  with assms show "emeasure (M  (λx. N)) X = emeasure N X * emeasure M (space M)"
      by (simp add: sets_bind[of _ _ N] subprob_space.subprob_not_empty
                    space_subprob_algebra emeasure_notin_sets)
qed

lemma emeasure_bind_const_prob_space:
  assumes "prob_space M" "subprob_space N"
  shows "emeasure (M  (λx. N)) X = emeasure N X"
  using assms by (simp add: emeasure_bind_const' prob_space_imp_subprob_space
                            prob_space.emeasure_space_1)

lemma bind_return:
  assumes "f  measurable M (subprob_algebra N)" and "x  space M"
  shows "bind (return M x) f = f x"
  using sets_kernel[OF assms] assms
  by (simp_all add: distr_return join_return subprob_space_kernel bind_nonempty'
               cong: subprob_algebra_cong)

lemma bind_return':
  shows "bind M (return M) = M"
  by (cases "space M = {}")
     (simp_all add: bind_empty space_empty[symmetric] bind_nonempty join_return'
               cong: subprob_algebra_cong)

lemma distr_bind:
  assumes N: "N  measurable M (subprob_algebra K)" "space M  {}"
  assumes f: "f  measurable K R"
  shows "distr (M  N) R f = (M  (λx. distr (N x) R f))"
proof -
  have "distr (join (distr M (subprob_algebra K) N)) R f =
       join (distr M (subprob_algebra R) (λx. distr (N x) R f))"
    by (simp add: assms distr_distr[OF measurable_distr] comp_def flip: join_distr_distr)
  with assms show ?thesis
    unfolding bind_nonempty''[OF N]
    by (smt (verit) bind_nonempty sets_distr subprob_algebra_cong)
qed

lemma bind_distr:
  assumes f[measurable]: "f  measurable M X"
  assumes N[measurable]: "N  measurable X (subprob_algebra K)" and "space M  {}"
  shows "(distr M X f  N) = (M  (λx. N (f x)))"
proof -
  have "space X  {}" "space M  {}"
    using space M  {} f[THEN measurable_space] by auto
  then show ?thesis
    by (simp add: bind_nonempty''[where N=K] distr_distr comp_def)
qed

lemma bind_count_space_singleton:
  assumes "subprob_space (f x)"
  shows "count_space {x}  f = f x"
proof-
  have A: "A. A  {x}  A = {}  A = {x}" by auto
  have "count_space {x} = return (count_space {x}) x"
    by (intro measure_eqI) (auto dest: A)
  also have "...  f = f x"
    by (subst bind_return[of _ _ "f x"]) (auto simp: space_subprob_algebra assms)
  finally show ?thesis .
qed

lemma restrict_space_bind:
  assumes N: "N  measurable M (subprob_algebra K)"
  assumes "space M  {}"
  assumes X[simp]: "X  sets K" "X  {}"
  shows "restrict_space (bind M N) X = bind M (λx. restrict_space (N x) X)"
proof (rule measure_eqI)
  note N_sets = sets_bind[OF sets_kernel[OF N] assms(2), simp]
  note N_space = sets_eq_imp_space_eq[OF N_sets, simp]
  show "sets (restrict_space (bind M N) X) = sets (bind M (λx. restrict_space (N x) X))"
    by (simp add: sets_restrict_space assms(2) sets_bind[OF sets_kernel[OF restrict_space_measurable[OF assms(4,3,1)]]])
  fix A assume "A  sets (restrict_space (M  N) X)"
  with X have A: "A  sets K" "A  X"
    by (auto simp: sets_restrict_space)
  then have "emeasure (restrict_space (M  N) X) A = emeasure (M  N) A"
    by (simp add: emeasure_restrict_space)
  also have " = + x. emeasure (N x) A M"
    by (metis A  sets K N space M  {} emeasure_bind)
  also have "... = + x. emeasure (restrict_space (N x) X) A M"
    using A assms by (smt (verit, best) emeasure_restrict_space nn_integral_cong sets.Int_space_eq2 subprob_measurableD(2))
  also have " = emeasure (M  (λx. restrict_space (N x) X)) A"
    using A assms
    apply (subst emeasure_bind[OF _ restrict_space_measurable])
    apply (auto simp: sets_restrict_space)
    done
  finally show "emeasure (restrict_space (M  N) X) A = emeasure (M  (λx. restrict_space (N x) X)) A" .
qed

lemma bind_restrict_space:
  assumes A: "A  space M  {}" "A  space M  sets M"
  and f: "f  measurable (restrict_space M A) (subprob_algebra N)"
  shows "restrict_space M A  f = M  (λx. if x  A then f x else null_measure (f (SOME x. x  A  x  space M)))"
  (is "?lhs = ?rhs" is "_ = M  ?f")
proof -
  let ?P = "λx. x  A  x  space M"
  let ?x = "Eps ?P"
  let ?c = "null_measure (f ?x)"
  from A have "?P ?x" by-(rule someI_ex, blast)
  hence "?x  space (restrict_space M A)" by(simp add: space_restrict_space)
  with f have "f ?x  space (subprob_algebra N)" by(rule measurable_space)
  hence sps: "subprob_space (f ?x)"
    and sets: "sets (f ?x) = sets N"
    by(simp_all add: space_subprob_algebra)
  have "space (f ?x)  {}" using sps by(rule subprob_space.subprob_not_empty)
  moreover have "sets ?c = sets N" by(simp add: null_measure_def)(simp add: sets)
  ultimately have c: "?c  space (subprob_algebra N)"
    by(simp add: space_subprob_algebra subprob_space_null_measure)
  from f A c have f': "?f  measurable M (subprob_algebra N)"
    by(simp add: measurable_restrict_space_iff)

  from A have [simp]: "space M  {}" by blast

  have "?lhs = join (distr (restrict_space M A) (subprob_algebra N) f)"
    using assms by(simp add: space_restrict_space bind_nonempty'')
  also have " = join (distr M (subprob_algebra N) ?f)"
    by(rule measure_eqI)(auto simp add: emeasure_join nn_integral_distr nn_integral_restrict_space f f' A intro: nn_integral_cong)
  also have " = ?rhs" using f' by(simp add: bind_nonempty'')
  finally show ?thesis .
qed

lemma bind_const': "prob_space M; subprob_space N  M  (λx. N) = N"
  by (intro measure_eqI)
     (simp_all add: space_subprob_algebra prob_space.not_empty emeasure_bind_const_prob_space)

lemma bind_return_distr:
  assumes "space M  {}" "f  measurable M N"
  shows "bind M (return N  f) = distr M N f"
proof -
  have "bind M (return N  f)
      = join (distr M (subprob_algebra (return N (f (SOME x. x  space M)))) (return N  f))"
    by (simp add: Giry_Monad.bind_def assms)
  also have " = join (distr M (subprob_algebra N) (return N  f))"
    by (metis sets_return subprob_algebra_cong)
  also have " = distr M N f"
    by (metis assms(2) distr_distr join_return' return_measurable sets_distr)
  finally show ?thesis .
qed

lemma bind_return_distr':
  "space M  {}  f  measurable M N  bind M (λx. return N (f x)) = distr M N f"
  using bind_return_distr[of M f N] by (simp add: comp_def)

lemma bind_assoc:
  fixes f :: "'a  'b measure" and g :: "'b  'c measure"
  assumes M1: "f  measurable M (subprob_algebra N)" and M2: "g  measurable N (subprob_algebra R)"
  shows "bind (bind M f) g = bind M (λx. bind (f x) g)"
proof (cases "space M = {}")
  assume [simp]: "space M  {}"
  from assms have [simp]: "space N  {}" "space R  {}"
      by (auto simp: measurable_empty_iff space_subprob_algebra_empty_iff)
  from assms have sets_fx[simp]: "x. x  space M  sets (f x) = sets N"
      by (simp add: sets_kernel)
  have ex_in: "A. A  {}  x. x  A" by blast
  note sets_some[simp] = sets_kernel[OF M1 someI_ex[OF ex_in[OF space M  {}]]]
                         sets_kernel[OF M2 someI_ex[OF ex_in[OF space N  {}]]]
  note space_some[simp] = sets_eq_imp_space_eq[OF this(1)] sets_eq_imp_space_eq[OF this(2)]


  have *: "(λx. distr x (subprob_algebra R) g)  f  M M subprob_algebra (subprob_algebra R)"
    using M1 M2 measurable_comp measurable_distr by blast
  have "bind M (λx. bind (f x) g) =
        join (distr M (subprob_algebra R) (join  (λx. (distr x (subprob_algebra R) g))  f))"
    by (simp add: sets_eq_imp_space_eq[OF sets_fx] bind_nonempty o_def
             cong: subprob_algebra_cong distr_cong)
  also have "distr M (subprob_algebra R) (join  (λx. (distr x (subprob_algebra R) g))  f) =
             distr (distr (distr M (subprob_algebra N) f)
                          (subprob_algebra (subprob_algebra R))
                          (λx. distr x (subprob_algebra R) g))
                   (subprob_algebra R) join"
    by (simp add: distr_distr M1 M2 measurable_distr measurable_join fun.map_comp *)
  also have "join ... = bind (bind M f) g"
      by (simp add: join_assoc join_distr_distr M2 bind_nonempty cong: subprob_algebra_cong)
  finally show ?thesis ..
qed (simp add: bind_empty)

lemma double_bind_assoc:
  assumes Mg: "g  measurable N (subprob_algebra N')"
  assumes Mf: "f  measurable M (subprob_algebra M')"
  assumes Mh: "case_prod h  measurable (M M M') N"
  shows "do {x  M; y  f x; g (h x y)} = do {x  M; y  f x; return N (h x y)}  g"
proof-
  have "do {x  M; y  f x; return N (h x y)}  g =
            do {x  M; do {y  f x; return N (h x y)}  g}"
    using Mh by (auto intro!: bind_assoc measurable_bind'[OF Mf] Mf Mg
                      measurable_compose[OF _ return_measurable] simp: measurable_split_conv)
  also from Mh have "x. x  space M  h x  measurable M' N" by measurable
  hence "do {x  M; do {y  f x; return N (h x y)}  g} =
            do {x  M; y  f x; return N (h x y)  g}"
    apply (intro ballI bind_cong refl bind_assoc)
    apply (subst measurable_cong_sets[OF sets_kernel[OF Mf] refl], simp)
    apply (rule measurable_compose[OF _ return_measurable], auto intro: Mg)
    done
  also have "x. x  space M  space (f x) = space M'"
    by (intro sets_eq_imp_space_eq sets_kernel[OF Mf])
  with measurable_space[OF Mh]
    have "do {x  M; y  f x; return N (h x y)  g} = do {x  M; y  f x; g (h x y)}"
    by (intro ballI bind_cong bind_return[OF Mg]) (auto simp: space_pair_measure)
  finally show ?thesis ..
qed

lemma (in prob_space) M_in_subprob[measurable (raw)]: "M  space (subprob_algebra M)"
  by (simp add: space_subprob_algebra) unfold_locales

lemma (in pair_prob_space) pair_measure_eq_bind:
  "(M1 M M2) = (M1  (λx. M2  (λy. return (M1 M M2) (x, y))))"
proof (rule measure_eqI)
  have ps_M2: "prob_space M2" by unfold_locales
  note return_measurable[measurable]
  show "sets (M1 M M2) = sets (M1  (λx. M2  (λy. return (M1 M M2) (x, y))))"
    by (simp_all add: M1.not_empty M2.not_empty)
  fix A assume [measurable]: "A  sets (M1 M M2)"
  show "emeasure (M1 M M2) A = emeasure (M1  (λx. M2  (λy. return (M1 M M2) (x, y)))) A"
    by (auto simp: M2.emeasure_pair_measure M1.not_empty M2.not_empty emeasure_bind[where N="M1 M M2"]
             intro!: nn_integral_cong)
qed

lemma (in pair_prob_space) bind_rotate:
  assumes C[measurable]: "(λ(x, y). C x y)  measurable (M1 M M2) (subprob_algebra N)"
  shows "(M1  (λx. M2  (λy. C x y))) = (M2  (λy. M1  (λx. C x y)))"
proof -
  interpret swap: pair_prob_space M2 M1 by unfold_locales
  note measurable_bind[where N="M2", measurable]
  note measurable_bind[where N="M1", measurable]
  have [simp]: "M1  space (subprob_algebra M1)" "M2  space (subprob_algebra M2)"
    by (auto simp: space_subprob_algebra) unfold_locales
  have "(M1  (λx. M2  (λy. C x y))) =
    (M1  (λx. M2  (λy. return (M1 M M2) (x, y))))  (λ(x, y). C x y)"
    by (auto intro!: bind_cong simp: bind_return[where N=N] space_pair_measure bind_assoc[where N="M1 M M2" and R=N])
  also have " = (distr (M2 M M1) (M1 M M2) (λ(x, y). (y, x)))  (λ(x, y). C x y)"
    unfolding pair_measure_eq_bind[symmetric] distr_pair_swap[symmetric] ..
  also have " = (M2  (λx. M1  (λy. return (M2 M M1) (x, y))))  (λ(y, x). C x y)"
    unfolding swap.pair_measure_eq_bind[symmetric]
    by (auto simp add: space_pair_measure M1.not_empty M2.not_empty bind_distr[OF _ C] intro!: bind_cong)
  also have " = (M2  (λy. M1  (λx. C x y)))"
    by (auto intro!: bind_cong simp: bind_return[where N=N] space_pair_measure bind_assoc[where N="M2 M M1" and R=N])
  finally show ?thesis .
qed

lemma bind_return'': "sets M = sets N  M  return N = M"
   by (cases "space M = {}")
      (simp_all add: bind_empty space_empty[symmetric] bind_nonempty join_return'
                cong: subprob_algebra_cong)

lemma (in prob_space) distr_const[simp]:
  "c  space N  distr M N (λx. c) = return N c"
  by (rule measure_eqI) (auto simp: emeasure_distr emeasure_space_1)

lemma return_count_space_eq_density:
    "return (count_space M) x = density (count_space M) (indicator {x})"
  by (rule measure_eqI)
     (auto simp: indicator_inter_arith[symmetric] emeasure_density split: split_indicator)

lemma null_measure_in_space_subprob_algebra [simp]:
  "null_measure M  space (subprob_algebra M)  space M  {}"
by(simp add: space_subprob_algebra subprob_space_null_measure_iff)

subsection ‹Giry monad on probability spaces›

definition prob_algebra :: "'a measure  'a measure measure" where
  "prob_algebra K = restrict_space (subprob_algebra K) {M. prob_space M}"

lemma space_prob_algebra: "space (prob_algebra M) = {N. sets N = sets M  prob_space N}"
  unfolding prob_algebra_def by (auto simp: space_subprob_algebra space_restrict_space prob_space_imp_subprob_space)

lemma measurable_measure_prob_algebra[measurable]:
  "a  sets A  (λM. Sigma_Algebra.measure M a)  prob_algebra A M borel"
  unfolding prob_algebra_def by (intro measurable_restrict_space1 measurable_measure_subprob_algebra)

lemma measurable_prob_algebraD:
  "f  N M prob_algebra M  f  N M subprob_algebra M"
  unfolding prob_algebra_def measurable_restrict_space2_iff by auto

lemma measure_measurable_prob_algebra2:
  "Sigma (space M) A  sets (M M N)  L  M M prob_algebra N 
    (λx. Sigma_Algebra.measure (L x) (A x))  borel_measurable M"
  using measure_measurable_subprob_algebra2[of M A N L] by (auto intro: measurable_prob_algebraD)

lemma measurable_prob_algebraI:
  "(x. x  space N  prob_space (f x))  f  N M subprob_algebra M  f  N M prob_algebra M"
  unfolding prob_algebra_def by (intro measurable_restrict_space2) auto

lemma measurable_distr_prob_space:
  assumes f: "f  M M N"
  shows "(λM'. distr M' N f)  prob_algebra M M prob_algebra N"
  unfolding prob_algebra_def measurable_restrict_space2_iff
proof (intro conjI measurable_restrict_space1 measurable_distr f)
  show "(λM'. distr M' N f)  space (restrict_space (subprob_algebra M) (Collect prob_space))  Collect prob_space"
    using f by (auto simp: space_restrict_space space_subprob_algebra intro!: prob_space.prob_space_distr)
qed

lemma measurable_return_prob_space[measurable]: "return N  N M prob_algebra N"
  by (rule measurable_prob_algebraI) (auto simp: prob_space_return)

lemma measurable_distr_prob_space2[measurable (raw)]:
  assumes f: "g  L M prob_algebra M" "(λ(x, y). f x y)  L M M M N"
  shows "(λx. distr (g x) N (f x))  L M prob_algebra N"
  unfolding prob_algebra_def measurable_restrict_space2_iff
proof (intro conjI measurable_restrict_space1 measurable_distr2[where M=M] f measurable_prob_algebraD)
  show "(λx. distr (g x) N (f x))  space L  Collect prob_space"
    using f subprob_measurableD[OF measurable_prob_algebraD[OF f(1)]]
    by (auto simp: measurable_restrict_space2_iff prob_algebra_def
             intro!: prob_space.prob_space_distr)
qed

lemma measurable_bind_prob_space:
  assumes f: "f  M M prob_algebra N" and g: "g  N M prob_algebra R"
  shows "(λx. bind (f x) g)  M M prob_algebra R"
  unfolding prob_algebra_def measurable_restrict_space2_iff
proof (intro conjI measurable_restrict_space1 measurable_bind2[where N=N] f g measurable_prob_algebraD)
  show "(λx. f x  g)  space M  Collect prob_space"
    using g f subprob_measurableD[OF measurable_prob_algebraD[OF f]]
    by (auto simp: measurable_restrict_space2_iff prob_algebra_def
                intro!: prob_space.prob_space_bind[where S=R] AE_I2)
qed

lemma measurable_bind_prob_space2[measurable (raw)]:
  assumes f: "f  M M prob_algebra N" and g: "(λ(x, y). g x y)  (M M N) M prob_algebra R"
  shows "(λx. bind (f x) (g x))  M M prob_algebra R"
  unfolding prob_algebra_def measurable_restrict_space2_iff
proof (intro conjI measurable_restrict_space1 measurable_bind[where N=N] f g measurable_prob_algebraD)
  show "(λx. f x  g x)  space M  Collect prob_space"
    using g f subprob_measurableD[OF measurable_prob_algebraD[OF f]]
      using measurable_space[OF g]
    by (auto simp: measurable_restrict_space2_iff prob_algebra_def space_pair_measure Pi_iff
                intro!: prob_space.prob_space_bind[where S=R] AE_I2)
qed (use g in simp)


lemma measurable_prob_algebra_generated:
  assumes eq: "sets N = sigma_sets Ω G" and "Int_stable G" "G  Pow Ω"
  assumes subsp: "a. a  space M  prob_space (K a)"
  assumes sets: "a. a  space M  sets (K a) = sets N"
  assumes "A. A  G  (λa. emeasure (K a) A)  borel_measurable M"
  shows "K  measurable M (prob_algebra N)"
  unfolding measurable_restrict_space2_iff prob_algebra_def
proof
  show "K  M M subprob_algebra N"
  proof (rule measurable_subprob_algebra_generated[OF assms(1,2,3) _ assms(5,6)])
    fix a assume "a  space M" then show "subprob_space (K a)"
      using subsp[of a] by (intro prob_space_imp_subprob_space)
  next
    have "(λa. emeasure (K a) Ω)  borel_measurable M  (λa. 1::ennreal)  borel_measurable M"
      using sets_eq_imp_space_eq[of "sigma Ω G" N] G  Pow Ω eq sets_eq_imp_space_eq[OF sets]
        prob_space.emeasure_space_1[OF subsp]
      by (intro measurable_cong) auto
    then show "(λa. emeasure (K a) Ω)  borel_measurable M" by simp
  qed
qed (use subsp in auto)

lemma in_space_prob_algebra:
  "x  space (prob_algebra M)  emeasure x (space M) = 1"
  unfolding prob_algebra_def space_restrict_space space_subprob_algebra
  by (auto dest!: prob_space.emeasure_space_1 sets_eq_imp_space_eq)

lemma prob_space_pair:
  assumes "prob_space M" "prob_space N" shows "prob_space (M M N)"
  by (metis assms measurable_fst prob_space.distr_pair_fst prob_space_distrD)

lemma measurable_pair_prob[measurable]:
  "f  M M prob_algebra N  g  M M prob_algebra L  (λx. f x M g x)  M M prob_algebra (N M L)"
  unfolding prob_algebra_def measurable_restrict_space2_iff
  by (auto intro!: measurable_pair_measure prob_space_pair)

lemma emeasure_bind_prob_algebra:
  assumes A: "A  space (prob_algebra N)"
  assumes B: "B  N M prob_algebra L"
  assumes X: "X  sets L"
  shows "emeasure (bind A B) X = (+x. emeasure (B x) X A)"
  using A B
  by (intro emeasure_bind[OF _ _ X])
     (auto simp: space_prob_algebra measurable_prob_algebraD cong: measurable_cong_sets intro!: prob_space.not_empty)

lemma prob_space_bind':
  assumes A: "A  space (prob_algebra M)" and B: "B  M M prob_algebra N" shows "prob_space (A  B)"
  using measurable_bind_prob_space[OF measurable_const, OF A B, THEN measurable_space, of undefined "count_space UNIV"]
  by (simp add: space_prob_algebra)

lemma sets_bind':
  assumes A: "A  space (prob_algebra M)" and B: "B  M M prob_algebra N" shows "sets (A  B) = sets N"
  using measurable_bind_prob_space[OF measurable_const, OF A B, THEN measurable_space, of undefined "count_space UNIV"]
  by (simp add: space_prob_algebra)

lemma bind_cong_AE':
  assumes M: "M  space (prob_algebra L)"
    and f: "f  L M prob_algebra N" and g: "g  L M prob_algebra N"
    and ae: "AE x in M. f x = g x"
  shows "bind M f = bind M g"
proof (rule measure_eqI)
  show "sets (M  f) = sets (M  g)"
    unfolding sets_bind'[OF M f] sets_bind'[OF M g] ..
  show "A  sets (M  f)  emeasure (M  f) A = emeasure (M  g) A" for A
    unfolding sets_bind'[OF M f]
    using emeasure_bind_prob_algebra[OF M f, of A] emeasure_bind_prob_algebra[OF M g, of A] ae
    by (auto intro: nn_integral_cong_AE)
qed

lemma density_discrete:
  "countable A  sets N = Set.Pow A  (x. f x  0)  (x. x  A  f x = emeasure N {x}) 
    density (count_space A) f = N"
  by (rule measure_eqI_countable[of _ A]) (auto simp: emeasure_density)

lemma distr_density_discrete:
  fixes f'
  assumes "countable A"
  assumes "f'  borel_measurable M"
  assumes "g  measurable M (count_space A)"
  defines "f  λx. +t. (if g t = x then 1 else 0) * f' t M"
  assumes "x.  x  space M  g x  A"
  shows "density (count_space A) (λx. f x) = distr (density M f') (count_space A) g"
proof (rule density_discrete)
  fix x assume x: "x  A"
  have "f x = +t. indicator (g -` {x}  space M) t * f' t M" (is "_ = ?I") unfolding f_def
    by (intro nn_integral_cong) (simp split: split_indicator)
  also from x have in_sets: "g -` {x}  space M  sets M"
    by (intro measurable_sets[OF assms(3)]) simp
  have "?I = emeasure (density M f') (g -` {x}  space M)" unfolding f_def
    by (subst emeasure_density[OF assms(2) in_sets], subst mult.commute) (rule refl)
  also from assms(3) x have "... = emeasure (distr (density M f') (count_space A) g) {x}"
    by (subst emeasure_distr) simp_all
  finally show "f x = emeasure (distr (density M f') (count_space A) g) {x}" .
qed (use assms in auto)

lemma bind_cong_AE:
  assumes "M = N"
  assumes f: "f  measurable N (subprob_algebra B)"
  assumes g: "g  measurable N (subprob_algebra B)"
  assumes ae: "AE x in N. f x = g x"
  shows "bind M f = bind N g"
proof cases
  assume "space N = {}" then show ?thesis
    using M = N by (simp add: bind_empty)
next
  assume "space N  {}"
  show ?thesis unfolding M = N
  proof (rule measure_eqI)
    have *: "sets (N  f) = sets B"
      using sets_bind[OF sets_kernel[OF f] space N  {}] by simp
    then show "sets (N  f) = sets (N  g)"
      using sets_bind[OF sets_kernel[OF g] space N  {}] by auto
    fix A assume "A  sets (N  f)"
    then have "A  sets B"
      unfolding * .
    with ae f g space N  {} show "emeasure (N  f) A = emeasure (N  g) A"
      by (subst (1 2) emeasure_bind[where N=B]) (auto intro!: nn_integral_cong_AE)
  qed
qed

lemma bind_cong_simp: "M = N  (x. xspace M =simp=> f x = g x)  bind M f = bind N g"
  by (auto simp: simp_implies_def intro!: bind_cong)

lemma sets_bind_measurable:
  assumes f: "f  measurable M (subprob_algebra B)"
  assumes M: "space M  {}"
  shows "sets (M  f) = sets B"
  using M by (intro sets_bind[OF sets_kernel[OF f]]) auto

lemma space_bind_measurable:
  assumes f: "f  measurable M (subprob_algebra B)"
  assumes M: "space M  {}"
  shows "space (M  f) = space B"
  using M by (intro space_bind[OF sets_kernel[OF f]]) auto

lemma bind_distr_return:
  "f  M M N  g  N M L  space M  {} 
    distr M N f  (λx. return L (g x)) = distr M L (λx. g (f x))"
  by (subst bind_distr[OF _ measurable_compose[OF _ return_measurable]])
     (auto intro!: bind_return_distr')

lemma (in prob_space) AE_eq_constD:
  assumes "AE x in M. x = y"
  shows   "M = return M y" "y  space M"
proof -
  have "AE x in M. x  space M"
    by auto
  with assms have "AE x in M. y  space M"
    by eventually_elim auto
  thus "y  space M"
    by simp
  show "M = return M y"
  proof (rule measure_eqI)
    fix X assume X: "X  sets M"
    have "AE x in M. (x  X) = (x  (if y  X then space M else {}))"
      using assms by eventually_elim (use X y  space M in auto)
    hence "emeasure M X = emeasure M (if y  X then space M else {})"
      using X by (intro emeasure_eq_AE) auto
    also have " = emeasure (return M y) X"
      using X by (auto simp: emeasure_space_1)
    finally show "emeasure M X = " .
  qed auto
qed

end