Theory Partial_Trace

section Partial_Trace› -- The partial trace›

theory Partial_Trace
  imports Trace_Class Hilbert_Space_Tensor_Product
begin

hide_fact (open) Infinite_Set_Sum.abs_summable_on_Sigma_iff
hide_fact (open) Infinite_Set_Sum.abs_summable_on_comparison_test
hide_const (open) Determinants.trace
hide_fact (open) Determinants.trace_def

definition partial_trace :: (('a × 'c) ell2, ('b × 'c) ell2) trace_class  ('a ell2, 'b ell2) trace_class where
  partial_trace t = (j. compose_tcl (compose_tcr ((tensor_ell2_right (ket j))*) t) (tensor_ell2_right (ket j)))

lemma partial_trace_def': partial_trace t = (j. sandwich_tc ((tensor_ell2_right (ket j))*) t)
― ‹We cannot use this as the definition of constpartial_trace because this definition
      has a more restricted type (termt is a square operator).›
  by (auto intro!: simp: partial_trace_def sandwich_tc_def)

lemma partial_trace_abs_summable:
  (λj. compose_tcl (compose_tcr ((tensor_ell2_right (ket j))*) t) (tensor_ell2_right (ket j))) abs_summable_on UNIV
  and partial_trace_has_sum:
  ((λj. compose_tcl (compose_tcr ((tensor_ell2_right (ket j))*) t) (tensor_ell2_right (ket j))) has_sum partial_trace t) UNIV
  and partial_trace_norm_reducing: norm (partial_trace t)  norm t
