Theory Finite_Tensor_Product

section ‹Tensor products (finite dimensional)›

theory Finite_Tensor_Product
  imports Complex_Bounded_Operators.Complex_L2 Misc
begin

declare cblinfun.scaleC_right[simp]

unbundle cblinfun_notation
no_notation m_inv ("invı _" [81] 80)

lift_definition tensor_ell2 :: 'a::finite ell2  'b::finite ell2  ('a×'b) ell2 (infixr "s" 70) is
  λψ φ (i,j). ψ i * φ j
  by simp

lemma tensor_ell2_add2: tensor_ell2 a (b + c) = tensor_ell2 a b + tensor_ell2 a c
  apply transfer apply (rule ext) apply (auto simp: case_prod_beta)
  by (meson algebra_simps)

lemma tensor_ell2_add1: tensor_ell2 (a + b) c = tensor_ell2 a c + tensor_ell2 b c
  apply transfer apply (rule ext) apply (auto simp: case_prod_beta)
  by (simp add: vector_space_over_itself.scale_left_distrib)

lemma tensor_ell2_scaleC2: tensor_ell2 a (c *C b) = c *C tensor_ell2 a b
  apply transfer apply (rule ext) by (auto simp: case_prod_beta)

lemma tensor_ell2_scaleC1: tensor_ell2 (c *C a) b = c *C tensor_ell2 a b
  apply transfer apply (rule ext) by (auto simp: case_prod_beta)

lemma tensor_ell2_inner_prod[simp]: tensor_ell2 a b C tensor_ell2 c d = (a C c) * (b C d)
  apply transfer
  by (auto simp: case_prod_beta sum_product sum.cartesian_product mult.assoc mult.left_commute)

lemma clinear_tensor_ell21: "clinear (λb. tensor_ell2 a b)"
  apply (rule clinearI; transfer)
   apply (auto simp: case_prod_beta)
  by (simp add: cond_case_prod_eta algebra_simps)

