Theory HOL-Probability.Tree_Space

(*  Title:      HOL/Probability/Tree_Space.thy
    Author:     Johannes Hölzl, CMU *)

theory Tree_Space
  imports "HOL-Analysis.Analysis" "HOL-Library.Tree"
begin

lemma countable_lfp:
  assumes step: "Y. countable Y  countable (F Y)"
  and cont: "Order_Continuity.sup_continuous F"
  shows "countable (lfp F)"
by(subst sup_continuous_lfp[OF cont])(simp add: countable_funpow[OF step])

lemma countable_lfp_apply:
  assumes step: "Y x. (x. countable (Y x))  countable (F Y x)"
  and cont: "Order_Continuity.sup_continuous F"
  shows "countable (lfp F x)"
proof -
  { fix n
    have "x. countable ((F ^^ n) bot x)"
      by(induct n)(auto intro: step) }
  thus ?thesis using cont by(simp add: sup_continuous_lfp)
qed

inductive_set trees :: "'a set  'a tree set" for S :: "'a set" where
  [intro!]: "Leaf  trees S"
| "l  trees S  r  trees S  v  S  Node l v r  trees S"

lemma Node_in_trees_iff[simp]: "Node l v r  trees S  (l  trees S  v  S  r  trees S)"
  by (subst trees.simps) auto

lemma trees_sub_lfp: "trees S  lfp (λT. T  {Leaf}  (lT. (vS. (rT. {Node l v r}))))"
proof
  have mono: "mono (λT. T  {Leaf}  (lT. (vS. (rT. {Node l v r}))))"
    by (auto simp: mono_def)
  fix t assume "t  trees S" then show "t  lfp (λT. T  {Leaf}  (lT. (vS. (rT. {Node l v r}))))"
  proof induction
    case 1 then show ?case
      by (subst lfp_unfold[OF mono]) auto
  next
    case 2 then show ?case
      by (subst lfp_unfold[OF mono]) auto
  qed
qed

lemma countable_trees: "countable A  countable (trees A)"
proof (intro countable_subset[OF trees_sub_lfp] countable_lfp
         sup_continuous_sup sup_continuous_const sup_continuous_id)
  show "sup_continuous (λT. (lT. vA. rT. {l, v, r}))"
    unfolding sup_continuous_def
  proof (intro allI impI equalityI subsetI, goal_cases)
    case (1 M t)
    then obtain i j :: nat and l x r  where "t = Node l x r" "x  A" "l  M i" "r  M j"
      by auto
    hence "l  M (max i j)" "r  M (max i j)"
      using incseqD[OF incseq M, of i "max i j"] incseqD[OF incseq M, of j "max i j"] by auto
    with t = Node l x r and x  A show ?case by auto
  qed auto
qed auto

lemma trees_UNIV[simp]: "trees UNIV = UNIV"
proof -
  have "t  trees UNIV" for t :: "'a tree"
    by (induction t) (auto intro: trees.intros(2))
  then show ?thesis by auto
qed

instance tree :: (countable) countable
proof
  have "countable (UNIV :: 'a tree set)"
    by (subst trees_UNIV[symmetric]) (intro countable_trees[OF countableI_type])
  then show "to_nat::'a tree  nat. inj to_nat"
    by (auto simp: countable_def)
qed

lemma map_in_trees[intro]: "(x. x  set_tree t  f x  S)  map_tree f t  trees S"
  by (induction t) (auto intro: trees.intros(2))

primrec trees_cyl :: "'a set tree  'a tree set" where
  "trees_cyl Leaf = {Leaf} "
| "trees_cyl (Node l v r) = (l'trees_cyl l. (v'v. (r'trees_cyl r. {Node l' v' r'})))"

definition tree_sigma :: "'a measure  'a tree measure"
where
  "tree_sigma M = sigma (trees (space M)) (trees_cyl ` trees (sets M))"