proof -
  define t' where t' = from_trace_class t
  define s where s k = compose_tcl (compose_tcr ((tensor_ell2_right (ket k))*) t) (tensor_ell2_right (ket k)) for k

  have bound: (kF. norm (s k))  norm t
    if  F  {F. F  UNIV  finite F}
    for F :: 'a set
  proof -
    from that have [simp]: finite F
      by force
    define tk where tk k = tensor_ell2_right (ket k)* oCL t' oCL tensor_ell2_right (ket k) for k
    have tc_t'[simp]: trace_class t'
      by (simp add: t'_def)
    then have tc_tk[simp]: trace_class (tk k) for k
      by (simp add: tk_def trace_class_comp_left trace_class_comp_right)
    define uk where uk k = (polar_decomposition (tk k))* for k
    define u where u = (kF. uk k o butterfly (ket k) (ket k))
    define B :: 'b ell2 set where B = range ket

    have aux1: tensor_ell2_right (ket x)* *V u *V a = 0 if x  F for x a
    proof -
      have *: u* oCL tensor_ell2_right (ket x) = 0
        by (auto intro!: equal_ket simp: u_def sum_adj tensor_op_adjoint tensor_ell2_right_apply
            cblinfun.sum_left tensor_op_ell2 cinner_ket sum_single[where i=x] x  F)
      have tensor_ell2_right (ket x)* oCL u = 0
        by (rule adj_inject[THEN iffD1]) (use * in simp)
      then show ?thesis
        by (simp flip: cblinfun_apply_cblinfun_compose)
    qed

    have aux2: uk x *V tensor_ell2_right (ket x)* *V a = tensor_ell2_right (ket x)* *V u *V a if x  F for x a
    proof - 
      have *: tensor_ell2_right (ket x) oCL (uk x)* = u* oCL tensor_ell2_right (ket x)
        by (auto intro!: equal_ket simp: u_def sum_adj tensor_op_adjoint tensor_ell2_right_apply
            cblinfun.sum_left tensor_op_ell2 x  F cinner_ket sum_single[where i=x])
      have uk x oCL tensor_ell2_right (ket x)* = tensor_ell2_right (ket x)* oCL u
        by (rule adj_inject[THEN iffD1]) (use * in simp)
      then show ?thesis
        by (simp flip: cblinfun_apply_cblinfun_compose)
    qed

    have sum1: (λ(x, y). ket (y, x) C (u *V t' *V ket (y, x))) summable_on UNIV
    proof -
      have trace_class (u oCL t')
        by (simp add: trace_class_comp_right)
      then have (λyx. yx C ((u oCL t') *V yx)) summable_on (range ket)
        using is_onb_ket trace_exists by blast
      then have (λyx. ket yx C ((u oCL t') *V ket yx)) summable_on UNIV
        apply (subst summable_on_reindex_bij_betw[where g=ket and A=UNIV and B=range ket])
         using bij_betw_def inj_ket by blast
      then show ?thesis
        by (subst summable_on_reindex_bij_betw[where g=prod.swap and A=UNIV, symmetric]) auto
    qed

    have norm_u: norm u  1
    proof -
      define u2 uk2 where u2 = u* oCL u and uk2 k = (uk k)* oCL uk k for k (* and ‹u4 = u2* oCL u2› *)
      have *: (iF. (uk i* oCL uk k) o (ket i C ket k) *C butterfly (ket i) (ket k))
           = (uk k* oCL uk k) o butterfly (ket k) (ket k) if [simp]: k  F for k
        apply (subst sum_single[where i=k])
        by (auto simp: cinner_ket)
      have **: (kaF. (uk2 ka oCL uk2 k) o (ket ka C ket k) *C butterfly (ket ka) (ket k))
           = (uk2 k oCL uk2 k) o butterfly (ket k) (ket k) if [simp]: k  F for k
        apply (subst sum_single[where i=k])
        by (auto simp: cinner_ket)
      have proj_uk2: is_Proj (uk2 k) for k
        unfolding uk2_def
        apply (rule partial_isometry_square_proj)
        by (auto intro!: partial_isometry_square_proj partial_isometry_adj simp: uk_def)
      have u2_explicit: u2 = (kF. uk2 k o butterfly (ket k) (ket k))
        by (simp add: u2_def u_def sum_adj tensor_op_adjoint cblinfun_compose_sum_right 
            cblinfun_compose_sum_left tensor_butter comp_tensor_op * uk2_def)
      have u2* = u2
        by (simp add: u2_def)
      moreover have u2 oCL u2 = u2
        by (simp add: u2_explicit cblinfun_compose_sum_right cblinfun_compose_sum_left
            comp_tensor_op ** proj_uk2 is_Proj_idempotent)
      ultimately have is_Proj u2
        by (simp add: is_Proj_I)
      then have norm u2  1
        using norm_is_Proj by blast
      then show norm u  1
        by (simp add: power_le_one_iff norm_AAadj u2_def)
    qed

    have (kF. norm (s k))
      = (kF. trace_norm (tk k))
      by (simp add: s_def tk_def norm_trace_class.rep_eq compose_tcl.rep_eq compose_tcr.rep_eq t'_def)
    also have  = cmod (kF. trace (uk k oCL tk k))
      by (smt (verit, best) norm_of_real of_real_hom.hom_sum polar_decomposition_correct' sum.cong sum_nonneg trace_abs_op trace_norm_nneg uk_def)
    also have  = cmod (kF. trace (tensor_ell2_right (ket k)* oCL u oCL t' oCL tensor_ell2_right (ket k)))
      apply (rule arg_cong[where f=cmod], rule sum.cong[OF refl], rule arg_cong[where f=trace])
      by (auto intro!: equal_ket simp: tk_def aux2)
    also have  = cmod (kF. j. ket j C ((tensor_ell2_right (ket k)* oCL u oCL t' oCL tensor_ell2_right (ket k)) *V ket j))
      by (auto intro!: sum.cong simp: trace_ket_sum trace_class_comp_left trace_class_comp_right)
    also have  = cmod (kF. j. ket j C ((tensor_ell2_right (ket k)* oCL u oCL t' oCL tensor_ell2_right (ket k)) *V ket j))
      by (simp add: finite F)
    also have  = cmod (k. j. ket j C ((tensor_ell2_right (ket k)* oCL u oCL t' oCL tensor_ell2_right (ket k)) *V ket j))
      apply (rule arg_cong[where f=cmod])
      apply (rule infsum_cong_neutral)
      by (auto simp: aux1)
    also have  = cmod (k. j. ket (j,k) C ((u oCL t') *V ket (j,k)))
      apply (rule arg_cong[where f=cmod], rule infsum_cong, rule infsum_cong)
      by (simp add: tensor_ell2_right_apply cinner_adj_right tensor_ell2_ket)
    also have  = cmod ((k,j). ket (j,k) C ((u oCL t') *V ket (j,k)))
      apply (rule arg_cong[where f=cmod])
      apply (subst infsum_Sigma'_banach)
      using sum1 by auto
    also have  = cmod (jk. ket jk C ((u oCL t') *V ket jk))
      apply (subst infsum_reindex_bij_betw[where g=prod.swap and A=UNIV, symmetric])
      by auto
    also have  = cmod (trace (u oCL t'))
      by (simp add: trace_ket_sum trace_class_comp_right)
    also have   trace_norm (u oCL t')
      using trace_leq_trace_norm by blast
    also have   norm u * trace_norm t'
      by (simp add: trace_norm_comp_right)
    also have   trace_norm t'
      using norm_u
      by (metis more_arith_simps(5) mult_right_mono trace_norm_nneg)
    also have  = norm t
      by (simp add: norm_trace_class.rep_eq t'_def)
    finally show (kF. norm (s k))  norm t
      by -
  qed

  show abs_summable: s abs_summable_on UNIV
    by (intro nonneg_bdd_above_summable_on bdd_aboveI2[where M=norm t] norm_ge_zero bound)

  from abs_summable
  show has_sum: (s has_sum partial_trace t) UNIV
    by (simp add: abs_summable_summable partial_trace_def s_def[abs_def] t'_def)

  show norm (partial_trace t)  norm t
  proof -
    have norm (partial_trace t)  (k. norm (s k))
      using _ has_sum apply (rule norm_has_sum_bound)
      using abs_summable has_sum_infsum by blast
    also from bound have (k. norm (s k))  norm t
      by (simp add: abs_summable infsum_le_finite_sums)
    finally show ?thesis
      by -
  qed
qed

lemma partial_trace_abs_summable':
  (λj.  sandwich_tc ((tensor_ell2_right (ket j))*) t) abs_summable_on UNIV
  and partial_trace_has_sum':
  ((λj.  sandwich_tc ((tensor_ell2_right (ket j))*) t) has_sum partial_trace t) UNIV
  using partial_trace_abs_summable partial_trace_has_sum
  by (auto intro!: simp: sandwich_tc_def sandwich_apply)

(* definition partial_trace' where ‹partial_trace' t = (if trace_class t then from_trace_class (partial_trace (Abs_trace_class t)) else 0)›

lemma partial_trace_transfer[transfer_rule]: 
  includes lifting_syntax
  shows ‹(cr_trace_class ===> cr_trace_class) partial_trace' partial_trace›
  by (auto intro!: rel_funI simp: cr_trace_class_def partial_trace'_def from_trace_class_inverse) *)


lemma trace_partial_trace_compose_eq_trace_compose_tensor_id: 
  trace (from_trace_class (partial_trace t) oCL x) = trace (from_trace_class t oCL (x o id_cblinfun))
proof -
  define s where s = trace (from_trace_class (partial_trace t) oCL x)
  define s' where s' e = ket e C ((from_trace_class (partial_trace t) oCL x) *V ket e) for e
  define u where u j = compose_tcl (compose_tcr ((tensor_ell2_right (ket j))*) t) (tensor_ell2_right (ket j)) for j
  define u' where u' e j = ket e C (from_trace_class (u j) *V x *V ket e) for e j
  have (u has_sum partial_trace t) UNIV
    using partial_trace_has_sum[of t]
    by (simp add: u_def[abs_def])
  then have ((λu. from_trace_class u *V x *V ket e) o u has_sum from_trace_class (partial_trace t) *V x *V ket e) UNIV for e
  proof (rule has_sum_comm_additive[rotated -1])
    show Modules.additive (λu. from_trace_class u *V x *V ket e)
      by (simp add: Modules.additive_def cblinfun.add_left plus_trace_class.rep_eq)
    have bounded_clinear: bounded_clinear (λu. from_trace_class u *V x *V ket e)
    proof (rule bounded_clinearI[where K=norm (x *V ket e)])
      show from_trace_class (b1 + b2) *V x *V ket e = from_trace_class b1 *V x *V ket e + from_trace_class b2 *V x *V ket e for b1 b2
        by (simp add: plus_cblinfun.rep_eq plus_trace_class.rep_eq)
      show from_trace_class (r *C b) *V x *V ket e = r *C (from_trace_class b *V x *V ket e) for b r
        by (simp add: scaleC_trace_class.rep_eq)
      show norm (from_trace_class t *V x *V ket e)  norm t * norm (x *V ket e) for t
      proof -
        have norm (from_trace_class t *V x *V ket e)  norm (from_trace_class t) * norm (x *V ket e)
          by (simp add: norm_cblinfun)
        also have   norm t * norm (x *V ket e)
          by (auto intro!: mult_right_mono simp add: norm_leq_trace_norm norm_trace_class.rep_eq)
        finally show ?thesis
          by -
      qed
    qed
    have isCont (λu. from_trace_class u *V x *V ket e) (partial_trace t)
      using bounded_clinear clinear_continuous_at by auto
    then show (λu. from_trace_class u *V x *V ket e) partial_trace t from_trace_class (partial_trace t) *V x *V ket e
      by (simp add: isCont_def)
  qed
  then have ((λv. ket e C v) o ((λu. from_trace_class u *V x *V ket e) o u) has_sum ket e C (from_trace_class (partial_trace t) *V x *V ket e)) UNIV for e 
  proof (rule has_sum_comm_additive[rotated -1])
    show Modules.additive (λv. ket e C v)
      by (simp add: Modules.additive_def cinner_simps(2))
    have bounded_clinear: bounded_clinear (λv. ket e C v)
      using bounded_clinear_cinner_right by auto
    then have isCont (λv. ket e C v) l for l
      by simp
    then show (λv. ket e C v) l ket e C l for l
      by (simp add: isContD)
  qed
  then have has_sum_u': ((λj. u' e j) has_sum s' e) UNIV for e 
    by (simp add: o_def u'_def s'_def)
  then have infsum_u': s' e = infsum (u' e) UNIV for e
    by (metis infsumI)
  have tc_u_x[simp]: trace_class (from_trace_class (u j) oCL x) for j
    by (simp add: trace_class_comp_left)

  have summable_u'_pairs: (λ(e, j). u' e j) summable_on UNIV × UNIV
  proof -
    have trace_class (from_trace_class t oCL (x o id_cblinfun))
      by (simp add: trace_class_comp_left)
    from trace_exists[OF is_onb_ket this]
    have (λej. ket ej C (from_trace_class t *V (x o id_cblinfun) *V ket ej)) summable_on UNIV
      by (simp_all add: summable_on_reindex o_def)
    then show ?thesis
      by (simp_all add: o_def u'_def[abs_def] u_def
          trace_class_comp_left trace_class_comp_right Abs_trace_class_inverse tensor_ell2_right_apply 
          ket_pair_split tensor_op_ell2 case_prod_unfold cinner_adj_right
          compose_tcl.rep_eq compose_tcr.rep_eq)
  qed

  have u'_tensor: u' e j = ket (e,j) C ((from_trace_class t oCL (x o id_cblinfun)) *V ket (e,j)) for e j
    by (simp add: u'_def u_def tensor_op_ell2 tensor_ell2_right_apply  Abs_trace_class_inverse
        trace_class_comp_left trace_class_comp_right cinner_adj_right compose_tcl.rep_eq compose_tcr.rep_eq
        flip: tensor_ell2_ket)

  have ((λe. e C ((from_trace_class (partial_trace t) oCL x) *V e)) has_sum s) (range ket)
    unfolding s_def
    apply (rule trace_has_sum)
    by (auto simp: trace_class_comp_left)
  then have (s' has_sum s) UNIV
    apply (subst (asm) has_sum_reindex)
    by (auto simp: o_def s'_def[abs_def])
  then have s = infsum s' UNIV
    by (simp add: infsumI)
  also have  = infsum (λe. infsum (u' e) UNIV) UNIV
    using infsum_u' by presburger
  also have  = ((e, j)UNIV. u' e j)
    apply (subst infsum_Sigma'_banach)
     apply (rule summable_u'_pairs)
    by simp
  also have  = trace (from_trace_class t oCL (x o id_cblinfun))
    unfolding u'_tensor 
    by (simp add: trace_ket_sum cond_case_prod_eta trace_class_comp_left)
  finally show ?thesis
    by (simp add: s_def)
qed



lemma right_amplification_weak_star_cont[simp]:
  continuous_map weak_star_topology weak_star_topology (λa. a o id_cblinfun)
  ― ‹Logically does not belong in this theory but uses the partial trace in the proof.›
proof (unfold weak_star_topology_def', rule continuous_map_pullback_both)
  show S  f -` UNIV for S :: 'x set and f :: 'x  'y
    by simp
  define g' :: (('b ell2, 'a ell2) trace_class  complex)  (('b × 'c) ell2, ('a × 'c) ell2) trace_class  complex where
    g' τ t = τ (partial_trace t) for τ t
  have continuous_on UNIV g'
    by (simp add: continuous_on_coordinatewise_then_product g'_def)
  then show continuous_map euclidean euclidean g'
    using continuous_map_iff_continuous2 by blast
  show g' (λt. trace (from_trace_class t oCL x)) =
         (λt. trace (from_trace_class t oCL x o id_cblinfun)) for x
    by (auto intro!: ext simp: g'_def trace_partial_trace_compose_eq_trace_compose_tensor_id)
qed

lemma left_amplification_weak_star_cont[simp]:
  continuous_map weak_star_topology weak_star_topology (λb. id_cblinfun o b :: ('c×'a) ell2 CL ('c×'b) ell2)
  ― ‹Logically does not belong in this theory but uses the partial trace in the proof.›
proof -
  have continuous_map weak_star_topology weak_star_topology (
        (λx. x oCL swap_ell2) o (λx. swap_ell2 oCL x) o (λa. a o id_cblinfun :: ('a×'c) ell2 CL ('b×'c) ell2))
    by (auto intro!: continuous_map_compose[where X'=weak_star_topology]
        continuous_map_left_comp_weak_star continuous_map_right_comp_weak_star)
  then show ?thesis
    by (auto simp: o_def)
qed


lemma partial_trace_plus: partial_trace (t + u) = partial_trace t + partial_trace u
proof -
  from partial_trace_has_sum[of t] and partial_trace_has_sum[of u]
  have ((λj. compose_tcl (compose_tcr ((tensor_ell2_right (ket j))*) t) (tensor_ell2_right (ket j))
            + compose_tcl (compose_tcr ((tensor_ell2_right (ket j))*) u) (tensor_ell2_right (ket j))) has_sum
           partial_trace t + partial_trace u) UNIV (is (?f has_sum _) _)
    by (rule has_sum_add)
  moreover have ?f j = compose_tcl (compose_tcr ((tensor_ell2_right (ket j))*) (t + u)) (tensor_ell2_right (ket j)) (is ?f j = ?g j) for j
    by (simp add: compose_tcl.add_left compose_tcr.add_right)
  ultimately have (?g has_sum partial_trace t + partial_trace u) UNIV
    by simp
  moreover have (?g has_sum partial_trace (t + u)) UNIV
    by (simp add: partial_trace_has_sum)
  ultimately show ?thesis
    using has_sum_unique by blast
qed

lemma partial_trace_scaleC: partial_trace (c *C t) = c *C partial_trace t
  by (simp add: partial_trace_def infsum_scaleC_right compose_tcr.scaleC_right compose_tcl.scaleC_left)

lemma partial_trace_tensor: partial_trace (tc_tensor t u) = trace_tc u *C t
proof -
  define t' u' where t' = from_trace_class t and u' = from_trace_class u
  have 1: (λj. ket j C (from_trace_class u *V ket j)) summable_on UNIV
    using  trace_exists[where B=range ket and A=from_trace_class u]
    by (simp add: summable_on_reindex o_def)
  have partial_trace (tc_tensor t u) =
      (j. compose_tcl (compose_tcr (tensor_ell2_right (ket j)*) (tc_tensor t u)) (tensor_ell2_right (ket j)))
    by (simp add: partial_trace_def)
  also have  = (j. (ket j C (from_trace_class u *V ket j)) *C t)
  proof -
    have *: tensor_ell2_right (ket j)* oCL t' o u' oCL tensor_ell2_right (ket j) =
         (ket j C (u' *V ket j)) *C t' for j
      by (auto intro!: cblinfun_eqI simp: tensor_op_ell2)
    show ?thesis
    apply (rule infsum_cong)
      by (auto intro!: from_trace_class_inject[THEN iffD1] simp flip: t'_def u'_def
        simp: * compose_tcl.rep_eq compose_tcr.rep_eq tc_tensor.rep_eq scaleC_trace_class.rep_eq)
  qed
  also have  = trace_tc u *C t
    by (auto intro!: infsum_scaleC_left simp: trace_tc_def trace_alt_def[OF is_onb_ket] infsum_reindex o_def 1)
  finally show ?thesis
    by -
qed

lemma bounded_clinear_partial_trace[bounded_clinear, iff]: bounded_clinear partial_trace
  apply (rule bounded_clinearI[where K=1])
  by (auto simp add: partial_trace_plus partial_trace_scaleC partial_trace_norm_reducing)

lemma vector_sandwich_partial_trace_has_sum:
  ((λz. ((x s ket z) C (from_trace_class ρ *V (y s ket z))))
      has_sum x C (from_trace_class (partial_trace ρ) *V y)) UNIV
proof -
  define xρy where xρy = x C (from_trace_class (partial_trace ρ) *V y)
  have ((λj. compose_tcl (compose_tcr ((tensor_ell2_right (ket j))*) ρ) (tensor_ell2_right (ket j))) 
        has_sum partial_trace ρ) UNIV
    using partial_trace_has_sum by force
  then have ((λj. x C (from_trace_class (compose_tcl (compose_tcr ((tensor_ell2_right (ket j))*) ρ) (tensor_ell2_right (ket j))) *V y))
        has_sum xρy) UNIV
    unfolding xρy_def
    apply (rule Infinite_Sum.has_sum_bounded_linear[rotated])
    by (intro bounded_clinear.bounded_linear bounded_linear_intros)
  then have ((λj. x C (tensor_ell2_right (ket j)* *V from_trace_class ρ *V y s ket j)) has_sum
     xρy) UNIV
    by (simp add: compose_tcl.rep_eq compose_tcr.rep_eq)
  then show ?thesis
    by (metis (no_types, lifting) cinner_adj_right has_sum_cong tensor_ell2_right_apply xρy_def)
qed

lemma vector_sandwich_partial_trace:
  x C (from_trace_class (partial_trace ρ) *V y) =
      (z. ((x s ket z) C (from_trace_class ρ *V (y s ket z))))
  by (metis (mono_tags, lifting) infsumI vector_sandwich_partial_trace_has_sum)



end