Theory While_SPMF

(* Title: While_SPMF.thy
   Author: Andreas Lochbihler, ETH Zurich *)

theory While_SPMF imports
  MFMC_Countable.Rel_PMF_Characterisation
  "HOL-Types_To_Sets.Types_To_Sets"
  "HOL-Library.Complete_Partial_Order2"
begin

text ‹
  This theory defines a probabilistic while combinator for discrete (sub-)probabilities and
  formalises rules for probabilistic termination similar to those by Hurd cite"Hurd2002TPHOLs"
  and McIver and Morgan cite"McIverMorgan2005".
›

section ‹Miscellaneous library additions›

fun map_option_set :: "('a  'b option set)  'a option  'b option set"
where
  "map_option_set f None = {None}"
| "map_option_set f (Some x) = f x"

lemma None_in_map_option_set:
  "None  map_option_set f x  None  Set.bind (set_option x) f  x = None"
by(cases x) simp_all

lemma None_in_map_option_set_None [intro!]: "None  map_option_set f None"
by simp

lemma None_in_map_option_set_Some [intro!]: "None  f x  None  map_option_set f (Some x)"
by simp

lemma Some_in_map_option_set [intro!]: "Some y  f x  Some y  map_option_set f (Some x)"
by simp

lemma map_option_set_singleton [simp]: "map_option_set (λx. {f x}) y = {Option.bind y f}"
by(cases y) simp_all

lemma Some_eq_bind_conv: "Some y = Option.bind x f  (z. x = Some z  f z = Some y)"
by(cases x) auto

lemma map_option_set_bind: "map_option_set f (Option.bind x g) = map_option_set (map_option_set f  g) x"
by(cases x) simp_all

lemma Some_in_map_option_set_conv: "Some y  map_option_set f x  (z. x = Some z  Some y  f z)"
by(cases x) auto


interpretation rel_spmf_characterisation by unfold_locales(rule rel_pmf_measureI)
hide_fact (open) rel_pmf_measureI

lemma Sup_conv_fun_lub: "Sup = fun_lub Sup"
  by(auto simp add: Sup_fun_def fun_eq_iff fun_lub_def intro: arg_cong[where f=Sup])

lemma le_conv_fun_ord: "(≤) = fun_ord (≤)"
  by(auto simp add: fun_eq_iff fun_ord_def le_fun_def)

lemmas parallel_fixp_induct_2_1 = parallel_fixp_induct_uc[
  of _ _ _ _ "case_prod" _ "curry" "λx. x" _ "λx. x",
  where P="λf g. P (curry f) g",
  unfolded case_prod_curry curry_case_prod curry_K,
  OF _ _ _ _ _ _ refl refl]
  for P

lemma monotone_Pair:
  " monotone ord orda f; monotone ord ordb g 
   monotone ord (rel_prod orda ordb) (λx. (f x, g x))"
by(simp add: monotone_def)

lemma cont_Pair:
  " cont lub ord luba orda f; cont lub ord lubb ordb g 
   cont lub ord (prod_lub luba lubb) (rel_prod orda ordb) (λx. (f x, g x))"
by(rule contI)(auto simp add: prod_lub_def image_image dest!: contD)

lemma mcont_Pair:
  " mcont lub ord luba orda f; mcont lub ord lubb ordb g 
   mcont lub ord (prod_lub luba lubb) (rel_prod orda ordb) (λx. (f x, g x))"
by(rule mcontI)(simp_all add: monotone_Pair mcont_mono cont_Pair)

lemma mono2mono_emeasure_spmf [THEN lfp.mono2mono]:
  shows monotone_emeasure_spmf:
  "monotone (ord_spmf (=)) (≤) (λp. emeasure (measure_spmf p))"
  by(rule monotoneI le_funI ord_spmf_eqD_emeasure)+

lemma cont_emeasure_spmf: "cont lub_spmf (ord_spmf (=)) Sup (≤) (λp. emeasure (measure_spmf p))"
  by (rule contI) (simp add: emeasure_lub_spmf fun_eq_iff image_comp)

lemma mcont2mcont_emeasure_spmf [THEN lfp.mcont2mcont, cont_intro]:
  shows mcont_emeasure_spmf: "mcont lub_spmf (ord_spmf (=)) Sup (≤) (λp. emeasure (measure_spmf p))"
  by(simp add: mcont_def monotone_emeasure_spmf cont_emeasure_spmf)

lemma mcont2mcont_emeasure_spmf' [THEN lfp.mcont2mcont, cont_intro]:
  shows mcont_emeasure_spmf': "mcont lub_spmf (ord_spmf (=)) Sup (≤) (λp. emeasure (measure_spmf p) A)"
  using mcont_emeasure_spmf[unfolded Sup_conv_fun_lub le_conv_fun_ord]
  by(subst (asm) mcont_fun_lub_apply) blast