lemma Node_in_trees_cyl: "Node l' v' r'  trees_cyl t 
  (l v r. t = Node l v r  l'  trees_cyl l  r'  trees_cyl r  v'  v)"
  by (cases t) auto

lemma trees_cyl_sub_trees:
  assumes "t  trees A" "A  Pow B" shows "trees_cyl t  trees B"
  using assms(1)
proof induction
  case (2 l v r) with A  Pow B show ?case
    by (auto intro!: trees.intros(2))
qed auto

lemma trees_cyl_sets_in_space: "trees_cyl ` trees (sets M)  Pow (trees (space M))"
  using trees_cyl_sub_trees[OF _ sets.space_closed, of _ M] by auto

lemma space_tree_sigma: "space (tree_sigma M) = trees (space M)"
  unfolding tree_sigma_def by (rule space_measure_of_conv)

lemma sets_tree_sigma_eq: "sets (tree_sigma M) = sigma_sets (trees (space M)) (trees_cyl ` trees (sets M))"
  unfolding tree_sigma_def by (rule sets_measure_of) (rule trees_cyl_sets_in_space)

lemma Leaf_in_space_tree_sigma [measurable, simp, intro]: "Leaf  space (tree_sigma M)"
  by (auto simp: space_tree_sigma)

lemma Leaf_in_tree_sigma [measurable, simp, intro]: "{Leaf}  sets (tree_sigma M)"
  unfolding sets_tree_sigma_eq
  by (rule sigma_sets.Basic) (auto intro: trees.intros(2) image_eqI[where x=Leaf])

lemma trees_cyl_map_treeI: "t  trees_cyl (map_tree (λx. A) t)" if *: "t  trees A"
  using * by induction auto

lemma trees_cyl_map_in_sets:
  "(x. x  set_tree t  f x  sets M)  trees_cyl (map_tree f t)  sets (tree_sigma M)"
  by (subst sets_tree_sigma_eq) auto

lemma Node_in_tree_sigma:
  assumes L: "X  sets (M M (tree_sigma M M tree_sigma M))"
  shows "{Node l v r | l v r. (v, l, r)  X}  sets (tree_sigma M)"
proof -
  let ?E = "λs::unit tree. trees_cyl (map_tree (λ_. space M) s)"
  have 1: "countable (range ?E)"
    by (intro countable_image countableI_type)
  have 2: "trees_cyl ` trees (sets M)  Pow (space (tree_sigma M))"
    using trees_cyl_sets_in_space[of M] by (simp add: space_tree_sigma)
  have 3: "sets (tree_sigma M) = sigma_sets (space (tree_sigma M)) (trees_cyl ` trees (sets M))"
    unfolding sets_tree_sigma_eq by (simp add: space_tree_sigma)
  have 4: "(s. ?E s) = space (tree_sigma M)"
  proof (safe; clarsimp simp: space_tree_sigma)
    fix t s assume "t  trees_cyl (map_tree (λ_::unit. space M) s)"
    then show "t  trees (space M)"
      by (induction s arbitrary: t) auto
  next
    fix t assume "t  trees (space M)"
    then show "t'. t  ?E t'"
      by (intro exI[of _ "map_tree (λ_. ()) t"])
         (auto simp: tree.map_comp comp_def intro: trees_cyl_map_treeI)
  qed
  have 5: "range ?E  trees_cyl ` trees (sets M)" by auto
  let ?P = "{A × B | A B. A  trees_cyl ` trees (sets M)  B  trees_cyl ` trees (sets M)}"
  have P: "sets (tree_sigma M M tree_sigma M) = sets (sigma (space (tree_sigma M) × space (tree_sigma M)) ?P)"
    by (rule sets_pair_eq[OF 2 3 1 5 4 2 3 1 5 4])

  have "sets (M M (tree_sigma M M tree_sigma M)) =
    sets (sigma (space M × space (tree_sigma M M tree_sigma M)) {A × BC | A BC. A  sets M  BC  ?P})"
  proof (rule sets_pair_eq)
    show "sets M  Pow (space M)" "sets M = sigma_sets (space M) (sets M)"
      by (auto simp: sets.sigma_sets_eq sets.space_closed)
    show "countable {space M}" "{space M}  sets M" "{space M} = space M"
      by auto
    show "?P  Pow (space (tree_sigma M M tree_sigma M))"
      using trees_cyl_sets_in_space[of M]
      by (auto simp: space_pair_measure space_tree_sigma subset_eq)
    then show "sets (tree_sigma M M tree_sigma M) =
      sigma_sets (space (tree_sigma M M tree_sigma M)) ?P"
      by (subst P, subst sets_measure_of) (auto simp: space_tree_sigma space_pair_measure)
    show "countable ((λ(a, b). a × b) ` (range ?E × range ?E))"
      by (intro countable_image countable_SIGMA countableI_type)
    show "(λ(a, b). a × b) ` (range ?E × range ?E)  ?P"
      by auto
  qed (insert 4, auto simp: space_pair_measure space_tree_sigma set_eq_iff)
  also have " = sigma_sets (space M × trees (space M) × trees (space M))
                    {A × BC |A BC. A  sets M  BC  {A × B |A B.
                       A  trees_cyl ` trees (sets M)  B  trees_cyl ` trees (sets M)}}"
    (is "_ = sigma_sets ?X ?Y") using sets.space_closed[of M] trees_cyl_sub_trees[of _ "sets M" "space M"]
    by (subst sets_measure_of) 
       (auto simp: space_pair_measure space_tree_sigma)
  also have "?Y = {A × trees_cyl B × trees_cyl C | A B C. A  sets M  
                     B  trees (sets M)  C  trees (sets M)}" by blast
  finally have "X  sigma_sets (space M × trees (space M) × trees (space M))
    {A × trees_cyl B × trees_cyl C | A B C. A  sets M  B  trees (sets M)  C  trees (sets M) }"
    using assms by blast
  then show ?thesis
  proof induction
    case (Basic A')
    then obtain A B C where "A' = A × trees_cyl B × trees_cyl C"
      and *: "A  sets M" "B  trees (sets M)" "C  trees (sets M)"
      by auto
    then have "{Node l v r |l v r. (v, l, r)  A'} = trees_cyl (Node B A C)"
      by auto
    then show ?case
      by (auto simp del: trees_cyl.simps simp: sets_tree_sigma_eq intro!: sigma_sets.Basic *)
  next
    case Empty show ?case by auto
  next
    case (Compl A)
    have "{Node l v r |l v r. (v, l, r)  space M × trees (space M) × trees (space M) - A} =
      (space (tree_sigma M) - {Node l v r |l v r. (v, l, r)  A}) - {Leaf}"
      by (auto simp: space_tree_sigma elim: trees.cases)
    also have "  sets (tree_sigma M)"
      by (intro sets.Diff Compl) auto
    finally show ?case .
  next
    case (Union I)
    have *: "{Node l v r |l v r. (v, l, r)  (I ` UNIV)} =
      (i. {Node l v r |l v r. (v, l, r)  I i})" by auto
    show ?case unfolding * using Union(2) by (intro sets.countable_UN) auto
  qed
qed

lemma measurable_left[measurable]: "left  tree_sigma M M tree_sigma M"
proof (rule measurableI)
  show "t  space (tree_sigma M)  left t  space (tree_sigma M)" for t
    by (cases t) (auto simp: space_tree_sigma)
  fix A assume A: "A  sets (tree_sigma M)"
  from sets.sets_into_space[OF this]
  have *: "left -` A  space (tree_sigma M) =
    (if Leaf  A then {Leaf} else {}) 
    {Node a v r | a v r. (v, a, r)  space M × A × space (tree_sigma M)}"
    by (auto simp: space_tree_sigma elim: trees.cases)
  show "left -` A  space (tree_sigma M)  sets (tree_sigma M)"
    unfolding * using A by (intro sets.Un Node_in_tree_sigma pair_measureI) auto
qed

lemma measurable_right[measurable]: "right  tree_sigma M M tree_sigma M"
proof (rule measurableI)
  show "t  space (tree_sigma M)  right t  space (tree_sigma M)" for t
    by (cases t) (auto simp: space_tree_sigma)
  fix A assume A: "A  sets (tree_sigma M)"
  from sets.sets_into_space[OF this]
  have *: "right -` A  space (tree_sigma M) =
    (if Leaf  A then {Leaf} else {}) 
    {Node l v a | l v a. (v, l, a)  space M × space (tree_sigma M) × A}"
    by (auto simp: space_tree_sigma elim: trees.cases)
  show "right -` A  space (tree_sigma M)  sets (tree_sigma M)"
    unfolding * using A by (intro sets.Un Node_in_tree_sigma pair_measureI) auto
qed

lemma measurable_value': "value  restrict_space (tree_sigma M) (-{Leaf}) M M"
proof (rule measurableI)
  show "t  space (restrict_space (tree_sigma M) (- {Leaf}))  value t  space M" for t
    by (cases t) (auto simp: space_restrict_space space_tree_sigma)
  fix A assume A: "A  sets M"
  from sets.sets_into_space[OF this]
  have "value -` A  space (restrict_space (tree_sigma M) (- {Leaf})) =
    {Node l a r | l a r. (a, l, r)  A × space (tree_sigma M) × space (tree_sigma M)}"
    by (auto simp: space_tree_sigma space_restrict_space elim: trees.cases)
  also have "  sets (tree_sigma M)"
    using A by (intro sets.Un Node_in_tree_sigma pair_measureI) auto
  finally show "value -` A  space (restrict_space (tree_sigma M) (- {Leaf})) 
      sets (restrict_space (tree_sigma M) (- {Leaf}))"
    by (auto simp: sets_restrict_space_iff space_restrict_space)
qed

lemma measurable_value[measurable (raw)]:
  assumes "f  X M tree_sigma M"
    and "x. x  space X  f x  Leaf"
  shows "(λω. value (f ω))  X M M"
proof -
  from assms have "f  X M restrict_space (tree_sigma M) (- {Leaf})"
    by (intro measurable_restrict_space2) auto
  from this and measurable_value' show ?thesis by (rule measurable_compose)
qed


lemma measurable_Node [measurable]:
  "(λ(l,x,r). Node l x r)  tree_sigma M M M M tree_sigma M M tree_sigma M"
proof (rule measurable_sigma_sets)
  show "sets (tree_sigma M) = sigma_sets (trees (space M)) (trees_cyl ` trees (sets M))"
    by (simp add: sets_tree_sigma_eq)
  show "trees_cyl ` trees (sets M)  Pow (trees (space M))"
    by (rule trees_cyl_sets_in_space)
  show "(λ(l, x, r). l, x, r)  space (tree_sigma M M M M tree_sigma M)  trees (space M)"
    by (auto simp: space_pair_measure space_tree_sigma)
  fix A assume t: "A  trees_cyl ` trees (sets M)"
  then obtain t where t: "t  trees (sets M)" "A = trees_cyl t" by auto
  show "(λ(l, x, r). l, x, r) -` A 
         space (tree_sigma M M M M tree_sigma M)
          sets (tree_sigma M M M M tree_sigma M)"
  proof (cases t)
    case Leaf
    have "(λ(l, x, r). l, x, r) -` {Leaf :: 'a tree} = {}" by auto
    with Leaf show ?thesis using t by simp
  next
    case (Node l B r)
    hence "(λ(l, x, r). l, x, r) -` A  space (tree_sigma M M M M tree_sigma M) = 
             trees_cyl l × B × trees_cyl r" 
      using t and Node and trees_cyl_sub_trees[of _ "sets M" "space M"]
      by (auto simp: space_pair_measure space_tree_sigma 
               dest: sets.sets_into_space[of _ M])
    thus ?thesis using t and Node
      by (auto intro!: pair_measureI simp: sets_tree_sigma_eq)
  qed    
qed

lemma measurable_Node' [measurable (raw)]:
  assumes [measurable]: "l  B M tree_sigma A"
  assumes [measurable]: "x  B M A"
  assumes [measurable]: "r  B M tree_sigma A"
  shows   "(λy. Node (l y) (x y) (r y))  B M tree_sigma A"
proof -
  have "(λy. Node (l y) (x y) (r y)) = (λ(a,b,c). Node a b c)  (λy. (l y, x y, r y))"
    by (simp add: o_def)
  also have "  B M tree_sigma A"
    by (intro measurable_comp[OF _ measurable_Node]) simp_all
  finally show ?thesis .
qed  

lemma measurable_rec_tree[measurable (raw)]:
  assumes t: "t  B M tree_sigma M"
  assumes l: "l  B M A"
  assumes n: "(λ(x, l, v, r, al, ar). n x l v r al ar) 
    (B M tree_sigma M M M M tree_sigma M M A M A) M A" (is "?N  ?M M A")
  shows "(λx. rec_tree (l x) (n x) (t x))  B M A"
proof (rule measurable_piecewise_restrict)
  let ?C = "λt. λs::unit tree. t -` trees_cyl (map_tree (λ_. space M) s)"
  show "countable (range (?C t))" by (intro countable_image countableI_type)
  show "space B  (s. ?C t s)"
  proof (safe; clarsimp)
    fix x assume x: "x  space B" have "t x  trees (space M)"
      using t[THEN measurable_space, OF x] by (simp add: space_tree_sigma)
    then show "xa::unit tree. t x  trees_cyl (map_tree (λ_. space M) xa)"
      by (intro exI[of _ "map_tree (λ_. ()) (t x)"])
         (simp add: tree.map_comp comp_def trees_cyl_map_treeI)
  qed
  fix Ω assume "Ω  range (?C t)"
  then obtain s :: "unit tree" where Ω: "Ω = ?C t s" by auto
  then show "Ω  space B  sets B"
    by (safe intro!: measurable_sets[OF t] trees_cyl_map_in_sets)
  show "(λx. rec_tree (l x) (n x) (t x))  restrict_space B Ω M A"
    unfolding Ω using t
  proof (induction s arbitrary: t)
    case Leaf
    show ?case
    proof (rule measurable_cong[THEN iffD2])
      fix ω assume "ω  space (restrict_space B (?C t Leaf))"
      then show "rec_tree (l ω) (n ω) (t ω) = l ω"
        by (auto simp: space_restrict_space)
    next
      show "l  restrict_space B (?C t Leaf) M A"
        using l by (rule measurable_restrict_space1)
    qed
  next
    case (Node ls u rs)
    let ?F = "λω. ?N (ω, left (t ω), value (t ω), right (t ω),
        rec_tree (l ω) (n ω) (left (t ω)), rec_tree (l ω) (n ω) (right (t ω)))"
    show ?case
    proof (rule measurable_cong[THEN iffD2])
      fix ω assume "ω  space (restrict_space B (?C t (Node ls u rs)))"
      then show "rec_tree (l ω) (n ω) (t ω) = ?F ω"
        by (auto simp: space_restrict_space)
    next
      show "?F  (restrict_space B (?C t (Node ls u rs))) M A"
        apply (intro measurable_compose[OF _ n] measurable_Pair[rotated])
        subgoal
          apply (rule measurable_restrict_mono[OF Node(2)])
          apply (rule measurable_compose[OF Node(3) measurable_right])
          by auto
        subgoal
          apply (rule measurable_restrict_mono[OF Node(1)])
          apply (rule measurable_compose[OF Node(3) measurable_left])
          by auto
        subgoal
          by (rule measurable_restrict_space1)
             (rule measurable_compose[OF Node(3) measurable_right])
        subgoal
          apply (rule measurable_compose[OF _ measurable_value'])
          apply (rule measurable_restrict_space3[OF Node(3)])
          by auto
        subgoal
          by (rule measurable_restrict_space1)
             (rule measurable_compose[OF Node(3) measurable_left])
        by (rule measurable_restrict_space1) auto
    qed
  qed
qed

lemma measurable_case_tree [measurable (raw)]:
  assumes "t  B M tree_sigma M"
  assumes "l  B M A"
  assumes "(λ(x, l, v, r). n x l v r)
              B M tree_sigma M M M M tree_sigma M M A"
  shows   "(λx. case_tree (l x) (n x) (t x))  B M (A :: 'a measure)"
proof -
  define n' where "n' = (λx l v r (_::'a) (_::'a). n x l v r)"
  have "(λx. case_tree (l x) (n x) (t x)) = (λx. rec_tree (l x) (n' x) (t x))"
    (is "_ = (λx. rec_tree _ (?n' x) _)") by (rule ext) (auto split: tree.splits simp: n'_def)
  also have "  B M A"
  proof (rule measurable_rec_tree)
    have "(λ(x, l, v, r, al, ar). n' x l v r al ar) = 
            (λ(x,l,v,r). n x l v r)  (λ(x,l,v,r,al,ar). (x,l,v,r))" 
      by (simp add: n'_def o_def case_prod_unfold)
    also have "  B M tree_sigma M M M M tree_sigma M M A M A M A"
      using assms(3) by measurable
    finally show "(λ(x, l, v, r, al, ar). n' x l v r al ar)  " .
  qed (insert assms, simp_all)
  finally show ?thesis .
qed

hide_const (open) left
hide_const (open) right

end