Theory CVM_Original_Algorithm

section ‹The Original CVM Algorithm\label{sec:cvm_original}›

text ‹In this section, we verify the algorithm as presented by Chakrabory et
al.~\cite{chakraborty2022} (replicated, here, in Algorithm~\ref{alg:cvm_classic}),
with the following caveat:

In the original algorithm the elements are removed with probability $f := \frac{1}{2}$ in the
subsampling step. The version verified here allows for any $f \in [\frac{1}{2},e^{-1/12}]$.

\begin{algorithm}[h!]
	\caption{Original CVM algorithm.\label{alg:cvm_classic}}
	\begin{algorithmic}[1]
  \Require Stream elements $a_1,\ldots,a_l$, $0 < \varepsilon$, $0 < \delta < 1$, $f$ subsampling param.
  \Ensure An estimate $R$, s.t., $\prob \left( | R - |A| | > \varepsilon |A| \right) \leq \delta$ where $A := \{a_1,\ldots,a_l\}.$
  \State $\chi \gets \{\}, p \gets 1, n \geq \left\lceil \frac{12}{\varepsilon^2} \ln(\frac{6l}{\delta}) \right\rceil$
  \For{$i \gets 1$ to $l$}
    \State $b \getsr \Ber(p)$ \Comment insert $a_i$ with probability $p$ (and remove it otherwise)
    \If{$b$}
      \State $\chi \gets \chi \cup \{a_i\}$
    \Else
      \State $\chi \gets \chi - \{a_i\}$
    \EndIf
    \If{$|\chi| = n$}
      \State $\chi \getsr \mathrm{subsample}(\chi)$ \Comment keep each element of $\chi$ indep. with prob. $f$
      \State $p \gets p f$
    \EndIf
    \If{$|\chi| = n$}
      \State \Return $\bot$
    \EndIf
  \EndFor
  \State \Return $\frac{|\chi|}{p}$ \Comment estimate cardinality of $A$
\end{algorithmic}
\end{algorithm}

The first step of the proof is identical to the original proof~\cite{chakraborty2022}, where the
above algorithm is approximated by a second algorithm, where lines 11--12 are removed, i.e., the two
algorithms behave identically, unless the very improbable event---where the subsampling step fails
to remove any elements---occurs. It is possible to show that the total variational distance between
the two algorithms is at most $\frac{\delta}{2}$.

In the second step, we verify that the probability that the second algorithm returns an estimate
outside of the desired interval is also at most $\frac{\delta}{2}$. This, of course, works by
noticing that it is an instance of the abstract algorithm we introduced in
Section~\ref{sec:cvm_abs}. In combination, we conclude a failure probability of $\delta$ for the
unmodified version of the algorithm.

On the other hand, the fact that the number of elements in the buffer is at most $n$ can be seen
directly from Algorithm~\ref{alg:cvm_classic}.›

theory CVM_Original_Algorithm
  imports CVM_Abstract_Algorithm
begin

context
  fixes f :: real
  fixes n :: nat
  assumes f_range: f  {1/2..exp(-1/12)}
  assumes n_gt_0: n > 0
begin

text ‹Line 1:›
definition initial_state :: 'a state where
  initial_state = State {} 1

text ‹Lines 3--7:›
fun step_1 :: 'a  'a state  'a state spmf where
  step_1 a (State χ p) =
    do {
      b  bernoulli_pmf p;
      let χ = (if b then χ  {a} else χ - {a});

      return_spmf (State χ p)
    }

definition subsample :: 'a set  'a set spmf where
  subsample χ =
    do {
      keep_in_χ  prod_pmf χ (λ_. bernoulli_pmf f);
      return_spmf (Set.filter keep_in_χ χ)
    }

text ‹Lines 8--10:›
fun step_2 :: 'a state  'a state spmf where
  step_2 (State χ p) =
    do {
      if card χ = n then do {
        χ  subsample χ;
        return_spmf (State χ (p * f))
      } else
        return_spmf (State χ p)
    }

text ‹Lines 11--12:›
fun step_3 :: 'a state  'a state spmf where
  step_3 (State χ p) =
    do {
      if card χ = n
      then fail_spmf
      else return_spmf (State χ p)
    }