lemma mcont_bind_pmf [cont_intro]:
  assumes g: "y. mcont luba orda lub_spmf (ord_spmf (=)) (g y)"
  shows "mcont luba orda lub_spmf (ord_spmf (=)) (λx. bind_pmf p (λy. g y x))"
using mcont_bind_spmf[where f="λ_. spmf_of_pmf p" and g=g, OF _ assms] by(simp)

lemma ennreal_less_top_iff: "x <   x  ( :: ennreal)"
  by(cases x) simp_all

lemma type_definition_Domainp: 
  fixes Rep Abs A T
  assumes type: "type_definition Rep Abs A"
  assumes T_def: "T  (λ(x::'a) (y::'b). x = Rep y)"
  shows "Domainp T = (λx. x  A)"
proof -
  interpret type_definition Rep Abs A by(rule type)
  show ?thesis unfolding Domainp_iff[abs_def] T_def fun_eq_iff by(metis Abs_inverse Rep)
qed

context includes lifting_syntax begin

lemma weight_spmf_parametric [transfer_rule]:
  "(rel_spmf A ===> (=)) weight_spmf weight_spmf"
by(simp add: rel_fun_def rel_spmf_weightD)

lemma lossless_spmf_parametric [transfer_rule]:
  "(rel_spmf A ===> (=)) lossless_spmf lossless_spmf"
by(simp add: rel_fun_def lossless_spmf_def rel_spmf_weightD)

lemma UNIV_parametric_pred: "rel_pred R UNIV UNIV"
  by(auto intro!: rel_predI)
end

lemma bind_spmf_spmf_of_set:
  "A.  finite A; A  {}   bind_spmf (spmf_of_set A) = bind_pmf (pmf_of_set A)"
by(simp add: spmf_of_set_def fun_eq_iff del: spmf_of_pmf_pmf_of_set)

lemma set_pmf_bind_spmf: "set_pmf (bind_spmf M f) = set_pmf M  map_option_set (set_pmf  f)"
by(auto 4 3 simp add: bind_spmf_def split: option.splits intro: rev_bexI)

lemma set_pmf_spmf_of_set:
  "set_pmf (spmf_of_set A) = (if finite A  A  {} then Some ` A else {None})"
by(simp add: spmf_of_set_def spmf_of_pmf_def del: spmf_of_pmf_pmf_of_set)

definition measure_measure_spmf :: "'a spmf  'a set  real"
where [simp]: "measure_measure_spmf p = measure (measure_spmf p)"

lemma measure_measure_spmf_parametric [transfer_rule]:
  includes lifting_syntax shows
  "(rel_spmf A ===> rel_pred A ===> (=)) measure_measure_spmf measure_measure_spmf"
unfolding measure_measure_spmf_def[abs_def] by(rule measure_spmf_parametric)

lemma of_nat_le_one_cancel_iff [simp]:
  fixes n :: nat shows "real n  1  n  1"
by linarith

lemma of_int_ceiling_less_add_one [simp]: "of_int r < r + 1"
  by linarith

lemma lessThan_subset_Collect: "{..<x}  Collect P  (y<x. P y)"
  by(auto simp add: lessThan_def)

lemma spmf_ub_tight:
  assumes ub: "x. spmf p x  f x"
  and sum: "(+ x. f x count_space UNIV) = weight_spmf p"
  shows "spmf p x = f x"
proof -
  have [rule_format]: "x. f x  spmf p x"
  proof(rule ccontr)
    assume "¬ ?thesis"
    then obtain x where x: "spmf p x < f x" by(auto simp add: not_le)
    have *: "(+ y. ennreal (f y) * indicator (- {x}) y count_space UNIV)  "
      by(rule neq_top_trans[where y="weight_spmf p"], simp)(auto simp add: sum[symmetric] intro!: nn_integral_mono split: split_indicator)
      
    have "weight_spmf p = + y. spmf p y count_space UNIV"
      by(simp add: nn_integral_spmf space_measure_spmf measure_spmf.emeasure_eq_measure)
    also have " = (+ y. ennreal (spmf p y) * indicator (- {x}) y count_space UNIV) +
      (+ y. spmf p y * indicator {x} y count_space UNIV)"
      by(subst nn_integral_add[symmetric])(auto intro!: nn_integral_cong split: split_indicator)
    also have "  (+ y. ennreal (f y) * indicator (- {x}) y count_space UNIV) + spmf p x"
      using ub by(intro add_mono nn_integral_mono)(auto split: split_indicator intro: ennreal_leI)
    also have " < (+ y. ennreal (f y) * indicator (- {x}) y count_space UNIV) + (+ y. f y * indicator {x} y count_space UNIV)"
      using * x by(simp add: ennreal_less_iff)
    also have " = (+ y. ennreal (f y) count_space UNIV)"
      by(subst nn_integral_add[symmetric])(auto intro: nn_integral_cong split: split_indicator)
    also have " = weight_spmf p" using sum by simp
    finally show False by simp
  qed
  from this[of x] ub[of x] show ?thesis by simp
