Theory Efron_Stein_Inequality

section ‹Efron-Stein Inequality›

text ‹In this section we verify the Efron-Stein inequality. The verified theorem is stated as
Efron-Stein inequality for non-symmetric functions by Steele~\cite{steele1986}. However most
textbook refer to this version as ``the Efron-Stein inequality''. The original result that was shown
by Efron and Stein is a tail bound for the variance of a symmetric functions of i.i.d.
random variables~\cite{efron1981}.›

theory Efron_Stein_Inequality
  imports Concentration_Inequalities_Preliminary
begin

theorem efron_stein_inequality_distr:
  fixes f :: "_  real"
  assumes "finite I"
  assumes "i. i  I  prob_space (M i)"
  assumes "integrable (PiM I M) (λx. f x^2)" and f_meas: "f  borel_measurable (PiM I M)"
  shows "prob_space.variance (PiM I M) f 
    (iI. (x. (f (λj. x (j,False)) - f (λj. x (j, j=i)))^2 PiM (I×UNIV) (M  fst))) / 2"
    (is "?L  ?R")
proof -
  let ?M = "PiM (I×(UNIV::bool set)) (M  fst)"

  have prob: "prob_space (PiM I M)"
    using assms(2) by (intro prob_space_PiM) auto

  interpret prob_space "?M"
    using assms(2) by (intro prob_space_PiM) auto

  define n where "n = card I"

  obtain q :: "_  nat" where q:"bij_betw q I {..<n}"
    unfolding n_def using ex_bij_betw_finite_nat[OF assms(1)] atLeast0LessThan by auto

  let  = "(λn x. f (λj. x (j, q j < n)))"
  let  = "(λn x. f (λj. x (j, q j = n)))"
  let  = "(λx. f (λj. x (j, False)))"
  let  = "(λx. f (λj. x (j, True)))"

  have meas_1: "(λω. f (g ω))  borel_measurable ?M"
    if "g  PiM (I × UNIV) (M  fst) M PiM I M" for g
    using that by (intro measurable_compose[OF _ f_meas])

  have meas_2: "(λx j. x (j, h j))  ?M M PiM I M" for h
  proof -
    have "?thesis  (λx. (λj  I. x (j, h j)))  ?M M PiM I M"
      by (intro measurable_cong) (auto simp:space_PiM PiE_def extensional_def)
    also have "...  True"
      unfolding eq_True
      by (intro measurable_restrict measurable_PiM_component_rev) auto
    finally show ?thesis by simp
  qed

  have int_1: "integrable ?M (λx. (g x - h x)^2)"
    if "integrable ?M (λx. (g x)^2)"  "integrable ?M (λx. (h x)^2)"
    and "g  borel_measurable ?M" "h  borel_measurable ?M"
    for g h :: "_  real"
  proof -
    have "integrable ?M (λx. (g x)^2 + (h x)^2 - 2 * (g x * h x))"
      using that by (intro Bochner_Integration.integrable_add Bochner_Integration.integrable_diff
          integrable_mult_right cauchy_schwartz(1))
    thus ?thesis by (simp add:algebra_simps power2_eq_square)
  qed

  note meas_rules = borel_measurable_add borel_measurable_times borel_measurable_diff
    borel_measurable_power meas_1 meas_2

  have f_int: "integrable (PiM I M) f"
    by (intro finite_measure.square_integrable_imp_integrable[OF _ f_meas assms(3)]
        prob_space.finite_measure prob)
  moreover have "integrable (PiM I M) (λx. f (restrict x I)) = integrable (PiM I M) f"
    by (intro  Bochner_Integration.integrable_cong) (auto simp:space_PiM)
  ultimately have f_int_2: "integrable (PiM I M) (λx. f (restrict x I))" by simp

  have cong: "(x. g (λjI. x (j, h j)) ?M) = (x. g (λj. x (j, h j)) ?M)" (is "?L1 = ?R1")
    for g :: "_  real" and h
    by (intro Bochner_Integration.integral_cong arg_cong[where f="g"] refl)
       (auto simp add:space_PiM PiE_def extensional_def restrict_def)

  have lift: "(x. g x PiM I M) = (x. g (λj. x (j, h j)) ?M)" (is "?L1 = ?R1")
    if "g  borel_measurable (PiM I M)"
    for g :: "_  real" and h
  proof -
    let ?J = "(λi. (i, h i)) ` I"
    have "?R1 = (x. g (λj  I. x (j, h j)) ?M)"
      by (intro cong[symmetric])
    also have "... = (x. g x distr ?M (PiM I (λi. (Mfst) (i, h i))) (λx. (λj  I. x (j, h j))))"
      using that
      by (intro integral_distr[symmetric] measurable_restrict measurable_component_singleton) auto
    also have "... = (x. g x PiM I (λi. (M  fst) (i, h i)))"
      using assms(2)
      by (intro arg_cong2[where f="integralL"] refl distr_PiM_reindex inj_onI) auto
    also have "... = ?L1"
      by auto
    finally show ?thesis
      by simp
  qed

  have lift_int: "integrable ?M (λx. g (λj. x (j, h j)))" if "integrable (PiM I M) g"
    for g :: "_  real" and h
  proof -
    have 0:"integrable (distr ?M (PiM I (λi. (Mfst) (i, h i))) (λx. (λj  I. x (j, h j)))) g"
      using that assms(2) by (subst distr_PiM_reindex) (auto intro:inj_onI)
    have "integrable ?M (λx. g (λjI. x (j, h j)))"
      by (intro integrable_distr[OF _ 0] measurable_restrict measurable_component_singleton) auto
    moreover have "integrable ?M (λx. g (λjI. x (j, h j)))  ?thesis"
      by (intro Bochner_Integration.integrable_cong refl arg_cong[where f="g"] ext)
       (auto simp:PiE_def space_PiM extensional_def)
    ultimately show ?thesis
      by simp
  qed

  note int_rules = cauchy_schwartz(1) int_1 lift_int assms(3) f_int f_int_2

  have "(x. g x ?M) = (x. g (λ(j,v). x (j, v  h j)) ?M)" (is "?L1 = ?R1")
    if "g  borel_measurable ?M" for g :: "_  real" and h
  proof -
    have "?L1 = (x. g x distr ?M (PiM (I×UNIV) (λi. (M  fst) (fst i, snd i  h (fst i))))
      (λx.(λi  I×UNIV. x (fst i, snd i  h (fst i))) ))"
      by (subst distr_PiM_reindex) (auto intro:inj_onI assms(2) simp:comp_def)
    also have "... = (x. g (λi  I×UNIV. x (fst i, snd i  h (fst i))) ?M)"
      using that by (intro integral_distr measurable_restrict measurable_component_singleton)
        (auto simp:comp_def)
    also have "... = ?R1"
      by (intro Bochner_Integration.integral_cong refl arg_cong[where f="g"] ext)
       (auto simp add:space_PiM PiE_def extensional_def restrict_def)
    finally show ?thesis
      by simp
  qed

  hence switch: "(x. g x ?M) = (x. h x ?M)"
    if "x. h x = g (λ(j,v). x (j, v  u j))" "g  borel_measurable ?M"
    for g h :: "_  real" and u
    using that by simp

  have 1: "(x. ( x) * ( i x -  (i+1) x) ?M)  (x. ( x -  i x)^2 ?M) / 2"
    (is "?L1  ?R1")
    if "i < n" for i
  proof -
    have "?L1 = (x. ( i x) * ( (i+1) x -  i x) ?M)"
      by (intro switch[of _ _ "(λj. q j = i)"] arg_cong2[where f="(*)"]
            arg_cong2[where f="(-)"] arg_cong[where f="f"] ext meas_rules) (auto intro:arg_cong)
    hence "?L1 = (?L1 + (x. ( i x) * ( (i+1) x -  i x) ?M)) / 2"
      by simp
    also have "... = (x. ( x) * ( i x - (i+1) x) + ( i x) * ((i+1) x -  i x) ?M)/2"
      by (intro Bochner_Integration.integral_add[symmetric] arg_cong2[where f="(/)"] refl
          int_rules meas_rules)
    also have "... = (x. ( x -  i x) * ( i x - (i+1) x) ?M)/2"
      by (intro arg_cong2[where f="(/)"] Bochner_Integration.integral_cong)
        (auto simp:algebra_simps)
    also have "...((x. ( x- i x)^2 ?M)powr(1/2)*(x.( i x-(i+1)x)^2 ?M) powr (1/2))/2"
      by (intro divide_right_mono cauchy_schwartz meas_rules int_rules) auto
    also have "...=((x. ( x- i x)^2 ?M)powr(1/2)*(x.( x- i x)^2 ?M) powr (1/2))/2"
      by (intro arg_cong2[where f="(/)"] arg_cong2[where f="(*)"] arg_cong2[where f="(powr)"] refl
         switch[of _ _ "(λj. q j < i)"] arg_cong2[where f="power"] arg_cong2[where f="(-)"]
         arg_cong[where f="f"] ext meas_rules) (auto intro:arg_cong)
    also have "... = (x. ( x- i x)^2 ?M)/2"
      by (simp add:powr_add[symmetric])
    finally show ?thesis by simp
  qed

  have "indep_vars (M  fst) (λi ω. ω i) (I × UNIV)"
    using assms(2) by (intro proj_indep) auto
  hence 2:"indep_var (PiM (I×{False}) (Mfst)) (λx. λjI×{False}. x j)
    (PiM (I×{True}) (Mfst)) (λx. λjI×{True}. x j)"
    by (intro indep_var_restrict[where I="I × UNIV"]) auto
  have "indep_var
    (PiM I M) ((λx. (λi  I. x (i, False)))  (λx. (λj  I×{False}. x j)))
    (PiM I M) ((λx. (λi  I. x (i, True)))  (λx. (λj  I×{True}. x j)))"
    by (intro indep_var_compose[OF 2] measurable_restrict measurable_PiM_component_rev) auto
  hence "indep_var (PiM I M) (λx. (λjI. x (j, False))) (PiM I M) (λx. (λjI. x (j, True)))"
    unfolding comp_def by (simp add:restrict_def cong:if_cong)

  hence "indep_var borel (f  (λx. (λjI. x (j, False)))) borel (f  (λx. (λj  I. x (j, True))))"
    by (intro indep_var_compose[OF _ assms(4,4)]) auto
  hence indep:"indep_var borel (λx. f (λjI. x (j, False))) borel (λx. f (λjI. x (j, True)))"
    by (simp add:comp_def)

  have 3: "ω (j, q j = q i) = ω (j, j = i)" if
    "ω  PiE (I × UNIV) (λi. space (M (fst i)))" "i  I" for i j ω
  proof (cases "j  I")
    case True
    hence "(q j = q i) = (j = i)"
      using that inj_onD[OF bij_betw_imp_inj_on[OF q]] by blast
    thus ?thesis by simp
  next
    case False
    hence "ω (j, a) = undefined" for a
      using that unfolding PiE_def extensional_def by simp
    thus ?thesis by simp
  qed

  have "?L = (x. (f x)^2 PiM I M) - (x. (f x) PiM I M)^2"
    by (intro prob_space.variance_eq f_int assms(3) prob)
  also have "... = (x. (f x)^2 PiM I M) - (x. f x PiM I M) * (x. f x PiM I M)"
    by (simp add:power2_eq_square)
  also have "... = (x. ( x)^2 ?M) - (x.  x ?M) * (x.  x ?M)"
    by (intro arg_cong2[where f="(-)"] lift  arg_cong2[where f="(*)"] meas_rules f_meas)
  also have "... = (x. ( x)^2 ?M)-(x. f (λjI. x (j,False)) ?M)*(x. f(λjI. x(j,True)) ?M)"
    by (intro arg_cong2[where f="(-)"] arg_cong2[where f="(*)"] cong[symmetric] refl)
  also have "... = (x. ( x)^2 ?M) - (x. f (λjI. x (j,False))* f(λjI. x(j,True))  ?M)"
    by (intro arg_cong2[where f="(-)"] indep_var_lebesgue_integral[symmetric] refl int_rules indep)
  also have "... = (x. ( x) * ( 0 x) ?M) - (x. ( x) * ( n x)  ?M)"
    using bij_betw_apply[OF q] by (intro arg_cong2[where f="(-)"] arg_cong2[where f="(*)"] ext
        arg_cong[where f="f"] Bochner_Integration.integral_cong)
     (auto simp:space_PiM power2_eq_square PiE_def extensional_def)
  also have "... = (i < n. (x. ( x) *  ( i x)  ?M) -  (x. ( x) *  ( (Suc i) x) ?M))"
    unfolding power2_eq_square by (intro sum_lessThan_telescope'[symmetric])
  also have "... = (i < n. (x. ( x) *  ( i x) - ( x) *  ( (Suc i) x) ?M))"
    by (intro sum.cong Bochner_Integration.integral_diff[symmetric] int_rules meas_rules) auto
  also have "... = (i < n. (x. ( x) * ( i x -  (i+1) x) ?M))"
    by (simp_all add:power2_eq_square algebra_simps)
  also have "...  (i< n. ((x. ( x -  i x)^2 ?M)) / 2)"
    by (intro sum_mono 1) auto
  also have "... = (i  I. ((x. (f (λj. x (j,False)) - f (λj. x (j,q j=q i)))^2 ?M))/ 2)"
    by (intro sum.reindex_bij_betw[OF q, symmetric])
  also have "... = (i  I. ((x. (f (λj. x (j,False)) -  f (λj. x (j,q j=q i)))^2 ?M)))/2"
    unfolding sum_divide_distrib[symmetric] by simp
  also have "... = ?R"
    using inj_onD[OF bij_betw_imp_inj_on[OF q]]
    by (intro arg_cong2[where f="(/)"] arg_cong2[where f="(-)"]  arg_cong2[where f="power"]
          arg_cong[where f="f"] Bochner_Integration.integral_cong sum.cong refl ext 3)
     (auto  simp add:space_PiM )
  finally show ?thesis
    by simp
qed

theorem (in prob_space) efron_stein_inequality_classic:
  fixes f :: "_  real"
  assumes "finite I"
  assumes "indep_vars (M'  fst) X (I × (UNIV :: bool set))"
  assumes "f  borel_measurable (PiM I M')"
  assumes "integrable M (λω. f (λiI. X (i,False) ω)^2)"
  assumes "i. i  I  distr M (M' i) (X (i, True)) = distr M (M' i) (X (i, False))"
  shows "variance (λω. f (λiI. X (i,False) ω)) 
    (j  I. expectation (λω. (f (λiI. X (i,False) ω) - f (λiI. X (i, i=j) ω))^2))/2"
    (is "?L  ?R")
proof -
  let ?D = "distr M (PiM I M') (λω. λiI. X (i, False) ω)"

  let ?M = "PiM I (λi. distr M (M' i) (X (i,False)))"
  let ?N = "PiM (I × (UNIV::bool set)) ((λi. distr M (M' i) (X (i,False)))  fst)"

  have rv: "random_variable (M' i) (X (i, j))" if "i  I" for i j
    using assms(2) that unfolding indep_vars_def by auto

  have proj_meas: "(λx j. x (j, h j))  PiM (I × UNIV) (M'  fst) M PiM I M'"
    for h :: " _  bool"
  proof -
    have "?thesis  (λx. (λj  I. x (j, h j)))  PiM (I × UNIV) (M'  fst) M PiM I M'"
      by (intro measurable_cong) (auto simp:space_PiM PiE_def extensional_def)
    also have "...  True"
      unfolding eq_True
      by (intro measurable_restrict measurable_PiM_component_rev) auto
    finally show ?thesis by simp
  qed

  note meas_rules = borel_measurable_add borel_measurable_times borel_measurable_diff proj_meas
    borel_measurable_power assms(3) measurable_restrict measurable_compose[OF _ assms(3)]

  have "indep_vars ((M'  fst)  (λi. (i, False))) (λi. X (i, False)) I"
    by (intro indep_vars_reindex indep_vars_subset[OF assms(2)] inj_onI) auto
  hence "indep_vars M' (λi. X (i, False)) I" by (simp add: comp_def)
  hence 0:"?D = PiM I (λi. distr M (M' i) (X (i,False)))"
    by (intro iffD1[OF indep_vars_iff_distr_eq_PiM''] rv)

  have "distr M (M' (fst x)) (X (fst x, False)) = distr M (M' (fst x)) (X x)"
    if "x  I × UNIV" for x
    using that assms(5) by (cases x, cases "snd x") auto

  hence 1: "?N = PiM (I × UNIV) (λi. distr M ((M'  fst) i) (X i))"
    using assms(3) by (intro PiM_cong refl) (simp add:comp_def)
  also have "... = distr M (PiM (I × UNIV) (M'  fst)) (λx. λiI × UNIV. X i x)"
    using rv by (intro iffD1[OF indep_vars_iff_distr_eq_PiM'', symmetric] assms(2)) auto
  finally have 2:"?N = distr M (PiM (I × UNIV) (M'  fst)) (λx. λiI × UNIV. X i x)"
    by simp

  have 3: "integrable (PiM I (λi. distr M (M' i) (X (i, False)))) (λx. (f x)2)"
    unfolding 0[symmetric] by (intro iffD2[OF integrable_distr_eq] meas_rules assms rv)

  have "?L = (x. (f x - expectation (λω. f (λiI. X (i,False) ω)))^2 ?D)"
    using rv by (intro integral_distr[symmetric] meas_rules measurable_restrict) auto
  also have "... = prob_space.variance ?D f"
    by (intro arg_cong[where f="integralL ?D"] arg_cong2[where f="(-)"] arg_cong2[where f="power"]
        refl ext integral_distr[symmetric] measurable_restrict rv assms(3))
  also have "... = prob_space.variance ?M f"
    unfolding 0 by simp
  also have "...  (iI. (x. (f (λj. x (j, False)) - f (λj. x (j, j = i)))^2 ?N)) / 2"
    using assms(3) by (intro efron_stein_inequality_distr prob_space_distr rv assms(1) 3) auto
  also have "... = (iI. expectation (λω. (f (λj. (λiI×UNIV. X i ω) (j, False)) -
    f (λj. (λiI×UNIV. X i ω) (j, j=i)))2)) / 2"
    using rv unfolding 2
    by (intro sum.cong arg_cong2[where f="(/)"] integral_distr refl meas_rules) auto
  also have "... = ?R"
    by (simp add:restrict_def)
  finally show ?thesis
    by simp
qed

end