Theory Bernoulli

(* Title: Bernoulli.thy
   Author: Andreas Lochbihler, ETH Zurich *)

section ‹Distributions built from coin flips›

subsection ‹ The Bernoulli distribution›

theory Bernoulli imports "HOL-Probability.Probability" begin

lemma zero_lt_num [simp]: "0 < (numeral n :: _ :: {canonically_ordered_monoid_add, semiring_char_0})"
  by (metis not_gr_zero zero_neq_numeral)

lemma ennreal_mult_numeral: "ennreal x * numeral n = ennreal (x * numeral n)"
  by (simp add: ennreal_mult'')

lemma one_plus_ennreal: "0  x  1 + ennreal x = ennreal (1 + x)"
by simp

text ‹
  We define the Bernoulli distribution as a least fixpoint instead of a loop because this
  avoids the need to add a condition flag to the distribution, which we would have to project
  out at the end again.  As the direct termination proof is so simple, we do not bother to prove
  it equivalent to a while loop.
›

partial_function (spmf) bernoulli :: "real  bool spmf" where
  "bernoulli p = do {
     b  coin_spmf;
     if b then return_spmf (p  1 / 2)
     else if p < 1 / 2 then bernoulli (2 * p)
     else bernoulli (2 * p - 1)
   }"

lemma pmf_bernoulli_None: "pmf (bernoulli p) None = 0"
proof -
  have "ereal (pmf (bernoulli p) None)  (INF nUNIV. ereal (1 / 2 ^ n))"
  proof(rule INF_greatest)
    show "ereal (pmf (bernoulli p) None)  ereal (1 / 2 ^ n)" for n
    proof(induction n arbitrary: p)
      case (Suc n)
      show ?case using Suc.IH[of "2 * p"] Suc.IH[of "2 * p - 1"]
        by(subst bernoulli.simps)(simp add: UNIV_bool max_def field_simps spmf_of_pmf_pmf_of_set[symmetric] pmf_bind_pmf_of_set ennreal_pmf_bind nn_integral_pmf_of_set del: spmf_of_pmf_pmf_of_set)
    qed(simp add: pmf_le_1)
  qed
  also have " = ereal 0"
  proof(rule LIMSEQ_unique)
    show "(λn. ereal (1 / 2 ^ n))  " by(rule LIMSEQ_INF)(simp add: field_simps decseq_SucI)
    show "(λn. ereal (1 / 2 ^ n))  ereal 0" by(simp add: LIMSEQ_divide_realpow_zero)
  qed
  finally show ?thesis by simp
qed

lemma lossless_bernoulli [simp]: "lossless_spmf (bernoulli p)"
by(simp add: lossless_iff_pmf_None pmf_bernoulli_None)

lemma [simp]: assumes "0  p" "p  1"
  shows bernoulli_True: "spmf (bernoulli p) True = p" (is ?True)
  and bernoulli_False: "spmf (bernoulli p) False = 1 - p" (is ?False)
proof -
  { have "ennreal (spmf (bernoulli p) b)  ennreal (if b then p else 1 - p)" for b using assms
    proof(induction arbitrary: p rule: bernoulli.fixp_induct[case_names adm bottom step])
      case adm show ?case by(rule cont_intro)+
    next
      case (step bernoulli' p)
      show ?case using step.prems step.IH[of "2 * p"] step.IH[of "2 * p - 1"]
        by(auto simp add: UNIV_bool max_def divide_le_posI_ennreal ennreal_mult_numeral numeral_mult_ennreal field_simps spmf_of_pmf_pmf_of_set[symmetric] ennreal_pmf_bind nn_integral_pmf_of_set one_plus_ennreal simp del: spmf_of_pmf_pmf_of_set ennreal_plus)
    qed simp }
  note this[of True] this[of False]
  moreover have "spmf (bernoulli p) True + spmf (bernoulli p) False = 1"
    by(simp add: spmf_False_conv_True)
  ultimately show ?True ?False using assms by(auto simp add: ennreal_le_iff2)
qed

lemma bernoulli_neg [simp]:
  assumes "p  0"
  shows "bernoulli p = return_spmf False"
proof -
  from assms have "ord_spmf (=) (bernoulli p) (return_spmf False)"
  proof(induction arbitrary: p rule: bernoulli.fixp_induct[case_names adm bottom step])
    case (step bernoulli' p)
    show ?case using step.prems step.IH[of "2 * p"]
      by(auto simp add: ord_spmf_return_spmf2 set_bind_spmf bind_UNION field_simps)
  qed simp_all
  from ord_spmf_eq_leD[OF this, of True] have "spmf (bernoulli p) True = 0" by simp
  moreover then have "spmf (bernoulli p) False = 1" by(simp add: spmf_False_conv_True)
  ultimately show ?thesis by(auto intro: spmf_eqI split: split_indicator)
qed

lemma bernoulli_pos [simp]:
  assumes "1  p"
  shows "bernoulli p = return_spmf True"
proof -
  from assms have "ord_spmf (=) (bernoulli p) (return_spmf True)"
  proof(induction arbitrary: p rule: bernoulli.fixp_induct[case_names adm bottom step])
    case (step bernoulli' p)
    show ?case using step.prems step.IH[of "2 * p - 1"]
      by(auto simp add: ord_spmf_return_spmf2 set_bind_spmf bind_UNION field_simps)
  qed simp_all
  from ord_spmf_eq_leD[OF this, of False] have "spmf (bernoulli p) False = 0" by simp
  moreover then have "spmf (bernoulli p) True = 1" by(simp add: spmf_False_conv_True)
  ultimately show ?thesis by(auto intro: spmf_eqI split: split_indicator)
qed

context begin interpretation pmf_as_function .
lemma bernoulli_eq_bernoulli_pmf:
  "bernoulli p = spmf_of_pmf (bernoulli_pmf p)"
by(rule spmf_eqI; simp)(transfer; auto simp add: max_def min_def)
end

end