qed

section ‹Probabilistic while loop›

locale loop_spmf = 
  fixes guard :: "'a  bool"
  and body :: "'a  'a spmf"
begin

context notes [[function_internals]] begin

partial_function (spmf) while :: "'a  'a spmf"
where "while s = (if guard s then bind_spmf (body s) while else return_spmf s)"

end

lemma while_fixp_induct [case_names adm bottom step]:
  assumes "spmf.admissible P"
  and "P (λwhile. return_pmf None)"
  and "while'. P while'  P (λs. if guard s then body s  while' else return_spmf s)"
  shows "P while"
  using assms by(rule while.fixp_induct)

lemma while_simps:
  "guard s  while s = bind_spmf (body s) while"
  "¬ guard s  while s = return_spmf s"
by(rewrite while.simps; simp; fail)+

end

lemma while_spmf_parametric [transfer_rule]:
  includes lifting_syntax shows
  "((S ===> (=)) ===> (S ===> rel_spmf S) ===> S ===> rel_spmf S) loop_spmf.while loop_spmf.while"
unfolding loop_spmf.while_def[abs_def]
apply(rule rel_funI)
apply(rule rel_funI)
apply(rule fixp_spmf_parametric[OF loop_spmf.while.mono loop_spmf.while.mono])
subgoal premises [transfer_rule] by transfer_prover
done

lemma loop_spmf_while_cong:
  " guard = guard'; s. guard' s  body s = body' s 
   loop_spmf.while guard body = loop_spmf.while guard' body'"
unfolding loop_spmf.while_def[abs_def] by(simp cong: if_cong)

section ‹Rules for probabilistic termination›

context loop_spmf begin

subsection ‹0/1 termination laws›

lemma termination_0_1_immediate:
  assumes p: "s. guard s  spmf (map_spmf guard (body s)) False  p"
  and p_pos: "0 < p"
  and lossless: "s. guard s  lossless_spmf (body s)"
  shows "lossless_spmf (while s)"