lemma clinear_tensor_ell22: "clinear (λa. tensor_ell2 a b)"
  apply (rule clinearI; transfer)
   apply (auto simp: case_prod_beta)
  by (simp add: case_prod_beta' algebra_simps)

lemma tensor_ell2_ket[simp]: "tensor_ell2 (ket i) (ket j) = ket (i,j)"
  apply transfer by auto


definition tensor_op :: ('a ell2, 'b::finite ell2) cblinfun  ('c ell2, 'd::finite ell2) cblinfun
       (('a×'c) ell2, ('b×'d) ell2) cblinfun (infixr "o" 70) where
  tensor_op M N = (SOME P. a c. P *V (ket (a,c))
      = tensor_ell2 (M *V ket a) (N *V ket c))

lemma tensor_op_ket: 
  fixes a :: 'a::finite and b :: 'b and c :: 'c::finite and d :: 'd
  shows tensor_op M N *V (ket (a,c)) = tensor_ell2 (M *V ket a) (N *V ket c)
proof -
  define S :: ('a×'c) ell2 set where "S = ket ` UNIV"
  define φ where φ = (λ(a,c). tensor_ell2 (M *V ket a) (N *V ket c))
  define φ' where φ' = φ  inv ket

  have def: tensor_op M N = (SOME P. a c. P *V (ket (a,c)) = φ (a,c))
    unfolding tensor_op_def φ_def by auto

  have cindependent S
    using S_def cindependent_ket by blast
  moreover have cspan S = UNIV
    using S_def cspan_range_ket_finite by blast
  ultimately have "cblinfun_extension_exists S φ'"
    by (rule cblinfun_extension_exists_finite_dim)
  then have "P. xS. P *V x = φ' x"
    unfolding cblinfun_extension_exists_def by auto
  then have ex: P. a c. P *V ket (a,c) = φ (a,c)
    by (metis S_def φ'_def comp_eq_dest_lhs inj_ket inv_f_f rangeI)


  then have tensor_op M N *V (ket (a,c)) = φ (a,c)
    unfolding def apply (rule someI2_ex[where P=λP. a c. P *V (ket (a,c)) = φ (a,c)])
    by auto
  then show ?thesis
    unfolding φ_def by auto
qed


lemma tensor_op_ell2: "tensor_op A B *V tensor_ell2 ψ φ = tensor_ell2 (A *V ψ) (B *V φ)"
proof -
  have 1: clinear (λa. tensor_op A B *V tensor_ell2 a (ket b)) for b
    by (auto intro!: clinearI simp: tensor_ell2_add1 tensor_ell2_scaleC1 cblinfun.add_right)
  have 2: clinear (λa. tensor_ell2 (A *V a) (B *V ket b)) for b
    by (auto intro!: clinearI simp: tensor_ell2_add1 tensor_ell2_scaleC1 cblinfun.add_right)
  have 3: clinear (λa. tensor_op A B *V tensor_ell2 ψ a)
    by (auto intro!: clinearI simp: tensor_ell2_add2 tensor_ell2_scaleC2 cblinfun.add_right)
  have 4: clinear (λa. tensor_ell2 (A *V ψ) (B *V a))
    by (auto intro!: clinearI simp: tensor_ell2_add2 tensor_ell2_scaleC2 cblinfun.add_right)

  have eq_ket_ket: tensor_op A B *V tensor_ell2 (ket a) (ket b) = tensor_ell2 (A *V ket a) (B *V ket b) for a b
    by (simp add: tensor_op_ket)
  have eq_ket: tensor_op A B *V tensor_ell2 ψ (ket b) = tensor_ell2 (A *V ψ) (B *V ket b) for b
    apply (rule fun_cong[where x=ψ])
    using 1 2 eq_ket_ket by (rule clinear_equal_ket)
  show ?thesis 
    apply (rule fun_cong[where x=φ])
    using 3 4 eq_ket by (rule clinear_equal_ket)
qed

lemma comp_tensor_op: "(tensor_op a b) oCL (tensor_op c d) = tensor_op (a oCL c) (b oCL d)"
  for a :: "'e::finite ell2 CL 'c::finite ell2" and b :: "'f::finite ell2 CL 'd::finite ell2" and
      c :: "'a::finite ell2 CL 'e ell2" and d :: "'b::finite ell2 CL 'f ell2"
  apply (rule equal_ket)
  apply (rename_tac ij, case_tac ij, rename_tac i j, hypsubst_thin)
  by (simp flip: tensor_ell2_ket add: tensor_op_ell2 cblinfun_apply_cblinfun_compose)


lemma tensor_op_cbilinear: cbilinear (tensor_op :: 'a::finite ell2 CL 'b::finite ell2
                                                  'c::finite ell2 CL 'd::finite ell2  _)
proof -
  have clinear (λb::'c ell2 CL 'd ell2. tensor_op a b) for a :: 'a ell2 CL 'b ell2
    apply (rule clinearI)
     apply (rule equal_ket, rename_tac ij, case_tac ij, rename_tac i j, hypsubst_thin)
     apply (simp flip: tensor_ell2_ket add: tensor_op_ell2 cblinfun.add_left tensor_ell2_add2)
    apply (rule equal_ket, rename_tac ij, case_tac ij, rename_tac i j, hypsubst_thin)
    by (simp add: scaleC_cblinfun.rep_eq tensor_ell2_scaleC2 tensor_op_ket)

  moreover have clinear (λa::'a::finite ell2 CL 'b::finite ell2. tensor_op a b) for b :: 'c ell2 CL 'd ell2
    apply (rule clinearI)
     apply (rule equal_ket, rename_tac ij, case_tac ij, rename_tac i j, hypsubst_thin)
     apply (simp flip: tensor_ell2_ket add: tensor_op_ell2 cblinfun.add_left tensor_ell2_add1)
    apply (rule equal_ket, rename_tac ij, case_tac ij, rename_tac i j, hypsubst_thin)
    by (simp add: scaleC_cblinfun.rep_eq tensor_ell2_scaleC1 tensor_op_ket)

  ultimately show ?thesis
    unfolding cbilinear_def by auto
qed


lemma tensor_butter: tensor_op (butterket i j) (butterket k l) = butterket (i,k) (j,l)
  for i :: "_" and j :: "_::finite" and k :: "_" and l :: "_::finite"
  apply (rule equal_ket, rename_tac x, case_tac x)
  apply (auto simp flip: tensor_ell2_ket simp: cblinfun_apply_cblinfun_compose tensor_op_ell2 butterfly_def)
  by (auto simp: tensor_ell2_scaleC1 tensor_ell2_scaleC2)

lemma cspan_tensor_op: cspan {tensor_op (butterket i j) (butterket k l)| i (j::_::finite) k (l::_::finite). True} = UNIV
  unfolding tensor_butter
  apply (subst cspan_butterfly_ket[symmetric])
  by (metis surj_pair)

lemma cindependent_tensor_op: cindependent {tensor_op (butterket i j) (butterket k l)| i (j::_::finite) k (l::_::finite). True}
  unfolding tensor_butter
  using cindependent_butterfly_ket
  by (smt (z3) Collect_mono_iff complex_vector.independent_mono)


lemma tensor_extensionality:
  fixes F G :: ((('a::finite × 'b::finite) ell2) CL (('c::finite × 'd::finite) ell2))  'e::complex_vector
  assumes [simp]: "clinear F" "clinear G"
  assumes tensor_eq: "(a b. F (tensor_op a b) = G (tensor_op a b))"
  shows "F = G"
proof (rule ext, rule complex_vector.linear_eq_on_span[where f=F and g=G])
  show clinear F and clinear G
    using assms by (simp_all add: cbilinear_def)
  show x  cspan  {tensor_op (butterket i j) (butterket k l)| i j k l. True} 
    for x :: ('a × 'b) ell2 CL ('c × 'd) ell2
    using cspan_tensor_op by auto
  show F x = G x if x  {tensor_op (butterket i j) (butterket k l) |i j k l. True} for x
    using that by (auto simp: tensor_eq)
qed

lemma tensor_id[simp]: tensor_op id_cblinfun id_cblinfun = id_cblinfun
  apply (rule equal_ket, rename_tac x, case_tac x)
  by (simp flip: tensor_ell2_ket add: tensor_op_ell2)

lemma tensor_op_adjoint: (tensor_op a b)* = tensor_op (a*) (b*)
  apply (rule cinner_ket_adjointI[symmetric])
  apply (auto simp flip: tensor_ell2_ket simp: tensor_op_ell2)
  by (simp add: cinner_adj_left)

lemma tensor_butterfly[simp]: "tensor_op (butterfly ψ ψ') (butterfly φ φ') = butterfly (tensor_ell2 ψ φ) (tensor_ell2 ψ' φ')"
  apply (rule equal_ket, rename_tac x, case_tac x)
  by (simp flip: tensor_ell2_ket add: tensor_op_ell2 butterfly_def
      cblinfun_apply_cblinfun_compose tensor_ell2_scaleC1 tensor_ell2_scaleC2)

definition tensor_lift :: (('a1::finite ell2 CL 'a2::finite ell2)  ('b1::finite ell2 CL 'b2::finite ell2)  'c)
                         ((('a1×'b1) ell2 CL ('a2×'b2) ell2)  'c::complex_vector) where
  "tensor_lift F2 = (SOME G. clinear G  (a b. G (tensor_op a b) = F2 a b))"

lemma 
  fixes F2 :: "'a::finite ell2 CL 'b::finite ell2
             'c::finite ell2 CL 'd::finite ell2
             'e::complex_normed_vector"
  assumes "cbilinear F2"
  shows tensor_lift_clinear: "clinear (tensor_lift F2)"
    and tensor_lift_correct:  (λa b. tensor_lift F2 (tensor_op a b)) = F2
proof -
  define F2' t4 φ where
    F2' = tensor_lift F2 and
    t4 = (λ(i,j,k,l). tensor_op (butterket i j) (butterket k l)) and
    φ m = (let (i,j,k,l) = inv t4 m in F2 (butterket i j) (butterket k l)) for m
  have t4inj: "x = y" if "t4 x = t4 y" for x y
  proof (rule ccontr)
    obtain i  j  k  l  where x: "x = (i,j,k,l)" by (meson prod_cases4) 
    obtain i' j' k' l' where y: "y = (i',j',k',l')" by (meson prod_cases4) 
    have 1: "bra (i,k) *V t4 x *V ket (j,l) = 1"
      by (auto simp: t4_def x tensor_op_ell2 butterfly_def cinner_ket simp flip: tensor_ell2_ket)
    assume x  y
    then have 2: "bra (i,k) *V t4 y *V ket (j,l) = 0"
      by (auto simp: t4_def x y tensor_op_ell2 butterfly_def cblinfun_apply_cblinfun_compose cinner_ket
               simp flip: tensor_ell2_ket)
    from 1 2 that
    show False
      by auto
  qed
  have φ (tensor_op (butterket i j) (butterket k l)) = F2 (butterket i j) (butterket k l) for i j k l
    apply (subst asm_rl[of tensor_op (butterket i j) (butterket k l) = t4 (i,j,k,l)])
     apply (simp add: t4_def)
    by (auto simp add: injI t4inj inv_f_f φ_def)

  have *: range t4 = {tensor_op (butterket i j) (butterket k l) |i j k l. True}
    apply (auto simp: case_prod_beta t4_def)
    using image_iff by fastforce

  have "cblinfun_extension_exists (range t4) φ"
    thm cblinfun_extension_exists_finite_dim[where φ=φ]
    apply (rule cblinfun_extension_exists_finite_dim)
     apply auto unfolding * 
    using cindependent_tensor_op
    using cspan_tensor_op
    by auto

  then obtain G where G: G *V (t4 (i,j,k,l)) = F2 (butterket i j) (butterket k l) for i j k l
    apply atomize_elim
    unfolding cblinfun_extension_exists_def
    apply auto
    by (metis (no_types, lifting) t4inj φ_def f_inv_into_f rangeI split_conv)

  have *: G *V tensor_op (butterket i j) (butterket k l) = F2 (butterket i j) (butterket k l) for i j k l
    using G by (auto simp: t4_def)
  have *: G *V tensor_op a (butterket k l) = F2 a (butterket k l) for a k l
    apply (rule complex_vector.linear_eq_on_span[where g=λa. F2 a _ and B={butterket k l|k l. True}])
    unfolding cspan_butterfly_ket
    using * apply (auto intro!: clinear_compose[unfolded o_def, where f=λa. tensor_op a _ and g=(*V) G])
     apply (metis cbilinear_def tensor_op_cbilinear)
    using assms unfolding cbilinear_def by blast
  have G_F2: G *V tensor_op a b = F2 a b for a b
    apply (rule complex_vector.linear_eq_on_span[where g=F2 a and B={butterket k l|k l. True}])
    unfolding cspan_butterfly_ket
    using * apply (auto simp: cblinfun.add_right clinearI
                        intro!: clinear_compose[unfolded o_def, where f=tensor_op a and g=(*V) G])
    apply (meson cbilinear_def tensor_op_cbilinear)
    using assms unfolding cbilinear_def by blast

  have clinear F2'  (a b. F2' (tensor_op a b) = F2 a b)
    unfolding F2'_def tensor_lift_def 
    apply (rule someI[where x=(*V) G and P=λG. clinear G  (a b. G (tensor_op a b) = F2 a b)])
    using G_F2 by (simp add: cblinfun.add_right clinearI)

  then show clinear F2' and (λa b. tensor_lift F2 (tensor_op a b)) = F2
    unfolding F2'_def by auto
qed

lift_definition assoc_ell20 :: (('a::finite×'b::finite)×'c::finite) ell2  ('a×('b×'c)) ell2 is
  λf (a,(b,c)). f ((a,b),c)
  by auto

lift_definition assoc_ell20' :: ('a::finite×('b::finite×'c::finite)) ell2  (('a×'b)×'c) ell2 is
  λf ((a,b),c). f (a,(b,c))
  by auto

lift_definition assoc_ell2 :: (('a::finite×'b::finite)×'c::finite) ell2 CL ('a×('b×'c)) ell2
  is assoc_ell20
  apply (subst bounded_clinear_finite_dim)
   apply (rule clinearI; transfer)
  by auto

lift_definition assoc_ell2' :: ('a::finite×('b::finite×'c::finite)) ell2 CL (('a×'b)×'c) ell2 is
  assoc_ell20'
  apply (subst bounded_clinear_finite_dim)
   apply (rule clinearI; transfer)
  by auto

lemma assoc_ell2_tensor: assoc_ell2 *V tensor_ell2 (tensor_ell2 a b) c = tensor_ell2 a (tensor_ell2 b c)
  apply (rule clinear_equal_ket[THEN fun_cong, where x=a])
    apply (simp add: cblinfun.add_right clinearI tensor_ell2_add1 tensor_ell2_scaleC1)
   apply (simp add: clinear_tensor_ell22)
  apply (rule clinear_equal_ket[THEN fun_cong, where x=b])
    apply (simp add: cblinfun.add_right clinearI tensor_ell2_add1 tensor_ell2_add2 tensor_ell2_scaleC1 tensor_ell2_scaleC2)
   apply (simp add: clinearI tensor_ell2_add1 tensor_ell2_add2 tensor_ell2_scaleC1 tensor_ell2_scaleC2)
  apply (rule clinear_equal_ket[THEN fun_cong, where x=c])
    apply (simp add: cblinfun.add_right clinearI tensor_ell2_add2 tensor_ell2_scaleC2)
   apply (simp add: clinearI tensor_ell2_add2 tensor_ell2_scaleC2)
  unfolding assoc_ell2.rep_eq
  apply transfer
  by auto

lemma assoc_ell2'_tensor: assoc_ell2' *V tensor_ell2 a (tensor_ell2 b c) = tensor_ell2 (tensor_ell2 a b) c
  apply (rule clinear_equal_ket[THEN fun_cong, where x=a])
    apply (simp add: cblinfun.add_right clinearI tensor_ell2_add1 tensor_ell2_scaleC1)
   apply (simp add: clinearI tensor_ell2_add1 tensor_ell2_scaleC1)
  apply (rule clinear_equal_ket[THEN fun_cong, where x=b])
    apply (simp add: cblinfun.add_right clinearI tensor_ell2_add1 tensor_ell2_add2 tensor_ell2_scaleC1 tensor_ell2_scaleC2)
   apply (simp add: clinearI tensor_ell2_add1 tensor_ell2_add2 tensor_ell2_scaleC1 tensor_ell2_scaleC2)
  apply (rule clinear_equal_ket[THEN fun_cong, where x=c])
    apply (simp add: cblinfun.add_right clinearI tensor_ell2_add2 tensor_ell2_scaleC2)
   apply (simp add: clinearI tensor_ell2_add2 tensor_ell2_scaleC2)
  unfolding assoc_ell2'.rep_eq
  apply transfer
  by auto

lemma adjoint_assoc_ell2[simp]: assoc_ell2* = assoc_ell2'
proof (rule adjoint_eqI[symmetric])
  have [simp]: clinear (cinner (assoc_ell2' *V x)) for x :: ('a × 'b × 'c) ell2
    by (metis (no_types, lifting) cblinfun.add_right cinner_scaleC_right clinearI complex_scaleC_def mult.comm_neutral of_complex_def vector_to_cblinfun_adj_apply)
  have [simp]: clinear (λa. x C (assoc_ell2 *V a)) for x :: ('a × 'b × 'c) ell2
    by (simp add: cblinfun.add_right cinner_add_right clinearI)
  have [simp]: antilinear (λa. a C y) for y :: ('a × 'b × 'c) ell2
    using bounded_antilinear_cinner_left bounded_antilinear_def by blast
  have [simp]: antilinear (λa. (assoc_ell2' *V a) C y) for y :: (('a × 'b) × 'c) ell2
    by (simp add: cblinfun.add_right cinner_add_left antilinearI)
  have (assoc_ell2' *V ket x) C ket y = ket x C (assoc_ell2 *V ket y) for x :: 'a × 'b × 'c and y
    apply (cases x, cases y)
    by (simp flip: tensor_ell2_ket add: assoc_ell2'_tensor assoc_ell2_tensor)
  then have (assoc_ell2' *V ket x) C y = ket x C (assoc_ell2 *V y) for x :: 'a × 'b × 'c and y
    by (rule clinear_equal_ket[THEN fun_cong, rotated 2], simp_all)
  then show (assoc_ell2' *V x) C y = x C (assoc_ell2 *V y) for x :: ('a × 'b × 'c) ell2 and y
    by (rule antilinear_equal_ket[THEN fun_cong, rotated 2], simp_all)
qed

lemma adjoint_assoc_ell2'[simp]: assoc_ell2'* = assoc_ell2
  by (simp flip: adjoint_assoc_ell2)


lift_definition swap_ell20 :: ('a::finite×'b::finite) ell2  ('b×'a) ell2 is
  λf (a,b). f (b,a)
  by auto

lift_definition swap_ell2 :: ('a::finite×'b::finite) ell2 CL ('b×'a) ell2
  is swap_ell20
  apply (subst bounded_clinear_finite_dim)
   apply (rule clinearI; transfer)
  by auto

lemma swap_ell2_tensor[simp]: swap_ell2 *V tensor_ell2 a b = tensor_ell2 b a
  apply (rule clinear_equal_ket[THEN fun_cong, where x=a])
    apply (simp add: cblinfun.add_right clinearI tensor_ell2_add1 tensor_ell2_scaleC1)
   apply (simp add: clinear_tensor_ell21)
  apply (rule clinear_equal_ket[THEN fun_cong, where x=b])
    apply (simp add: cblinfun.add_right clinearI tensor_ell2_add1 tensor_ell2_add2 tensor_ell2_scaleC1 tensor_ell2_scaleC2)
   apply (simp add: clinearI tensor_ell2_add1 tensor_ell2_add2 tensor_ell2_scaleC1 tensor_ell2_scaleC2)
  unfolding swap_ell2.rep_eq
  apply transfer
  by auto

lemma adjoint_swap_ell2[simp]: swap_ell2* = swap_ell2
proof (rule adjoint_eqI[symmetric])
  have [simp]: clinear (cinner (swap_ell2 *V x)) for x :: ('a × 'b) ell2
    by (metis (no_types, lifting) cblinfun.add_right cinner_scaleC_right clinearI complex_scaleC_def mult.comm_neutral of_complex_def vector_to_cblinfun_adj_apply)
  have [simp]: clinear (λa. x C (swap_ell2 *V a)) for x :: ('a × 'b) ell2
    by (simp add: cblinfun.add_right cinner_add_right clinearI)
  have [simp]: antilinear (λa. a C y) for y :: ('a × 'b) ell2
    using bounded_antilinear_cinner_left bounded_antilinear_def by blast
  have [simp]: antilinear (λa. (swap_ell2 *V a) C y) for y :: ('b × 'a) ell2
    by (simp add: cblinfun.add_right cinner_add_left antilinearI)
  have (swap_ell2 *V ket x) C ket y = ket x C (swap_ell2 *V ket y) for x :: 'a × 'b and y
    apply (cases x, cases y)
    by (simp flip: tensor_ell2_ket add: swap_ell2_tensor)
  then have (swap_ell2 *V ket x) C y = ket x C (swap_ell2 *V y) for x :: 'a × 'b and y
    by (rule clinear_equal_ket[THEN fun_cong, rotated 2], simp_all)
  then show (swap_ell2 *V x) C y = x C (swap_ell2 *V y) for x :: ('a × 'b) ell2 and y
    apply (rule antilinear_equal_ket[THEN fun_cong, rotated 2])
    by simp_all
qed


lemma tensor_ell2_extensionality:
  assumes "(s t. a *V (s s t) = b *V (s s t))"
  shows "a = b"
  apply (rule equal_ket, case_tac x, hypsubst_thin)
  by (simp add: assms flip: tensor_ell2_ket)

lemma assoc_ell2'_assoc_ell2[simp]: assoc_ell2' oCL assoc_ell2 = id_cblinfun
  by (auto intro!: equal_ket simp: cblinfun_apply_cblinfun_compose assoc_ell2'_tensor assoc_ell2_tensor simp flip: tensor_ell2_ket)

lemma assoc_ell2_assoc_ell2'[simp]: assoc_ell2 oCL assoc_ell2' = id_cblinfun
  by (auto intro!: equal_ket simp: cblinfun_apply_cblinfun_compose assoc_ell2'_tensor assoc_ell2_tensor simp flip: tensor_ell2_ket)

lemma unitary_assoc_ell2[simp]: "unitary assoc_ell2"
  unfolding unitary_def by auto

lemma unitary_assoc_ell2'[simp]: "unitary assoc_ell2'"
  unfolding unitary_def by auto

lemma tensor_op_left_add: (x + y) o b = x o b + y o b
  for x y :: 'a::finite ell2 CL 'c::finite ell2 and b :: 'b::finite ell2 CL 'd::finite ell2
  apply (auto intro!: equal_ket simp: tensor_op_ket)
  by (simp add: plus_cblinfun.rep_eq tensor_ell2_add1 tensor_op_ket)

lemma tensor_op_right_add: b o (x + y) = b o x + b o y
  for x y :: 'a::finite ell2 CL 'c::finite ell2 and b :: 'b::finite ell2 CL 'd::finite ell2
  apply (auto intro!: equal_ket simp: tensor_op_ket)
  by (simp add: plus_cblinfun.rep_eq tensor_ell2_add2 tensor_op_ket)

lemma tensor_op_scaleC_left: (c *C x) o b = c *C (x o b)
  for x :: 'a::finite ell2 CL 'c::finite ell2 and b :: 'b::finite ell2 CL 'd::finite ell2
  apply (auto intro!: equal_ket simp: tensor_op_ket)
  by (metis scaleC_cblinfun.rep_eq tensor_ell2_ket tensor_ell2_scaleC1 tensor_op_ell2)

lemma tensor_op_scaleC_right: b o (c *C x) = c *C (b o x)
  for x :: 'a::finite ell2 CL 'c::finite ell2 and b :: 'b::finite ell2 CL 'd::finite ell2
  apply (auto intro!: equal_ket simp: tensor_op_ket)
  by (metis scaleC_cblinfun.rep_eq tensor_ell2_ket tensor_ell2_scaleC2 tensor_op_ell2)

lemma clinear_tensor_left[simp]: clinear (λa. a o b :: _::finite ell2 CL _::finite ell2)
  apply (rule clinearI)
   apply (rule tensor_op_left_add)
  by (rule tensor_op_scaleC_left)

lemma clinear_tensor_right[simp]: clinear (λb. a o b :: _::finite ell2 CL _::finite ell2)
  apply (rule clinearI)
   apply (rule tensor_op_right_add)
  by (rule tensor_op_scaleC_right)

lemma tensor_ell2_nonzero: a s b  0 if a  0 and b  0
  apply (use that in transfer)
  apply auto
  by (metis mult_eq_0_iff old.prod.case)

lemma tensor_op_nonzero:
  fixes a :: 'a::finite ell2 CL 'c::finite ell2 and b :: 'b::finite ell2 CL 'd::finite ell2
  assumes a  0 and b  0
  shows a o b  0
proof -
  from a  0 obtain i where i: a *V ket i  0
    by (metis cblinfun.zero_left equal_ket)
  from b  0 obtain j where j: b *V ket j  0
    by (metis cblinfun.zero_left equal_ket)
  from i j have ijneq0: (a *V ket i) s (b *V ket j)  0
    by (simp add: tensor_ell2_nonzero)
  have (a *V ket i) s (b *V ket j) = (a o b) *V ket (i,j)
    by (simp add: tensor_op_ket)
  with ijneq0 show a o b  0
    by force
qed

lemma inj_tensor_ell2_left: inj (λa::'a::finite ell2. a s b) if b  0 for b :: 'b::finite ell2
proof (rule injI, rule ccontr)
  fix x y :: 'a ell2
  assume eq: x s b = y s b
  assume neq: x  y
  define a where a = x - y
  from neq a_def have neq0: a  0
    by auto
  with b  0 have a s b  0
    by (simp add: tensor_ell2_nonzero)
  then have x s b  y s b
    unfolding a_def
    by (metis add_cancel_left_left diff_add_cancel tensor_ell2_add1)
  with eq show False
    by auto
qed

lemma inj_tensor_ell2_right: inj (λb::'b::finite ell2. a s b) if a  0 for a :: 'a::finite ell2
proof (rule injI, rule ccontr)
  fix x y :: 'b ell2
  assume eq: a s x = a s y
  assume neq: x  y
  define b where b = x - y
  from neq b_def have neq0: b  0
    by auto
  with a  0 have a s b  0
    by (simp add: tensor_ell2_nonzero)
  then have a s x  a s y
    unfolding b_def
    by (metis add_cancel_left_left diff_add_cancel tensor_ell2_add2)
  with eq show False
    by auto
qed



lemma inj_tensor_left: inj (λa::'a::finite ell2 CL 'c::finite ell2. a o b) if b  0 for b :: 'b::finite ell2 CL 'd::finite ell2
proof (rule injI, rule ccontr)
  fix x y :: 'a ell2 CL 'c ell2
  assume eq: x o b = y o b
  assume neq: x  y
  define a where a = x - y
  from neq a_def have neq0: a  0
    by auto
  with b  0 have a o b  0
    by (simp add: tensor_op_nonzero)
  then have x o b  y o b
    unfolding a_def
    by (metis add_cancel_left_left diff_add_cancel tensor_op_left_add) 
  with eq show False
    by auto
qed

lemma inj_tensor_right: inj (λb::'b::finite ell2 CL 'c::finite ell2. a o b) if a  0 for a :: 'a::finite ell2 CL 'd::finite ell2
proof (rule injI, rule ccontr)
  fix x y :: 'b ell2 CL 'c ell2
  assume eq: a o x = a o y
  assume neq: x  y
  define b where b = x - y
  from neq b_def have neq0: b  0
    by auto
  with a  0 have a o b  0
    by (simp add: tensor_op_nonzero)
  then have a o x  a o y
    unfolding b_def
    by (metis add_cancel_left_left diff_add_cancel tensor_op_right_add) 
  with eq show False
    by auto
qed

lemma tensor_ell2_almost_injective:
  assumes tensor_ell2 a b = tensor_ell2 c d
  assumes a  0
  shows γ. b = γ *C d
proof -
  from a  0 obtain i where i: cinner (ket i) a  0
    by (metis cinner_eq_zero_iff cinner_ket_left ell2_pointwise_ortho)
  have cinner (ket i s ket j) (a s b) = cinner (ket i s ket j) (c s d) for j
    using assms by simp
  then have eq2: (cinner (ket i) a) * (cinner (ket j) b) = (cinner (ket i) c) * (cinner (ket j) d) for j
    by (metis tensor_ell2_inner_prod)
  then obtain γ where cinner (ket i) c = γ * cinner (ket i) a
    by (metis i eq_divide_eq)
  with eq2 have (cinner (ket i) a) * (cinner (ket j) b) = (cinner (ket i) a) * (γ * cinner (ket j) d) for j
    by simp
  then have cinner (ket j) b = cinner (ket j) (γ *C d) for j
    using i by force
  then have b = γ *C d
    by (simp add: cinner_ket_eqI)
  then show ?thesis
    by auto
qed


lemma tensor_op_almost_injective:
  fixes a c :: 'a::finite ell2 CL 'b::finite ell2
    and b d :: 'c::finite ell2 CL 'd::finite ell2
  assumes tensor_op a b = tensor_op c d
  assumes a  0
  shows γ. b = γ *C d
proof (cases d = 0)
  case False
  from a  0 obtain ψ where ψ: a *V ψ  0
    by (metis cblinfun.zero_left cblinfun_eqI)
  have (a o b) (ψ s φ) = (c o d) (ψ s φ) for φ
    using assms by simp
  then have eq2: (a ψ) s (b φ) = (c ψ) s (d φ) for φ
    by (simp add: tensor_op_ell2)
  then have eq2': (d φ) s (c ψ) = (b φ) s (a ψ) for φ
    by (metis swap_ell2_tensor)
  from False obtain φ0 where φ0: d φ0  0
    by (metis cblinfun.zero_left cblinfun_eqI)
  obtain γ where c ψ = γ *C a ψ
    apply atomize_elim
    using eq2' φ0 by (rule tensor_ell2_almost_injective)
  with eq2 have (a ψ) s (b φ) = (a ψ) s (γ *C d φ) for φ
    by (simp add: tensor_ell2_scaleC1 tensor_ell2_scaleC2)
  then have b φ = γ *C d φ for φ
    by (smt (verit, best) ψ complex_vector.scale_cancel_right tensor_ell2_almost_injective tensor_ell2_nonzero tensor_ell2_scaleC2)
  then have b = γ *C d
    by (simp add: cblinfun_eqI)
  then show ?thesis
    by auto
next
  case True
  then have c o d = 0
    by (metis add_cancel_right_left tensor_op_right_add)
  then have a o b = 0
    using assms(1) by presburger
  with a  0 have b = 0
    by (meson tensor_op_nonzero)
  then show ?thesis
    by auto
qed


lemma tensor_ell2_0_left[simp]: tensor_ell2 0 x = 0
  apply transfer by auto

lemma tensor_ell2_0_right[simp]: tensor_ell2 x 0 = 0
  apply transfer by auto

lemma tensor_op_0_left[simp]: tensor_op 0 x = (0 :: ('a::finite*'b::finite) ell2 CL ('c::finite*'d::finite) ell2)
  apply (rule equal_ket)
  by (auto simp flip: tensor_ell2_ket simp: tensor_op_ell2)

lemma tensor_op_0_right[simp]: tensor_op x 0 = (0 :: ('a::finite*'b::finite) ell2 CL ('c::finite*'d::finite) ell2)
  apply (rule equal_ket)
  by (auto simp flip: tensor_ell2_ket simp: tensor_op_ell2)

lemma bij_tensor_ell2_one_dim_left:
  assumes ψ  0
  shows bij (λx::'b::finite ell2. (ψ :: 'a::CARD_1 ell2) s x)
proof (rule bijI)
  show inj (λx::'b::finite ell2. (ψ :: 'a::CARD_1 ell2) s x)
    using assms by (rule inj_tensor_ell2_right)
  have x. ψ s x = φ for φ :: ('a*'b) ell2
  proof (use assms in transfer)
    fix ψ :: 'a  complex and φ :: 'a*'b  complex
    assume has_ell2_norm φ and ψ  (λ_. 0)
    define c where c = ψ undefined
    then have ψ a = c for a 
      apply (subst everything_the_same[of _ undefined])
      by simp
    with ψ  (λ_. 0) have c  0
      by auto

    define x where x j = φ (undefined, j) / c for j
    have (λ(i, j). ψ i * x j) = φ
      apply (auto intro!: ext simp: x_def ψ _ = c c  0)
      apply (subst (2) everything_the_same[of _ undefined])
      by simp
    then show xCollect has_ell2_norm. (λ(i, j). ψ i * x j) = φ
      apply (rule bexI[where x=x])
      by simp
  qed

  then show surj (λx::'b::finite ell2. (ψ :: 'a::CARD_1 ell2) s x)
    by (metis surj_def)
qed

lemma bij_tensor_op_one_dim_left:
  assumes a  0
  shows bij (λx::'c::finite ell2 CL 'd::finite ell2. (a :: 'a::{CARD_1,enum} ell2 CL 'b::{CARD_1,enum} ell2) o x)
proof (rule bijI)
  define t where t = (λx::'c ell2 CL 'd ell2. (a :: 'a ell2 CL 'b ell2) o x)
  define i where
    i = tensor_lift (λ(x::'a ell2 CL 'b ell2) (y::'c ell2 CL 'd ell2). (one_dim_iso x / one_dim_iso a) *C y)

  have [simp]: clinear i
    by (auto intro!: tensor_lift_clinear simp: i_def cbilinear_def clinearI scaleC_add_left add_divide_distrib)
  have [simp]: clinear t
    by (simp add: t_def)
  have i (x o y) = (one_dim_iso x / one_dim_iso a) *C y for x y
    by (auto intro!: clinearI tensor_lift_correct[THEN fun_cong, THEN fun_cong] simp: t_def i_def cbilinear_def  scaleC_add_left add_divide_distrib)
  then have t (i (x o y)) = x o y for x y
    apply (simp add: t_def)
    by (smt (z3) assms complex_vector.scale_eq_0_iff nonzero_mult_div_cancel_right one_dim_scaleC_1 scaleC_scaleC tensor_op_scaleC_left tensor_op_scaleC_right times_divide_eq_left)
  then have t (i x) = x for x
    apply (rule_tac fun_cong[where x=x])
    apply (rule tensor_extensionality)
    by (auto intro: clinear_compose complex_vector.module_hom_ident simp flip: o_def[of t i])
  then show surj t 
    by (rule surjI)

  show inj t
    unfolding t_def using assms by (rule inj_tensor_right)
qed

lemma swap_ell2_selfinv[simp]: swap_ell2 oCL swap_ell2 = id_cblinfun
  apply (rule tensor_ell2_extensionality)
  by auto

lemma bij_tensor_op_one_dim_right:
  assumes b  0
  shows bij (λx::'c::finite ell2 CL 'd::finite ell2. x o (b :: 'a::{CARD_1,enum} ell2 CL 'b::{CARD_1,enum} ell2))
    (is bij ?f)
proof -
  let ?sf = (λx. swap_ell2 oCL (?f x) oCL swap_ell2)
  let ?s = (λx. swap_ell2 oCL x oCL swap_ell2)
  let ?g = (λx::'c::finite ell2 CL 'd::finite ell2. (b :: 'a::{CARD_1,enum} ell2 CL 'b::{CARD_1,enum} ell2) o x)
  have ?sf = ?g
    by (auto intro!: ext tensor_ell2_extensionality simp add: swap_ell2_tensor tensor_op_ell2)
  have bij ?g
    using assms by (rule bij_tensor_op_one_dim_left)
  have ?s o ?sf = ?f
    apply (auto intro!: ext simp: cblinfun_assoc_left)
    by (auto simp: cblinfun_assoc_right)
  also have bij ?s
    apply (rule o_bij[where g=(λx. swap_ell2 oCL x oCL swap_ell2)])
     apply (auto intro!: ext simp: cblinfun_assoc_left)
    by (auto simp: cblinfun_assoc_right)
  show bij ?f
    apply (subst ?s o ?sf = ?f[symmetric], subst ?sf = ?g)
    using bij ?g bij ?s by (rule bij_comp)
qed

lemma overlapping_tensor:
  fixes a23 :: ('a2::finite*'a3::finite) ell2 CL ('b2::finite*'b3::finite) ell2
    and b12 :: ('a1::finite*'a2) ell2 CL ('b1::finite*'b2) ell2
  assumes eq: butterfly ψ ψ' o a23 = assoc_ell2 oCL (b12 o butterfly φ φ') oCL assoc_ell2'
  assumes ψ  0 ψ'  0 φ  0 φ'  0
  shows c. butterfly ψ ψ' o a23 = butterfly ψ ψ' o c o butterfly φ φ'
proof -
  note [[show_types]]
  let ?id1 = id_cblinfun :: unit ell2 CL unit ell2
  note id_cblinfun_eq_1[simp del]
  define d where d = butterfly ψ ψ' o a23

  define ψn ψn' a23n where ψn = ψ /C norm ψ and ψn' = ψ' /C norm ψ' and a23n = norm ψ *C norm ψ' *C a23
  have [simp]: norm ψn = 1 norm ψn' = 1
    using ψ  0 ψ'  0 by (auto simp: ψn_def ψn'_def norm_inverse)
  have n1: butterfly ψn ψn' o a23n = butterfly ψ ψ' o a23
    apply (auto simp: ψn_def ψn'_def a23n_def tensor_op_scaleC_left tensor_op_scaleC_right)
    by (metis (no_types, lifting) assms(2) assms(3) inverse_mult_distrib mult.commute no_zero_divisors norm_eq_zero of_real_eq_0_iff right_inverse scaleC_one)

  define φn φn' b12n where φn = φ /C norm φ and φn' = φ' /C norm φ' and b12n = norm φ *C norm φ' *C b12
  have [simp]: norm φn = 1 norm φn' = 1
    using φ  0 φ'  0 by (auto simp: φn_def φn'_def norm_inverse)
  have n2: b12n o butterfly φn φn' = b12 o butterfly φ φ'
    apply (auto simp: φn_def φn'_def b12n_def tensor_op_scaleC_left tensor_op_scaleC_right)
    by (metis (no_types, lifting) assms(4) assms(5) field_class.field_inverse inverse_mult_distrib mult.commute no_zero_divisors norm_eq_zero of_real_hom.hom_0 scaleC_one)

  define c' :: (unit*'a2*unit) ell2 CL (unit*'b2*unit) ell2 
    where c' = (vector_to_cblinfun ψn o id_cblinfun o vector_to_cblinfun φn)* oCL d
            oCL (vector_to_cblinfun ψn' o id_cblinfun o vector_to_cblinfun φn')

  define c'' :: 'a2 ell2 CL 'b2 ell2
    where c'' = inv (λc''. id_cblinfun o c'' o id_cblinfun) c'

  have *: bij (λc''::'a2 ell2 CL 'b2 ell2. ?id1 o c'' o ?id1)
    apply (subst asm_rl[of _ = (λx. id_cblinfun o x) o (λc''. c'' o id_cblinfun)])
    using [[show_consts]]
    by (auto intro!: bij_comp bij_tensor_op_one_dim_left bij_tensor_op_one_dim_right)

  have c'_c'': c' = ?id1 o c'' o ?id1
    unfolding c''_def 
    apply (rule surj_f_inv_f[where y=c', symmetric])
    using * by (rule bij_is_surj)

  define c :: 'a2 ell2 CL 'b2 ell2
    where c = c'' /C norm ψ /C norm ψ' /C norm φ /C norm φ'

  have aux: assoc_ell2' oCL (assoc_ell2 oCL x oCL assoc_ell2') oCL assoc_ell2 = x for x
    apply (simp add: cblinfun_assoc_left)
    by (simp add: cblinfun_assoc_right)
  have aux2: (assoc_ell2 oCL ((x o y) o z) oCL assoc_ell2') = x o (y o z) for x y z
    apply (rule equal_ket, rename_tac xyz)
    apply (case_tac xyz, hypsubst_thin)
    by (simp flip: tensor_ell2_ket add: assoc_ell2'_tensor assoc_ell2_tensor tensor_op_ell2)

  have d = (butterfly ψn ψn o id_cblinfun) oCL d oCL (butterfly ψn' ψn' o id_cblinfun)
    by (auto simp: d_def n1[symmetric] comp_tensor_op cnorm_eq_1[THEN iffD1])
  also have  = (butterfly ψn ψn o id_cblinfun) oCL assoc_ell2 oCL (b12n o butterfly φn φn')
                  oCL assoc_ell2' oCL (butterfly ψn' ψn' o id_cblinfun)
    by (auto simp: d_def eq n2 cblinfun_assoc_left)
  also have  = (butterfly ψn ψn o id_cblinfun) oCL assoc_ell2 oCL 
               ((id_cblinfun o butterfly φn φn) oCL (b12n o butterfly φn φn') oCL (id_cblinfun o butterfly φn' φn'))
               oCL assoc_ell2' oCL (butterfly ψn' ψn' o id_cblinfun)
    by (auto simp: comp_tensor_op cnorm_eq_1[THEN iffD1])
  also have  = (butterfly ψn ψn o id_cblinfun) oCL assoc_ell2 oCL 
               ((id_cblinfun o butterfly φn φn) oCL (assoc_ell2' oCL d oCL assoc_ell2) oCL (id_cblinfun o butterfly φn' φn'))
               oCL assoc_ell2' oCL (butterfly ψn' ψn' o id_cblinfun)
    by (auto simp: d_def n2 eq aux)
  also have  = ((butterfly ψn ψn o id_cblinfun) oCL (assoc_ell2 oCL (id_cblinfun o butterfly φn φn) oCL assoc_ell2'))
               oCL d oCL ((assoc_ell2 oCL (id_cblinfun o butterfly φn' φn') oCL assoc_ell2') oCL (butterfly ψn' ψn' o id_cblinfun))
    by (auto simp: sandwich_def cblinfun_assoc_left)
  also have  = (butterfly ψn ψn o id_cblinfun o butterfly φn φn)
               oCL d oCL (butterfly ψn' ψn' o id_cblinfun o butterfly φn' φn')
    apply (simp only: tensor_id[symmetric] comp_tensor_op aux2)
    by (simp add: cnorm_eq_1[THEN iffD1])
  also have  = (vector_to_cblinfun ψn o id_cblinfun o vector_to_cblinfun φn)
               oCL c' oCL (vector_to_cblinfun ψn' o id_cblinfun o vector_to_cblinfun φn')*
    apply (simp add: c'_def butterfly_def_one_dim[where 'c="unit ell2"] cblinfun_assoc_left comp_tensor_op
                      tensor_op_adjoint cnorm_eq_1[THEN iffD1])
    by (simp add: cblinfun_assoc_right comp_tensor_op)
  also have  = butterfly ψn ψn' o c'' o butterfly φn φn'
    by (simp add: c'_c'' comp_tensor_op tensor_op_adjoint butterfly_def_one_dim[symmetric])
  also have  = butterfly ψ ψ' o c o butterfly φ φ'
    by (simp add: ψn_def ψn'_def φn_def φn'_def c_def tensor_op_scaleC_left tensor_op_scaleC_right)
  finally have d_c: d = butterfly ψ ψ' o c o butterfly φ φ'
    by -
  then show ?thesis
    by (auto simp: d_def)
qed

lemma norm_tensor_ell2: norm (a s b) = norm a * norm b
  apply transfer
  by (simp add: ell2_norm_finite sum_product sum.cartesian_product case_prod_beta
      norm_mult power_mult_distrib flip: real_sqrt_mult)

lemma bounded_cbilinear_tensor_ell2[bounded_cbilinear]: bounded_cbilinear (⊗s)
proof standard
  fix a a' :: "'a ell2" and b b' :: "'b ell2" and r :: complex
  show tensor_ell2 (a + a') b = tensor_ell2 a b + tensor_ell2 a' b
    by (meson tensor_ell2_add1)
  show tensor_ell2 a (b + b') = tensor_ell2 a b + tensor_ell2 a b'
    by (simp add: tensor_ell2_add2)  
  show tensor_ell2 (r *C a) b = r *C tensor_ell2 a b
    by (simp add: tensor_ell2_scaleC1)
  show tensor_ell2 a (r *C b) = r *C tensor_ell2 a b
    by (simp add: tensor_ell2_scaleC2)
  show K. a b. norm (tensor_ell2 a b)  norm a * norm b * K
    apply (rule exI[of _ 1])
    by (simp add: norm_tensor_ell2)
qed


end