text ‹Lines 1--12:›
definition run_steps :: 'a list  'a state spmf where
  run_steps xs  foldM_spmf (λx σ. step_1 x σ  step_2  step_3) xs initial_state

text ‹Line 13:›
definition estimate :: 'a state  real where
  estimate σ = card (state_χ σ) / state_p σ

definition run_algo :: 'a list  real spmf where
  run_algo xs = map_spmf estimate (run_steps xs)


schematic_goal step_1_m_def: step_1 x σ = ?x
  by (subst state.collapse[symmetric], subst step_1.simps, rule refl)

schematic_goal step_2_m_def: step_2 σ = ?x
  by (subst state.collapse[symmetric], subst step_2.simps, rule refl)

schematic_goal step_3_m_def: step_3 σ = ?x
  by (subst state.collapse[symmetric], subst step_3.simps, rule refl)

lemma ord_spmf_remove_step3:
  ord_spmf (=) (step_1 x σ  step_2  step_3) (step_1 x σ  step_2)
proof -
  have ord_spmf (=) (step_2 x  step_3) (step_2 x) for x :: 'a state
  proof -
    have ord_spmf (=) (step_2 x  step_3) (step_2 x  return_spmf)
      by (intro bind_spmf_mono') (simp_all add:step_3_m_def)
    thus ?thesis by simp
  qed
  thus ?thesis unfolding bind_spmf_assoc by (intro bind_spmf_mono') simp_all
qed

lemma ord_spmf_run_steps:
  ord_spmf (=) (run_steps xs) (foldM_spmf (λx σ. step_1 x σ  step_2) xs initial_state)
  unfolding run_steps_def
proof (induction xs rule:rev_induct)
  case Nil
  then show ?case by simp
next
  case (snoc x xs)
  show ?case
    unfolding run_steps_def foldM_spmf_snoc
    by (intro ord_spmf_remove_step3 bind_spmf_mono' snoc)
qed

lemma f_range_simple: f  {1/2..<1}
proof -
  have exp (- 1 / 12) < (1::real) by (approximation 5)
  from dual_order.strict_trans2[OF this]
  show ?thesis using f_range by auto
qed

text ‹Main result:›

theorem correctness:
  fixes xs :: 'a list
  assumes ε  {0<..<1} δ  {0<..<1}
  assumes real n  12 / ε2 * ln (6 * real (length xs) / δ)
  defines A  real (card (set xs))
  shows 𝒫(ω in run_algo xs. fails_or_satisfies (λR. ¦R - A¦ > ε * A) ω)  δ
    (is ?L  ?R)
proof -
  define abs_subsample where
    abs_subsample χ = map_pmf (λω. Set.filter ω χ) (prod_pmf χ (λ_. bernoulli_pmf f))
    for χ :: 'a set

  interpret abs:cvm_algo_abstract n f abs_subsample
    rewrites abs.estimate = estimate
  proof -
    show abs:cvm_algo_abstract n f abs_subsample
    proof (unfold_locales, goal_cases)
      case 1 thus ?case by (rule n_gt_0)
    next
      case 2 thus ?case using f_range_simple by auto
    next
      case (3 U x)
      then show ?case unfolding abs_subsample_def by auto
    next
      case (4 g χ S)
      hence fin_U: finite χ using n_gt_0 card_gt_0_iff by metis
      note conv = Pi_pmf_subset[OF this 4(1)]

      have (ω. (sS. g (s  ω)) abs_subsample χ) =
        (ω. (sS. g (s  χ  ω s)) prod_pmf χ (λ_. bernoulli_pmf f))
        unfolding abs_subsample_def by (simp cong:prod.cong)
      also have  = (ω. (sS. g (s  χ  ω s)) prod_pmf S (λ_. bernoulli_pmf f))
        unfolding conv by simp
      also have  = (sS. (ω. g (s  χ  ω) bernoulli_pmf f))
        using fin_U finite_subset[OF 4(1)]
        by (intro expectation_prod_Pi_pmf integrable_measure_pmf_finite) auto
      also have  = (sS. (ω. g ω bernoulli_pmf f))
        using 4(1) by (intro prod.cong refl) auto
      finally show ?case by simp
    qed
    show cvm_algo_abstract.estimate = (estimate :: 'a state  real)
      unfolding cvm_algo_abstract.estimate_def[OF abs] estimate_def by simp
  qed

  have a: step_1 σ x = spmf_of_pmf (abs.step_1 σ x) for σ x
    unfolding step_1_m_def abs.step_1_def Let_def spmf_of_pmf_def by (simp add:map_bind_pmf)

  have b: step_2 σ = map_pmf Some (abs.step_2 σ) for σ
    unfolding step_2_m_def abs.step_2_def subsample_def abs_subsample_def Let_def
    by (simp add:map_bind_pmf bind_pmf_return_spmf)

  have c: abs.initial_state = initial_state
    unfolding initial_state_def abs.initial_state_def by simp

  have d: subsample χ = spmf_of_pmf (abs_subsample χ) for χ
    unfolding subsample_def abs_subsample_def map_pmf_def[symmetric]
    by (simp add:spmf_of_pmf_def map_pmf_comp)

  define α :: real where α = f ^ n

  have α_range: α  {0..1}
    using f_range_simple unfolding α_def by (auto intro:power_le_one)
  hence [simp]: ¦α¦  1 by auto

  have (x. (if card x = n then 1 else 0) abs_subsample χ)  α (is ?L1  _)
    if that': card χ = n for χ
  proof -
    have fin_U: finite χ using n_gt_0 that card_gt_0_iff by metis

    have (sχ. of_bool (s  x)::real) = of_bool(card x = n)
      if x  set_pmf (abs_subsample χ) for x
    proof -
      have x_ran: x  χ using that unfolding abs_subsample_def by auto

      have (sχ. of_bool (s  x)::real) = of_bool(x = χ)
        using fin_U x_ran by (induction χ rule:finite_induct) auto
      also have  = of_bool (card x = card χ)
        using x_ran fin_U card_subset_eq by (intro arg_cong[where f=of_bool]) blast
      also have  = of_bool (card x = n) using that' by simp
      finally show ?thesis by auto
    qed
    hence ?L1 = (x. (s  χ. of_bool(s  x)) abs_subsample χ)
      by (intro integral_cong_AE AE_pmfI) simp_all
    also have   (s  χ. (x. of_bool x bernoulli_pmf f))
      by (intro abs.subsample_inequality that) auto
    also have  = f ^ card χ using f_range_simple by simp
    also have  = α unfolding α_def that by simp
    finally show ?thesis by simp
  qed
  hence e: pmf (step_2 σ  step_3) None  α for σ :: 'a state
    using α_range unfolding step_2_m_def step_3_m_def d Let_def
    by (simp add:pmf_bind bind_pmf_return_spmf if_distrib if_distribR cong:if_cong)

  have pmf (step_1 x σ  step_2  step_3) None  α for σ and x :: 'a
  proof-
    have pmf (step_1 x σ  step_2  step_3) None  0 + (_. α  measure_spmf (step_1 x σ))
      unfolding bind_spmf_assoc pmf_bind_spmf_None[where p=step_1 x σ]
      by (intro add_mono integral_mono_AE measure_spmf.integrable_const_bound[where B=1]
          iffD2[OF AE_measure_spmf_iff] ballI e)
         (simp_all add:pmf_le_1 step_1_m_def map_pmf_def[symmetric] pmf_map vimage_def Let_def)
    also have   α using α_range by (simp add: mult_left_le_one_le weight_spmf_le_1)
    finally show ?thesis by simp
  qed
  hence prob_fail (run_steps xs)  length xs * α
    unfolding run_steps_def by (intro prob_fail_foldM_spmf_le[where P=λ_. True]) auto
  also have   δ / 2
  proof (cases xs = [])
    case True
    thus ?thesis using assms(2) by auto
  next
    case False
    have δ  6 * 1 using assms(2) by simp
    also have   6 * real (length xs)
      using False by (intro mult_mono order.refl) (cases xs, auto)
    finally have [simp]: δ  6 * real (length xs) by simp
    have 2 * real (length xs) * f ^ n  2 * real (length xs) * exp (-1/12)^n
      using f_range by (intro mult_left_mono power_mono) auto
    also have  =  2 * real (length xs) * exp (-real n / 12)
      unfolding exp_of_nat_mult[symmetric] by simp
    also have   2 * real (length xs) * exp (-(12 / ε ^ 2 * ln (6 * real (length xs) / δ))/12)
      using assms(3) by (intro mult_left_mono iffD2[OF exp_le_cancel_iff] divide_right_mono) auto
    also have  = 2 * real (length xs) * exp (-ln (6 * real (length xs) / δ) / ε^2 )
      by auto
    also have   2 * real (length xs) * exp (-ln (6 * real (length xs) / δ) / 1 )
      using assms(1,2) False
      by (intro mult_left_mono iffD2[OF exp_le_cancel_iff] divide_left_mono_neg power_le_one)
        (auto intro!:ln_ge_zero simp:divide_simps)
    also have  = 2 * real (length xs) * exp (ln (inverse (6 * real (length xs) / δ)))
      using False assms(2) by (subst ln_inverse[symmetric]) auto
    also have  = 2 * real (length xs) / (6 * real (length xs) / δ)
      using assms(1,2) False by (subst exp_ln) auto
    also have  = δ / 3 using False assms(2) by auto
    also have   δ using assms(2) by auto
    finally have 2 * real (length xs) * f^n  δ by simp
    thus ?thesis unfolding α_def by simp
  qed
  finally have f:prob_fail (run_steps xs)  δ / 2 by simp

  have g:spmf_of_pmf (abs.run_steps xs) = foldM_spmf (λx σ. step_1 x σ  step_2) xs initial_state
    unfolding abs.run_steps_def foldM_spmf_of_pmf_eq(2)[symmetric]
    unfolding spmf_of_pmf_def map_pmf_def c b a
    by (simp add:bind_assoc_pmf bind_spmf_def bind_return_pmf)

  have ?L  measure (run_steps xs) {None} +
    measure (measure_spmf (run_steps xs)) {x. ¦estimate x - A¦ > ε * A}
    unfolding run_algo_def measure_measure_spmf_conv_measure_pmf measure_map_pmf
    by (intro pmf_add) (auto split:option.split_asm)
  also have   δ / 2 + measure (measure_spmf (run_steps xs)) {x. ¦estimate x - A¦ > ε * A}
    unfolding measure_pmf_single by (intro add_mono f order.refl)
  also have   δ/2+measure(measure_spmf (spmf_of_pmf (abs.run_steps xs))) {x. ¦estimate x-A¦>ε*A}
    using ord_spmf_eqD_emeasure[OF ord_spmf_run_steps] unfolding measure_spmf.emeasure_eq_measure g
    by (intro add_mono) auto
  also have   δ / 2 + measure (abs.run_steps xs) {x. ¦estimate x - A¦ > ε * A}
    using measure_spmf_map_pmf_Some spmf_of_pmf_def by auto
  also have   δ / 2 + δ / 2
    using assms(1-3) unfolding A_def by (intro add_mono abs.correctness) auto
  finally show ?thesis by simp
qed

lemma space_usage:
  AE σ in measure_spmf (run_steps xs). card (state_χ σ) < n  finite (state_χ σ)
proof (induction xs rule:rev_induct)
  case Nil thus ?case using n_gt_0 by (simp add:run_steps_def initial_state_def)
next
  case (snoc x xs)
  define p1 where p1 = run_steps xs  step_1 x
  define p2 where p2 = p1  step_2
  define p3 where p3 = p2  step_3

  have a:run_steps (xs@[x]) = p3
    unfolding run_steps_def p1_def p2_def p3_def foldM_spmf_snoc by (simp add:bind_assoc_pmf)

  have card (state_χ σ)  n  finite (state_χ σ) if σ  set_spmf p1 for σ
    using snoc that less_imp_le unfolding p1_def
    by (auto simp: step_1_m_def set_bind_spmf set_spmf_bind_pmf Let_def card_insert_if)+

  hence card (state_χ σ)  n  finite (state_χ σ) if σ  set_spmf p2 for σ
    using that card_filter_mono unfolding p2_def
    by (auto intro!:card_filter_mono simp:step_2_m_def set_bind_spmf set_spmf_bind_pmf
        subsample_def Let_def if_distrib)

  hence card (state_χ σ) < n  finite (state_χ σ) if σ  set_spmf p3 for σ
    using that unfolding p3_def
    by (auto intro:le_neq_implies_less simp:step_3_m_def set_bind_spmf if_distrib)

  thus ?case unfolding a by simp
qed

end (* context *)

end (* theory *)