proof -
  have "s. lossless_spmf (while s)"
  proof(rule ccontr)
    assume "¬ ?thesis"
    then obtain s where s: "¬ lossless_spmf (while s)" by blast
    hence True: "guard s" by(simp add: while.simps split: if_split_asm)

    from p[OF this] have p_le_1: "p  1" using pmf_le_1 by(rule order_trans)
    have new_bound: "p * (1 - k) + k  weight_spmf (while s)" 
      if k: "0  k" "k  1" and k_le: "s. k  weight_spmf (while s)" for k s
    proof(cases "guard s")
      case False
      have "p * (1 - k) + k  1 * (1 - k) + k" using p_le_1 k by(intro mult_right_mono add_mono; simp)
      also have "  1" by simp
      finally show ?thesis using False by(simp add: while.simps)
    next
      case True
      let ?M = "λs. measure_spmf (body s)"
      have bounded: "¦ s''. weight_spmf (while s'') ?M s'¦  1" for s'
        using integral_nonneg_AE[of "λs''. weight_spmf (while s'')" "?M s'"]
        by(auto simp add: weight_spmf_nonneg weight_spmf_le_1 intro!: measure_spmf.nn_integral_le_const integral_real_bounded)
      have "p  measure (?M s) {s'. ¬ guard s'}" using p[OF True]
        by(simp add: spmf_conv_measure_spmf measure_map_spmf vimage_def)
      hence "p * (1 - k) + k  measure (?M s) {s'. ¬ guard s'} * (1 - k) + k"
        using k by(intro add_mono mult_right_mono)(simp_all)
      also have " =  s'. indicator {s'. ¬ guard s'} s' * (1 - k) +  k ?M s"
        using True by(simp add: ennreal_less_top_iff lossless lossless_weight_spmfD)
      also have " =  s'. indicator {s'. ¬ guard s'} s' + indicator {s'. guard s'} s' * k ?M s"
        by(rule Bochner_Integration.integral_cong)(simp_all split: split_indicator)
      also have " =  s'. indicator {s'. ¬ guard s'} s' + indicator {s'. guard s'} s' *  s''. k ?M s' ?M s"
        by(rule Bochner_Integration.integral_cong)(auto simp add: lossless lossless_weight_spmfD split: split_indicator)
      also have "   s'. indicator {s'. ¬ guard s'} s' + indicator {s'. guard s'} s' *  s''. weight_spmf (while s'') ?M s' ?M s"
        using k bounded
        by(intro integral_mono integrable_add measure_spmf.integrable_const_bound[where B=1] add_mono mult_left_mono)
          (simp_all add: weight_spmf_nonneg weight_spmf_le_1 mult_le_one k_le split: split_indicator)
      also have " = s'. (if ¬ guard s' then 1 else  s''. weight_spmf (while s'') ?M s') ?M s"
        by(rule Bochner_Integration.integral_cong)(simp_all split: split_indicator)
      also have " =  s'. weight_spmf (while s') measure_spmf (body s)"
        by(rule Bochner_Integration.integral_cong; simp add: while.simps weight_bind_spmf o_def)
      also have " = weight_spmf (while s)" using True
        by(simp add: while.simps weight_bind_spmf o_def)
      finally show ?thesis .
    qed

    define k where "k  INF s. weight_spmf (while s)"
    define k' where "k'  p * (1 - k) + k"
    from s have "weight_spmf (while s) < 1"
      using weight_spmf_le_1[of "while s"] by(simp add: lossless_spmf_def)
    then have "k < 1"
      unfolding k_def by(rewrite cINF_less_iff)(auto intro!: bdd_belowI2 weight_spmf_nonneg)

    have "0  k" unfolding k_def by(auto intro: cINF_greatest simp add: weight_spmf_nonneg)
    moreover from k < 1 have "k  1" by simp
    moreover have "k  weight_spmf (while s)" for s unfolding k_def
      by(rule cINF_lower)(auto intro!: bdd_belowI2 weight_spmf_nonneg)
    ultimately have "s. k'  weight_spmf (while s)"
      unfolding k'_def by(rule new_bound)
    hence "k'  k" unfolding k_def by(auto intro: cINF_greatest)
    also have "k < k'" using p_pos k < 1 by(auto simp add: k'_def)
    finally show False by simp
  qed
  thus ?thesis by blast
qed

primrec iter :: "nat  'a  'a spmf"
where
  "iter 0 s = return_spmf s"
| "iter (Suc n) s = (if guard s then bind_spmf (body s) (iter n) else return_spmf s)"

lemma iter_unguarded [simp]: "¬ guard s  iter n s = return_spmf s"
  by(induction n) simp_all
  
lemma iter_bind_iter: "bind_spmf (iter m s) (iter n) = iter (m + n) s"
  by(induction m arbitrary: s) simp_all

lemma iter_Suc2: "iter (Suc n) s = bind_spmf (iter n s) (λs. if guard s then body s else return_spmf s)"
  using iter_bind_iter[of n s 1, symmetric]
  by(simp del: iter.simps)(rule bind_spmf_cong; simp cong: bind_spmf_cong)

lemma lossless_iter: "(s. guard s  lossless_spmf (body s))  lossless_spmf (iter n s)"
  by(induction n arbitrary: s) simp_all

lemma iter_mono_emeasure1:
  "emeasure (measure_spmf (iter n s)) {s. ¬ guard s}  emeasure (measure_spmf (iter (Suc n) s)) {s. ¬ guard s}"
  (is "?lhs  ?rhs")
proof(cases "guard s")
  case True
  have "?lhs = emeasure (measure_spmf (bind_spmf (iter n s) return_spmf)) {s. ¬ guard s}" by simp
  also have " = + s'. emeasure (measure_spmf (return_spmf s')) {s. ¬ guard s} measure_spmf (iter n s)"
    by(simp del: bind_return_spmf add: measure_spmf_bind o_def emeasure_bind[where N="measure_spmf _"] space_measure_spmf Pi_def space_subprob_algebra)
  also have "  + s'. emeasure (measure_spmf (if guard s' then body s' else return_spmf s')) {s. ¬ guard s} measure_spmf (iter n s)"
    by(rule nn_integral_mono)(simp add: measure_spmf_return_spmf)
  also have " = ?rhs"
    by(simp add: iter_Suc2 measure_spmf_bind o_def emeasure_bind[where N="measure_spmf _"] space_measure_spmf Pi_def space_subprob_algebra del: iter.simps)
  finally show ?thesis .
qed simp

lemma weight_while_conv_iter:
  "weight_spmf (while s) = (SUP n. measure (measure_spmf (iter n s)) {s. ¬ guard s})"
  (is "?lhs = ?rhs")
proof(rule antisym)
  have "emeasure (measure_spmf (while s)) UNIV  (SUP n. emeasure (measure_spmf (iter n s)) {s. ¬ guard s})"
    (is "_  (SUP n. ?f n s)")
  proof(induction arbitrary: s rule: while_fixp_induct)
    case adm show ?case by simp
    case bottom show ?case by simp
    case (step while')
    show ?case (is "?lhs'  ?rhs'")
    proof(cases "guard s")
      case True
      have inc: "incseq ?f" by(rule incseq_SucI le_funI iter_mono_emeasure1)+

      from True have "?lhs' = + s'. emeasure (measure_spmf (while' s')) UNIV measure_spmf (body s)"
        by(simp add: measure_spmf_bind o_def emeasure_bind[where N="measure_spmf _"] space_measure_spmf Pi_def space_subprob_algebra)
      also have "  + s'. (SUP n. ?f n s') measure_spmf (body s)"
        by(rule nn_integral_mono)(rule step.IH)
      also have " = (SUP n. + s'. ?f n s' measure_spmf (body s))" using inc
        by(subst nn_integral_monotone_convergence_SUP) simp_all
      also have " = (SUP n. ?f (Suc n) s)" using True
        by(simp add: measure_spmf_bind o_def emeasure_bind[where N="measure_spmf _"] space_measure_spmf Pi_def space_subprob_algebra)
      also have "  (SUP n. ?f n s)"
        by(rule SUP_mono)(auto intro: exI[where x="Suc _"])
      finally show ?thesis .
    next
      case False
      then have "?lhs' = emeasure (measure_spmf (iter 0 s)) {s. ¬ guard s}" 
        by(simp add: measure_spmf_return_spmf)
      also have   ?rhs' by(rule SUP_upper) simp
      finally show ?thesis .
    qed
  qed
  also have " = ennreal (SUP n. measure (measure_spmf (iter n s)) {s. ¬ guard s})"
    by(subst ennreal_SUP)(fold measure_spmf.emeasure_eq_measure, auto simp add: not_less measure_spmf.subprob_emeasure_le_1 intro!: exI[where x="1"])
  also have "0  (SUP n. measure (measure_spmf (iter n s)) {s. ¬ guard s})"
    by(rule cSUP_upper2)(auto intro!: bdd_aboveI[where M=1] simp add: measure_spmf.subprob_measure_le_1)
  ultimately show "?lhs  ?rhs" by(simp add: measure_spmf.emeasure_eq_measure space_measure_spmf)
  
  show "?rhs  ?lhs"
  proof(rule cSUP_least)
    show "measure (measure_spmf (iter n s)) {s. ¬ guard s}  weight_spmf (while s)" (is "?f n s  _") for n
    proof(induction n arbitrary: s)
      case 0 show ?case
        by(simp add: measure_spmf_return_spmf measure_return while_simps split: split_indicator)
    next
      case (Suc n)
      show ?case
      proof(cases "guard s")
        case True
        have "?f (Suc n) s = + s'. ?f n s' measure_spmf (body s)"
          using True unfolding measure_spmf.emeasure_eq_measure[symmetric]
          by(simp add: measure_spmf_bind o_def emeasure_bind[where N="measure_spmf _"] space_measure_spmf Pi_def space_subprob_algebra)
        also have "  + s'. weight_spmf (while s') measure_spmf (body s)"
          by(rule nn_integral_mono ennreal_leI Suc.IH)+
        also have " = weight_spmf (while s)"
          using True unfolding measure_spmf.emeasure_eq_measure[symmetric] space_measure_spmf
          by(simp add: while_simps measure_spmf_bind o_def emeasure_bind[where N="measure_spmf _"] space_measure_spmf Pi_def space_subprob_algebra)
        finally show ?thesis by(simp)
      next
        case False then show ?thesis
          by(simp add: measure_spmf_return_spmf measure_return while_simps split: split_indicator)
      qed
    qed
  qed simp
qed

lemma termination_0_1:
  assumes p: "s. guard s  p  weight_spmf (while s)"
    and p_pos: "0 < p"
    and lossless: "s. guard s  lossless_spmf (body s)"
  shows "lossless_spmf (while s)"
  unfolding lossless_spmf_def
proof(rule antisym)
  let ?X = "{s. ¬ guard s}"
  show "weight_spmf (while s)  1" by(rule weight_spmf_le_1)
  
  define p' where "p'  p / 2"
  have p'_pos: "p' > 0" and "p' < p" using p_pos by(simp_all add: p'_def)
  
  have "n. p' < measure (measure_spmf (iter n s)) ?X" if "guard s" for s using p[OF that] p' < p
    unfolding weight_while_conv_iter
    by(subst (asm) le_cSUP_iff)(auto intro!: measure_spmf.subprob_measure_le_1)
  then obtain N where p': "p'  measure (measure_spmf (iter (N s) s)) ?X" if "guard s" for s
    using p by atomize_elim(rule choice, force dest: order.strict_implies_order)

  interpret fuse: loop_spmf guard "λs. iter (N s) s" .
  
  have "1 = weight_spmf (fuse.while s)"
    by(rule lossless_weight_spmfD[symmetric])
      (rule fuse.termination_0_1_immediate; auto simp add: spmf_map vimage_def intro: p' p'_pos lossless_iter lossless)
  also have "  (n. measure (measure_spmf (iter n s)) ?X)"
    unfolding fuse.weight_while_conv_iter
  proof(rule cSUP_least)
    fix n
    have "emeasure (measure_spmf (fuse.iter n s)) ?X  (SUP n. emeasure (measure_spmf (iter n s)) ?X)"
    proof(induction n arbitrary: s)
      case 0 show ?case by(auto intro!: SUP_upper2[where i=0])
    next
      case (Suc n)
      have inc: "incseq (λn s'. emeasure (measure_spmf (iter n s')) ?X)"
        by(rule incseq_SucI le_funI iter_mono_emeasure1)+

      have "emeasure (measure_spmf (fuse.iter (Suc n) s)) ?X = emeasure (measure_spmf (iter (N s) s  fuse.iter n)) ?X"
        by simp
      also have " = + s'. emeasure (measure_spmf (fuse.iter n s')) ?X measure_spmf (iter (N s) s)"
        by(simp add: measure_spmf_bind o_def emeasure_bind[where N="measure_spmf _"] space_measure_spmf Pi_def space_subprob_algebra)
      also have "  + s'. (SUP n. emeasure (measure_spmf (iter n s')) ?X) measure_spmf (iter (N s) s)"
        by(rule nn_integral_mono Suc.IH)+
      also have " = (SUP n. + s'. emeasure (measure_spmf (iter n s')) ?X measure_spmf (iter (N s) s))"
        by(rule nn_integral_monotone_convergence_SUP[OF inc]) simp
      also have " = (SUP n. emeasure (measure_spmf (bind_spmf (iter (N s) s) (iter n))) ?X)"
        by(simp add: measure_spmf_bind o_def emeasure_bind[where N="measure_spmf _"] space_measure_spmf Pi_def space_subprob_algebra)
      also have " = (SUP n. emeasure (measure_spmf (iter (N s + n) s)) ?X)" by(simp add: iter_bind_iter)
      also have "  (SUP n. emeasure (measure_spmf (iter n s)) ?X)" by(rule SUP_mono) auto
      finally show ?case .
    qed
    also have " = ennreal (SUP n. measure (measure_spmf (iter n s)) ?X)"
      by(subst ennreal_SUP)(fold measure_spmf.emeasure_eq_measure, auto simp add: not_less measure_spmf.subprob_emeasure_le_1 intro!: exI[where x="1"])
    also have "0  (SUP n. measure (measure_spmf (iter n s)) ?X)"
      by(rule cSUP_upper2)(auto intro!: bdd_aboveI[where M=1] simp add: measure_spmf.subprob_measure_le_1)
    ultimately show "measure (measure_spmf (fuse.iter n s)) ?X  "
      by(simp add: measure_spmf.emeasure_eq_measure)
  qed simp
  finally show  "1  weight_spmf (while s)" unfolding weight_while_conv_iter .
qed

end

lemma termination_0_1_immediate_invar:
  fixes I :: "'s  bool"
  assumes p: "s.  guard s; I s   spmf (map_spmf guard (body s)) False  p"
  and p_pos: "0 < p"
  and lossless: "s.  guard s; I s   lossless_spmf (body s)"
  and invar: "s s'.  s'  set_spmf (body s); I s; guard s   I s'"
  and I: "I s"
  shows "lossless_spmf (loop_spmf.while guard body s)"
  including lifting_syntax
proof -
  { assume "(Rep :: 's'  's) Abs. type_definition Rep Abs {s. I s}"
    then obtain Rep :: "'s'  's" and Abs where td: "type_definition Rep Abs {s. I s}" by blast
    then interpret td: type_definition Rep Abs "{s. I s}" .
    define cr where "cr  λx y. x = Rep y"
    have [transfer_rule]: "bi_unique cr" "right_total cr" using td cr_def by(rule typedef_bi_unique typedef_right_total)+
    have [transfer_domain_rule]: "Domainp cr = I" using type_definition_Domainp[OF td cr_def] by simp

    define guard' where "guard'  (Rep ---> id) guard"
    have [transfer_rule]: "(cr ===> (=)) guard guard'" by(simp add: rel_fun_def cr_def guard'_def)
    define body1 where "body1  λs. if guard s then body s else return_pmf None"
    define body1' where "body1'  (Rep ---> map_spmf Abs) body1"
    have [transfer_rule]: "(cr ===> rel_spmf cr) body1 body1'"
      by(auto simp add: rel_fun_def body1'_def body1_def cr_def spmf_rel_map td.Rep[simplified] invar td.Abs_inverse intro!: rel_spmf_reflI)
    define s' where "s'  Abs s"
    have [transfer_rule]: "cr s s'" by(simp add: s'_def cr_def I td.Abs_inverse)

    have "s. guard' s  p  spmf (map_spmf guard' (body1' s)) False"
      by(transfer fixing: p)(simp add: body1_def p)
    moreover note p_pos
    moreover have "s. guard' s  lossless_spmf (body1' s)" by transfer(simp add: lossless body1_def)
    ultimately have "lossless_spmf (loop_spmf.while guard' body1' s')" by(rule loop_spmf.termination_0_1_immediate)
    hence "lossless_spmf (loop_spmf.while guard body1 s)" by transfer }
  from this[cancel_type_definition] I show ?thesis by(auto cong: loop_spmf_while_cong)
qed

lemma termination_0_1_invar:
  fixes I :: "'s  bool"
  assumes p: "s.  guard s; I s   p  weight_spmf (loop_spmf.while guard body s)"
    and p_pos: "0 < p"
    and lossless: "s.  guard s; I s   lossless_spmf (body s)"
    and invar: "s s'.  s'  set_spmf (body s); I s; guard s   I s'"
    and I: "I s"
  shows "lossless_spmf (loop_spmf.while guard body s)"
  including lifting_syntax
proof-
  { assume "(Rep :: 's'  's) Abs. type_definition Rep Abs {s. I s}"
    then obtain Rep :: "'s'  's" and Abs where td: "type_definition Rep Abs {s. I s}" by blast
    then interpret td: type_definition Rep Abs "{s. I s}" .
    define cr where "cr  λx y. x = Rep y"
    have [transfer_rule]: "bi_unique cr" "right_total cr" using td cr_def by(rule typedef_bi_unique typedef_right_total)+
    have [transfer_domain_rule]: "Domainp cr = I" using type_definition_Domainp[OF td cr_def] by simp

    define guard' where "guard'  (Rep ---> id) guard"
    have [transfer_rule]: "(cr ===> (=)) guard guard'" by(simp add: rel_fun_def cr_def guard'_def)
    define body1 where "body1  λs. if guard s then body s else return_pmf None"
    define body1' where "body1'  (Rep ---> map_spmf Abs) body1"
    have [transfer_rule]: "(cr ===> rel_spmf cr) body1 body1'"
      by(auto simp add: rel_fun_def body1'_def body1_def cr_def spmf_rel_map td.Rep[simplified] invar td.Abs_inverse intro!: rel_spmf_reflI)
    define s' where "s'  Abs s"
    have [transfer_rule]: "cr s s'" by(simp add: s'_def cr_def I td.Abs_inverse)
    
    interpret loop_spmf guard' body1' .

    note UNIV_parametric_pred[transfer_rule]
    have "s. guard' s  p  weight_spmf (while s)"
      unfolding measure_measure_spmf_def[symmetric] space_measure_spmf
      by(transfer fixing: p)(simp add: body1_def p[simplified space_measure_spmf] cong: loop_spmf_while_cong)
    moreover note p_pos
    moreover have "s. guard' s  lossless_spmf (body1' s)" by transfer(simp add: lossless body1_def)
    ultimately have "lossless_spmf (while s')" by(rule termination_0_1)
    hence "lossless_spmf (loop_spmf.while guard body1 s)" by transfer }
  from this[cancel_type_definition] I show ?thesis by(auto cong: loop_spmf_while_cong)
qed

subsection ‹Variant rule›

context loop_spmf begin

lemma termination_variant:
  fixes bound :: nat
  assumes bound: "s. guard s  f s  bound"
  and step: "s. guard s  p  spmf (map_spmf (λs'. f s' < f s) (body s)) True"
  and p_pos: "0 < p"
  and lossless: "s. guard s  lossless_spmf (body s)"
  shows "lossless_spmf (while s)"
proof -
  define p' and n where "p'  min p 1" and "n  bound + 1"
  have p'_pos: "0 < p'" and p'_le_1: "p'  1" 
    and step': "guard s  p'  measure (measure_spmf (body s)) {s'. f s' < f s}" for s
    using p_pos step[of s] by(simp_all add: p'_def spmf_map vimage_def)
  have "p' ^ n  weight_spmf (while s)" if "f s < n" for s using that
  proof(induction n arbitrary: s)
    case 0 thus ?case by simp
  next
    case (Suc n)
    show ?case
    proof(cases "guard s")
      case False
      hence "weight_spmf (while s) = 1" by(simp add: while.simps)
      thus ?thesis using p'_le_1 p_pos 
        by simp(meson less_eq_real_def mult_le_one p'_pos power_le_one zero_le_power)
    next
      case True
      let ?M = "measure_spmf (body s)"
      have "p' ^ Suc n  ( s'. indicator {s'. f s' < f s} s' ?M) * p' ^ n"
        using step'[OF True] p'_pos by(simp add: mult_right_mono)
      also have " = ( s'. indicator {s'. f s' < f s} s' * p' ^ n ?M)" by simp
      also have "  ( s'. indicator {s'. f s' < f s} s' * weight_spmf (while s') ?M)"
        using Suc.prems p'_le_1 p'_pos
        by(intro integral_mono)(auto simp add: Suc.IH power_le_one weight_spmf_le_1 split: split_indicator intro!: measure_spmf.integrable_const_bound[where B=1])
      also have "   + ( s'. indicator {s'. f s'  f s} s' * weight_spmf (while s') ?M)"
        by(simp add: integral_nonneg_AE weight_spmf_nonneg)
      also have " =  s'. weight_spmf (while s') ?M"
        by(subst Bochner_Integration.integral_add[symmetric])
          (auto intro!: Bochner_Integration.integral_cong measure_spmf.integrable_const_bound[where B=1] weight_spmf_le_1 split: split_indicator)
      also have " = weight_spmf (while s)"
        using True by(subst (1 2) while.simps)(simp add: weight_bind_spmf o_def)
      finally show ?thesis .
    qed
  qed
  moreover have "0 < p' ^ n" using p'_pos by simp
  ultimately show ?thesis using lossless
  proof(rule termination_0_1_invar)
    show "f s < n" if "guard s" "guard s  f s < n" for s using that by simp
    show "guard s  f s < n" using bound[of s] by(auto simp add: n_def)
    show "guard s'  f s' < n" for s' using bound[of s'] by(clarsimp simp add: n_def)
  qed
qed

end

lemma termination_variant_invar:
  fixes bound :: nat and I :: "'s  bool"
  assumes bound: "s.  guard s; I s   f s  bound"
  and step: "s.  guard s; I s   p  spmf (map_spmf (λs'. f s' < f s) (body s)) True"
  and p_pos: "0 < p"
  and lossless: "s.  guard s; I s   lossless_spmf (body s)"
  and invar: "s s'.  s'  set_spmf (body s); I s; guard s   I s'"
  and I: "I s"
  shows "lossless_spmf (loop_spmf.while guard body s)"
  including lifting_syntax
proof -
  { assume "(Rep :: 's'  's) Abs. type_definition Rep Abs {s. I s}"
    then obtain Rep :: "'s'  's" and Abs where td: "type_definition Rep Abs {s. I s}" by blast
    then interpret td: type_definition Rep Abs "{s. I s}" .
    define cr where "cr  λx y. x = Rep y"
    have [transfer_rule]: "bi_unique cr" "right_total cr" using td cr_def by(rule typedef_bi_unique typedef_right_total)+
    have [transfer_domain_rule]: "Domainp cr = I" using type_definition_Domainp[OF td cr_def] by simp

    define guard' where "guard'  (Rep ---> id) guard"
    have [transfer_rule]: "(cr ===> (=)) guard guard'" by(simp add: rel_fun_def cr_def guard'_def)
    define body1 where "body1  λs. if guard s then body s else return_pmf None"
    define body1' where "body1'  (Rep ---> map_spmf Abs) body1"
    have [transfer_rule]: "(cr ===> rel_spmf cr) body1 body1'"
      by(auto simp add: rel_fun_def body1'_def body1_def cr_def spmf_rel_map td.Rep[simplified] invar td.Abs_inverse intro!: rel_spmf_reflI)
    define s' where "s'  Abs s"
    have [transfer_rule]: "cr s s'" by(simp add: s'_def cr_def I td.Abs_inverse)
    define f' where "f'  (Rep ---> id) f"
    have [transfer_rule]: "(cr ===> (=)) f f'" by(simp add: rel_fun_def cr_def f'_def)

    have "s. guard' s  f' s  bound" by(transfer fixing: bound)(rule bound)
    moreover have "s. guard' s  p  spmf (map_spmf (λs'. f' s' < f' s) (body1' s)) True"
      by(transfer fixing: p)(simp add: step body1_def)
    note this p_pos
    moreover have "s. guard' s  lossless_spmf (body1' s)"
      by transfer(simp add: body1_def lossless)
    ultimately have "lossless_spmf (loop_spmf.while guard' body1' s')" by(rule loop_spmf.termination_variant)
    hence "lossless_spmf (loop_spmf.while guard body1 s)" by transfer }
  from this[cancel_type_definition] I show ?thesis by(auto cong: loop_spmf_while_cong)
qed

end