Theory Tensor_Mat_Compl_Properties

(*
Author: 
  Mnacho Echenim, Université Grenoble Alpes
*)

theory Tensor_Mat_Compl_Properties 
  imports 
    Commuting_Hermitian.Spectral_Theory_Complements
    Projective_Measurements.Projective_Measurements
begin

section ‹Basic algebraic results›


lemma pos_sum_gt_0:
  assumes "finite I"
and "i. i  I  (0:: 'a :: linordered_field)  f i"
and "0 < sum f I"
shows "j  I. 0 < f j"
proof (rule ccontr)
  assume "¬ (jI. 0 < f j)"
  hence "j  I. f j  0" by auto
  hence "j  I. f j = 0" using assms by fastforce
  hence "sum f I = 0" by simp
  thus False using assms by simp
qed

lemma pos_square_1_elem:
  assumes "finite I"
and "i. i  I  (0::real)  f i"
and "sum f I = 1"
and "sum (λx. f x * f x) I = 1"
shows "j  I. f j = 1"
proof (rule ccontr)
  assume "¬ (jI. f j = 1)"
  hence ne: "j I. f j  1" by simp
  have "j  I. 0 < f j" using pos_sum_gt_0[of I f] assms by simp
  from this obtain j where "jI" and "0 < f j" by auto
  hence "f j  1" using ne by simp
  moreover have "f j  1" using j I assms pos_sum_le_comp by force
  ultimately have "f j < 1" by auto
  have "sum (λx. f x * f x) I = f j * f j + sum (λx. f x * f x) (I-{j})"
    by (meson j  I assms(1) sum.remove)
  also have "... < f j + sum (λx. f x * f x) (I-{j})"
    by (simp add: 0 < f j f j < 1)
  also have "...  f j + sum f (I-{j})" 
    using square_pos_mult_le[of "I -  {j}"]
    by (smt (verit, ccfv_SIG) DiffD1 assms
        mult_left_le sum_mono sum_nonneg_leq_bound)
  also have "... = sum f I"
    by (metis j  I assms(1) sum.remove)
  also have "... = 1" using assms by simp
  finally have "sum (λx. f x * f x) I < 1" .
  thus False using assms by simp
qed

lemma cpx_pos_square_1_elem:
  assumes "finite I"
and "i. i  I  (0::complex)  f i"
and "sum f I = 1"
and "sum (λx. f x * f x) I = 1"
shows "j  I. f j = 1"
proof -
  have "i I. Im( f i) = 0" using assms complex_is_Real_iff
    by (meson  nonnegative_complex_is_real 
        positive_unitary_diag_pos real_diag_decompD(1))
  hence al: "i  I. Re(f i) = f i"
    by (simp add: assms complex.expand)
  have " j  I. Re (f j) = 1"
  proof (rule pos_square_1_elem)
    show "finite I" using assms by simp
    show "i. i  I  0  Re (f i)" using assms al
      by (simp add: less_eq_complex_def) 
    show "(jI. Re (f j)) = 1" using al
      by (metis Re_sum assms(3) one_complex.simps(1))
    show "(xI. Re (f x) * Re (f x)) = 1" using al
      by (smt (z3) assms(4) of_real_hom.hom_1 of_real_hom.hom_mult 
          of_real_hom.hom_sum sum.cong) 
  qed
  thus ?thesis using al by force
qed

lemma sum_eq_elmt:
  assumes "finite I"
  and "i. i  I  (0::'a :: linordered_field)  f i"
  and "sum f I = c"
  and "jI"
  and "f j = c"
shows "k(I-{j}). f k = 0"
proof -
  have "sum f I - f j= sum f (I - {j})" using assms sum_diff1[of I f j] by auto
  also have "sum f I - f j = 0" using assms
    using f j = c by linarith
  hence "sum f (I - {j}) = 0" using assms
    using calculation by linarith 
  finally show "k(I-{j}). f k = 0"
    by (meson DiffD1 sum f (I - {j}) = 0 assms(1) assms(2) finite_Diff sum_nonneg_eq_0_iff)
qed

lemma cpx_sum_eq_elmt:
  assumes "finite I"
  and "i. i  I  (0::complex)  f i"
  and "sum f I = c"
  and "jI"
  and "f j = c"
shows "k(I-{j}). f k = 0"
proof -
  have "sum f I - f j= sum f (I - {j})" using assms sum_diff1[of I f j] by auto
  also have "sum f I - f j = 0" using assms
    using f j = c by simp
  hence "sum f (I - {j}) = 0" using assms
    using calculation by simp 
  finally show "k(I-{j}). f k = 0"
    by (meson DiffD1 sum f (I - {j}) = 0 assms(1) assms(2) 
        finite_Diff sum_nonneg_eq_0_iff)
qed

lemma sum_nat_div_mod:
  shows "sum (λi. sum (λj. f i * g j) {..< (m::nat)}) {..< (n::nat)} =  
    sum (λk. f (k div m) * g (k mod m)) {..< n*m}" 
proof (induct n)
  case 0
  then show ?case by simp
next
  case (Suc n)
  have "(i<Suc n. j<m. f i * g j) = (i< n. j<m. f i * g j) + 
    (j<m. f n * g j)" 
    by simp
  also have "... = (k<n * m. f (k div m) * g (k mod m)) + 
    (j<m. f n * g j)" 
    using Suc by simp
  also have "... = (k<n * m. f (k div m) * g (k mod m)) + 
    sum (λk. f (k div m) * g (k mod m)) {n*m ..< (Suc n) * m}"
  proof -
    have "(j<m. f n * g j) = 
      sum (λk. f (k div m) * g (k mod m)) {n*m ..< (Suc n) * m}"
    proof (rule sum.reindex_cong)
      show "inj_on (λj. j mod m) {n * m..<Suc n * m}" 
      proof
        fix x y
        assume "x  {n * m..<Suc n * m}" and "y  {n * m..<Suc n * m}"
          and "x mod m = y mod m"
        thus "x = y"
          by (metis atLeastLessThan_iff div_nat_eqI mod_div_decomp 
              mult.commute)
      qed
      show "{..<m} = (λj. j mod m) ` {n * m..<Suc n * m}"
      proof
        show "{..<m}  (λj. j mod m) ` {n * m..<Suc n * m}"
        proof
          fix x
          assume "x {..< m}"
          hence "n * m + x {n * m..<Suc n * m}" by simp
          moreover have "x = (n * m + x) mod m" using x  {..<m} by auto 
          ultimately show "x  (λj. j mod m) ` {n * m..<Suc n * m}" 
            using x  {..<m} by blast 
        qed
      qed auto
      fix x
      assume "x  {n * m..<Suc n * m}"
      thus "f n * g (x mod m) = f (x div m) * g (x mod m)" by auto
    qed
    thus ?thesis by simp
  qed
  also have "... = sum (λk. f (k div m) * g (k mod m)) 
    ({..<n*m}  {n * m..<Suc n * m})" 
    by (rule sum.union_disjoint[symmetric], auto)
  also have "... = (k<(Suc n) * m. f (k div m) * g (k mod m))" 
  proof -
    have "{..<n*m}  {n * m..<Suc n * m} = {..< Suc n * m}"
      by (simp add: ivl_disj_un_one(2))
    thus ?thesis by simp
  qed
  finally show ?case .
qed

lemma abs_cmod_eq:
  fixes z::complex
  shows "¦z¦ = cmod z"
  by (simp add: abs_complex_def)

lemma real_cpx_abs_leq:
  fixes A::complex
  assumes "A Reals"
  and "B Reals"
  and "¦A * B¦  1"
shows "¦Re A * Re B¦  1"
proof -
  have "¦Re A * Re B¦ = ¦A * B¦"  using assms
      by (metis Reals_mult abs_cmod_eq in_Reals_norm real_mult_re) 
  also have "...  1" using assms by simp 
  finally show "¦Re A * Re B¦  1"
    by (metis Re_complex_of_real less_eq_complex_def one_complex.sel(1)) 
qed

lemma cpx_real_abs_eq:
  fixes z::complex and r::real
  assumes "z Reals"
  and "z = r"
shows "¦z¦ = ¦r¦"
proof -
  have "Re z = r" using assms by simp
  have "Im z = 0"  using assms complex_is_Real_iff by auto
  have "¦z¦ = cmod z" by (simp add: abs_complex_def)  
  hence "¦z¦ = ¦Re z¦" using Im z = 0 assms by simp
  thus ?thesis using Re z = r by simp
qed

lemma cpx_real_abs_leq:
  fixes z::complex and r::real
  assumes "z Reals"
  and "z = r" 
  and "¦r¦  k"
shows "¦z¦  (k::real)"
proof -
  have "Re z = r" using assms by simp
  hence "¦Re z¦  k" using assms by simp
  have "Im z = 0"  using assms complex_is_Real_iff by auto
  have "¦z¦ = cmod z" by (simp add: abs_complex_def)  
  hence "¦z¦ = ¦Re z¦" using Im z = 0 assms by simp
  thus ?thesis using ¦Re z¦  k by (simp add: less_eq_complex_def)
qed

lemma cpx_abs_mult_le_1:
  fixes z::complex
  assumes "¦z¦  1"
  and "¦z'¦  1"
shows "¦z*z'¦  1"
proof -
  have a: "cmod z  1"
    by (metis Reals_1 abs_1 abs_cmod_eq assms(1) 
        cpx_real_abs_leq dual_order.antisym linorder_le_cases 
        of_real_eq_1_iff)
  have b: "cmod z'  1"
    by (metis Reals_1 abs_1 abs_cmod_eq assms(2) 
        cpx_real_abs_leq dual_order.antisym linorder_le_cases 
        of_real_eq_1_iff)
  have "¦z*z'¦ = ¦z¦*¦z'¦"
    by (simp add: abs_mult)
  also have "... = cmod z * (cmod z')"
    using abs_cmod_eq by auto
  also have "...  1" using a b
    by (simp add: less_eq_complex_def mult_le_one)
  finally show ?thesis .
qed

lemma sum_abs_cpx:
  shows "¦sum K I¦  sum (λx. ¦(K x)::complex¦) I"
proof -
  have "¦sum K I¦ = cmod (sum K I)"
    using abs_cmod_eq by blast
  also have "...  sum (λx. cmod (K x)) I" using norm_sum
    by (metis Im_complex_of_real Re_complex_of_real less_eq_complex_def)
  also have "... = sum (λx. ¦(K x)::complex¦) I"
    using abs_cmod_eq by fastforce
  finally show ?thesis .
qed

lemma abs_mult_cpx:
  fixes z::complex
  assumes "0  (a::real)"
  shows "¦a*z¦ = a * ¦z¦"
proof -
  have "¦a*z¦ = cmod (a*z)" using abs_cmod_eq by blast
  also have "... = a * cmod z" using assms
    by (simp add: norm_mult)
  also have "... = a * ¦z¦" by (simp add: abs_cmod_eq) 
  finally show ?thesis .
qed

lemma cpx_ge_0_real:
  fixes c::complex
  assumes "0  c"
  and "c Reals"
shows "0  Re c" 
proof -
  have "Re c = c" using assms by simp
  hence "0  complex_of_real (Re (c::complex))" using assms by simp
  thus ?thesis using less_eq_complex_def by auto
qed

lemma cpx_of_real_ge_0:
  assumes "0  complex_of_real a"
  shows "0  a" 
proof -
  have "0  Re (complex_of_real a)"
    using Reals_of_real assms cpx_ge_0_real by blast
  also have "... = a" by simp
  finally show ?thesis .
qed


lemma set_cst_list:  
  shows "(i. i < length l  l!i = x)  0 < length l  set l = {x}"
proof (induct l)
  case Nil
  then show ?case by simp
next
  case (Cons a l)
  then show ?case
    by (metis in_set_conv_nth insert_absorb is_singletonI' 
        is_singleton_def singleton_iff)
qed

lemma pos_mult_Max:
  assumes "finite F"
and "F  {}"
and "0  x"
and "a F. 0  (a::real)"
shows "Max.F {x * a|a. a  F} = x * Max.F F" 
proof -
  define M where "M = Max.F F"
  have "finite {x * a|a. a  F}" using assms by auto
  have "M F" using assms unfolding M_def by simp
  hence "x*M  {x * a|a. a  F}"  by auto
  moreover have "c{x * a|a. a  F}. c  x*M"
    using M_def assms eq_Max_iff 
      ordered_comm_semiring_class.comm_mult_left_mono by fastforce
  ultimately show ?thesis using assms Max_eqI M_def finite {x * a |a. a  F} 
    by blast
qed

lemma square_Max:
  assumes "finite A"
  and "A {}"
  and "a A. 0  ((f a)::real)"
  and "b = Max.F {f a |a. a A}"
shows "Max.F {f a* f a|a. a A} = b * b" 
proof -
  define B where "B = {f a* f a|a. a A}"
  have "finite B" using finite_image_set unfolding B_def by (simp add: assms)
  have "finite {f a |a. a A}" using assms by auto
  hence "b {f a |a. a A}" using assms
    by (metis (mono_tags, lifting) Collect_empty_eq_bot Max_eq_iff all_not_in_conv 
        bot_empty_eq)
  hence "b*b  B" unfolding B_def by auto
  moreover have "cB. c  b*b"
  proof
    fix c
    assume "c B"
    hence "d A. c = f d* f d" unfolding B_def by auto
    from this obtain d where "d A" and "c = f d * f d" by auto 
      note dprop = this
    hence "f d  {f a |a. a A}" by auto
    hence "f d  b" using assms by auto
    thus "c  b*b" using assms by (simp add: dprop mult_mono')
  qed
  ultimately show ?thesis using assms Max_eqI[of B "b*b"] finite B 
    by (metis B_def)
qed

lemma ereal_Sup_switch: 
  assumes " m P. (b::real)  f m"
  and "m  P. f m  (c::real)"
  and "P  {}"
shows "ereal (Sup (f ` P)) = (m P. ereal (f m))"
proof (rule ereal_SUP)
  have b: "m  P. b  (ereal (f m))" using assms by auto
  hence "b  ( m P. ereal (f m))" using assms 
    by (meson Sup_upper2 ex_in_conv image_eqI) 
  have m: "m  P. (ereal (f m))  c" using assms by auto
  hence c: "Sup (ereal ` (f` P))  c"
    by (simp add: assms(3) cSUP_least image_image)  
  show "¦ m P. ereal (f m)¦  " using b c MInfty_neq_ereal(2)
    by (metis PInfty_neq_ereal(1) m b 
        assms(3) ereal_SUP_not_infty) 
qed

lemma Sup_ge_real:
  assumes "a (A::real set)"
  and "a  A. a  c"
  and "a  A. b  a"
shows "a  Sup A"
proof -
  define B where "B = {ereal a|a. a A}"
  have "ereal a  B" using assms unfolding B_def by simp
  hence "ereal a  Sup B" by (simp add: Sup_upper) 
  also have "... = ereal (Sup A)" 
    using ereal_Sup_switch[symmetric, of A b "λx. x" c] assms unfolding B_def
    by (metis B_def Collect_mem_eq empty_iff image_Collect image_ident)
  finally have "ereal a  ereal (Sup A)" .
  thus ?thesis by simp
qed

lemma Sup_real_le:
  assumes "a  (A::real set). a  c"
  and "a  A. b  a"
  and "A {}"
shows "Sup A  c" 
proof -
  define B where "B = {ereal a|a. a A}"
  have "Sup B  ereal c" unfolding B_def using SUP_least[of A "λx. x" c] assms 
    by (simp add: Setcompr_eq_image) 
  moreover have "Sup B = ereal (Sup A)" unfolding B_def
    using ereal_Sup_switch[symmetric, of A b "λx. x" c] assms 
    by (metis B_def Collect_mem_eq image_Collect image_ident)
  ultimately show ?thesis by simp
qed

section ‹Results in linear algebra›


lemma mat_add_eq_0_if:
  fixes A::"'a ::group_add Matrix.mat"
  assumes "A carrier_mat n m"
  and "B carrier_mat n m"
  and "A+B = 0m n m"
shows "B = -A" 
proof (rule eq_matI)
  show "dim_row B = dim_row (-A)" using assms by simp
  show "dim_col B = dim_col (-A)" using assms by simp
  fix i j
  assume "i < dim_row (-A)" and "j < dim_col (-A)" note ij= this
  hence "i < dim_row B" "j < dim_col B" 
    using dim_row B = dim_row (-A) dim_col B = dim_col (-A) by auto
  hence "A $$ (i,j) + B $$ (i,j) = (A+B)$$(i,j)" using ij by simp
  also have "... = 0"
    by (metis dim_col B = dim_col (-A) dim_row B = dim_row (-A) 
        assms(2) assms(3) carrier_matD(1) ij(1) ij(2) index_add_mat(3) 
        index_zero_mat(1) index_zero_mat(3))
  finally have "A $$ (i,j) + B $$ (i,j) = 0" .
  thus "B $$ (i, j) = (- A) $$ (i, j)"
    by (metis dim_col B = dim_col (- A) dim_row B = dim_row (- A) 
        i < dim_row B j < dim_col B add_eq_0_iff index_uminus_mat(1) 
        index_uminus_mat(2) index_uminus_mat(3))
qed

lemma trace_rank_1_proj:
  shows "Complex_Matrix.trace (rank_1_proj v) = v2"
proof -
  have "Complex_Matrix.trace (rank_1_proj v) = inner_prod v v" 
    using trace_outer_prod carrier_vecI
    unfolding rank_1_proj_def by blast
  also have "... = (vec_norm v)2" 
    unfolding vec_norm_def using power2_csqrt by presburger
  also have "... = v2" using vec_norm_sq_cpx_vec_length_sq by simp
  finally show ?thesis .
qed

lemma trace_ch_expand:
  fixes A::"'a::{minus,comm_ring} Matrix.mat"
  assumes "A carrier_mat n n"
  and "B carrier_mat n n"
  and "C carrier_mat n n"
  and "D carrier_mat n n"
shows "Complex_Matrix.trace (A - B + C + D) =
  Complex_Matrix.trace A - Complex_Matrix.trace B + 
  Complex_Matrix.trace C + Complex_Matrix.trace D" 
proof -
  have "Complex_Matrix.trace (A - B + C + D) = 
    Complex_Matrix.trace (A - B + C) + Complex_Matrix.trace D" 
    using trace_add_linear[of _ n D] assms by simp
  also have "... = Complex_Matrix.trace (A - B) + Complex_Matrix.trace C + 
    Complex_Matrix.trace D" using assms trace_add_linear[of _ n C] 
    by (metis minus_carrier_mat')
  finally show ?thesis using assms trace_minus_linear by auto
qed

lemma squared_A_trace:
  assumes "A carrier_mat n n"
  and "unitarily_equiv A B U"
shows "Complex_Matrix.trace (A*A) = Complex_Matrix.trace (B*B)"
proof (rule unitarily_equiv_trace)
  show "A*A  carrier_mat n n" using assms by simp
  show "unitarily_equiv (A * A) (B * B) U" 
    using assms unitarily_equiv_square[of A n] by simp
qed

lemma squared_A_trace':
assumes "A carrier_mat n n"
  and "unitary_diag A B U"
shows "Complex_Matrix.trace (A*A) = ( i  {0 ..< n}. (B $$ (i,i) * B $$ (i,i)))"
proof -
  have "Complex_Matrix.trace (A*A) = Complex_Matrix.trace (B*B)"
    using assms squared_A_trace[of A]
    by (meson unitary_diag_imp_unitarily_equiv)
  also have "... = ( i  {0 ..< n}. (B * B) $$ (i,i))" using assms 
    unfolding Complex_Matrix.trace_def
    by (metis (mono_tags, lifting) carrier_matD(1) index_mult_mat(2) 
        unitary_diag_carrier(1))
  also have "... = ( i  {0 ..< n}. (B $$ (i,i) * B $$ (i,i)))"
  proof (rule sum.cong)
    fix i
    assume "i  {0..<n}"
    hence "i < n" by simp
    thus "(B*B) $$ (i,i) = B $$(i,i) * B$$(i,i)" using diagonal_mat_sq_index
      by (metis assms(1) assms(2) unitary_diag_carrier(1) 
          unitary_diag_diagonal) 
  qed simp
  finally show ?thesis .
qed


lemma positive_square_trace:
  assumes "A  carrier_mat n n"
  and "Complex_Matrix.trace A = (1::real)"
  and "Complex_Matrix.trace (A*A) = 1"
  and "real_diag_decomp A B U"
  and "Complex_Matrix.positive A"
  and "0 < n" (*A retirer?*)
shows "j<n. B $$ (j,j) = 1  (i<n. ij  B $$ (i,i) = 0)"
proof -
  have b: "i<n. 0  B $$ (i, i)" using assms positive_unitary_diag_pos
    by (meson real_diag_decomp A B U real_diag_decompD(1))
  also have t: "Complex_Matrix.trace B  = (1::real)" 
    using assms
    by (metis real_diag_decomp A B U of_real_1 real_diag_decompD(1) 
        unitarily_equiv_trace unitary_diag_imp_unitarily_equiv)
  have t_sq: "(i{0..<n}. (B $$ (i,i) * B $$ (i,i))) = 1" 
    using assms unitary_diag_carrier squared_A_trace'
    by (smt (verit, ccfv_SIG) real_diag_decomp A B U real_diag_decompD(1) sum.cong)
  have dim_n: "dim_row B = n" using assms
      by (meson real_diag_decomp A B U carrier_matD(1) 
          real_diag_decompD(1) unitary_diag_carrier(1))
  have ex_j: "j{0..<n}.  (B $$ (j, j)) = 1"
  proof (rule cpx_pos_square_1_elem)
    show "finite {0..<n}" by simp
    show "i. i  {0..<n}  0  B $$ (i, i)" using b by simp
    show "(j  {0..<n}. B $$ (j, j)) = 1" using t 
      unfolding Complex_Matrix.trace_def
      by (metis dim_row B = n of_real_hom.hom_one)
    show "(x = 0..<n. B $$ (x, x) * B $$ (x, x)) = 1" using t_sq
      by blast
  qed
  from this obtain j where jn: "j{0..<n}" and bj: "B $$ (j, j) = 1" by auto
  have "k  ({0..<n}-{j}). B $$ (k, k) = 0"
  proof (rule cpx_sum_eq_elmt)
    show "finite {0..<n}" by simp
    show "i. i  {0..<n}  0  B $$ (i, i)" using b by simp
    show "(k = 0..<n. B $$ (k, k)) = 1" using t 
      unfolding Complex_Matrix.trace_def
      by (simp add: dim_n) 
    show "j  {0..<n}" using jn by simp
    show "B $$ (j, j) = 1" using bj by simp
  qed
  hence "i<n. i  j  B $$ (i, i) = 0"
    using atLeastLessThan_iff by blast
  thus ?thesis
    by (metis atLeastLessThan_iff bj jn)
qed

lemma idty_square:
  shows "((1m n):: 'a :: semiring_1 Matrix.mat) * (1m n) = 1m n" 
  using right_mult_one_mat by simp

lemma pos_hermitian_trace_reals:
  fixes A::"complex Matrix.mat"
  assumes "A carrier_mat n n"
  and "B carrier_mat n n"
  and "0 < n"
  and "Complex_Matrix.positive A"
  and "hermitian B"
  shows "Complex_Matrix.trace (B*A)  Reals"
proof -
  define fc::"complex Matrix.mat set" where "fc = carrier_mat n n"
  interpret cpx_sq_mat n n fc  
  proof 
    show "0 < n" using assms by simp
  qed (auto simp add: fc_def)
  have "Complex_Matrix.trace (B*A) = Complex_Matrix.trace (A*B)" using assms
    by (metis trace_comm)
  also have "... = Re (Complex_Matrix.trace (A * B))" 
  proof (rule trace_hermitian_pos_real[of B A])
    show "hermitian B" using assms by simp
    show "A fc" using assms unfolding fc_def by simp
    show "B fc" using assms unfolding fc_def by simp
    show "Complex_Matrix.positive A" using assms by simp
  qed
  finally have "Complex_Matrix.trace (B*A) =
    Re (Complex_Matrix.trace (A * B))" .
  thus ?thesis by (metis Reals_of_real) 
qed

lemma pos_hermitian_trace_reals':
  fixes A::"complex Matrix.mat"
  assumes "A carrier_mat n n"
  and "B carrier_mat n n"
  and "0 < n"
  and "Complex_Matrix.positive A"
  and "hermitian B"
  shows "Complex_Matrix.trace (A*B)  Reals"
  by (metis assms pos_hermitian_trace_reals trace_comm)

lemma hermitian_commute:
  assumes "hermitian A"
  and "hermitian B"
  and "A*B = B*A"
shows "hermitian (A*B)"
  by (metis adjoint_mult assms hermitian_def 
      hermitian_square index_mult_mat(2))


lemma idty_unitary_diag:
  assumes "unitary_diag (1m n) B U"
  shows "B = 1m n"
proof -
  have l: "(Complex_Matrix.adjoint U) * U = 1m n"
    using assms one_carrier_mat similar_mat_witD2(2) unitary_diagD(1) by blast
  have r: "(Complex_Matrix.adjoint U) * U = 1m n"
    by (simp add: l)
  hence "B = ((Complex_Matrix.adjoint U) * U) * B * 
    ((Complex_Matrix.adjoint U) * U)" using l r
    by (metis assms index_one_mat(2) left_mult_one_mat' right_mult_one_mat 
        similar_mat_witD(5) similar_mat_wit_dim_row unitary_diagD(1))
  also have "... = (Complex_Matrix.adjoint U) * 
    (U * B * (Complex_Matrix.adjoint U)) * U"
    by (metis assms calculation similar_mat_witD(3) similar_mat_wit_sym 
        unitary_diagD(1))
  also have "... = (Complex_Matrix.adjoint U) * (1m n) * U"
    by (metis assms one_carrier_mat similar_mat_witD2(3) unitary_diagD(1))
  also have "... = 1m n"
    by (metis assms index_one_mat(2) l right_mult_one_mat similar_mat_witD(7) 
        unitary_diagD(1))
  finally show ?thesis .
qed

lemma diag_mat_idty:
  assumes "0 < n"
  shows "set (diag_mat ((1m n)::'a::{one,zero} Matrix.mat)) = {1}" 
    (is "?L = ?R")
proof
  show "?L  ?R"
  proof
    fix x::'a
    assume "x  set (diag_mat (1m n))"
    hence "i < length (diag_mat (1m n)). nth (diag_mat (1m n))  i = x" 
      using in_set_conv_nth[of x "diag_mat (1m n)"] assms by simp
    from this obtain i where "i < length (diag_mat (1m n))" 
      and "nth (diag_mat (1m n))  i = x"
      by auto note iprop = this
    hence "i < dim_row (1m n)" unfolding diag_mat_def by simp
    hence "i < n" using assms by simp
    have "x = (1m n)$$(i,i)" using iprop unfolding diag_mat_def by simp
    thus "x  ?R" using i < n by simp
  qed
next
  show "?R  ?L"
  proof
    fix x
    assume "x ?R"
    hence "x = 1" by simp
    also have "... = (1m n)$$(0,0)" using assms by simp
    also have "...  ?L" using assms unfolding diag_mat_def by simp
    finally show "x ?L" .
  qed
qed

lemma idty_spectrum:
assumes "0 < n"
shows "spectrum ((1m n)::complex Matrix.mat) = {1}"
proof -
  have "spectrum ((1m n)::complex Matrix.mat) = set (diag_mat (1m n))"
    using similar_spectrum_eq
    by (meson one_carrier_mat similar_mat_refl upper_triangular_one)
  also have "... = {1}" using diag_mat_idty assms by simp
  finally show ?thesis .
qed

lemma spectrum_ne:
  fixes A::"complex Matrix.mat"
  assumes "A  carrier_mat n n"
  and "0 < n"
shows "spectrum A  {}" unfolding spectrum_def 
  using eigvals_poly_length[of A] assms by auto

lemma  unitary_diag_square_spectrum:
  fixes A::"complex Matrix.mat"
  assumes "hermitian A"
  and "A carrier_mat n n"
and "unitary_diag A B U"
shows "spectrum (A*A) = set (diag_mat (B*B))"
proof -
  have sa: "similar_mat (A*A) (B*B)" 
    using assms hermitian_square_similar_mat_wit[of A n] 
    unfolding similar_mat_def by auto
  have  "diagonal_mat (B*B)" using diagonal_mat_sq_diag[of B] assms
    by (meson unitary_diag_carrier(1) unitary_diag_diagonal) 
  have "(aeigvals (A*A). [:- a, 1:]) = char_poly (A*A)" using assms
    by (metis eigvals_poly_length mult_carrier_mat) 
  also have "... = char_poly (B*B)" using char_poly_similar[OF sa] by simp
  also have "... = (adiag_mat (B*B). [:- a, 1:])" using  
      diagonal_mat (B*B)
    by (metis assms(2) assms(3) char_poly_upper_triangular 
        diagonal_imp_upper_triangular mult_carrier_mat 
        unitary_diag_carrier(1)) 
  finally  have  "(aeigvals (A*A). [:- a, 1:]) = 
    (adiag_mat (B*B). [:- a, 1:])" . 
  hence "set (eigvals (A*A)) = set (diag_mat (B*B))" 
    using poly_root_set_eq[of "eigvals (A*A)"] by simp
  thus ?thesis unfolding spectrum_def by simp
qed

lemma diag_mat_square_eq:
  fixes B::"'a::{ring} Matrix.mat"
  assumes "diagonal_mat B"
  and "B  carrier_mat n n"
  shows "set (diag_mat (B*B)) = {b*b|b. b set (diag_mat B)}"
proof
  show "set (diag_mat (B * B))  {b*b |b. b  set (diag_mat B)}"
  proof
    fix x
    assume "x  set (diag_mat (B * B))"
    hence "i < length (diag_mat (B * B)). nth (diag_mat (B * B))  i = x" 
      using in_set_conv_nth[of x] by simp
    from this obtain i where "i < length (diag_mat (B * B))" 
      and "nth (diag_mat (B * B))  i = x"
      by auto note iprop = this
    hence "i < n" using assms unfolding diag_mat_def by simp
    have "(B*B) $$ (i,i) = x" using iprop 
      unfolding diag_mat_def by simp
    hence "B $$ (i,i)* B $$ (i,i) = x" 
      using diagonal_mat_sq_index[of B n i i] assms iprop i < n 
      by simp
    moreover have "B $$ (i,i)  set (diag_mat B)" 
      using i < n assms in_set_conv_nth[of x] 
      unfolding diag_mat_def by auto
    ultimately show "x  {b*b |b. b  set (diag_mat B)}" by auto
  qed
next
  show "{b * b |b. b  set (diag_mat B)}  set (diag_mat (B * B))"
  proof
    fix x
    assume "x  {b * b |b. b  set (diag_mat B)}"
    hence "b set (diag_mat B). x = b * b" by auto
    from this obtain b where "b set (diag_mat B)" and "x = b * b" by auto
    hence " i < length (diag_mat B). (diag_mat B)!i = b" 
      using in_set_conv_nth[of b] by simp
    from this obtain i where "i < length (diag_mat B)" 
      and "(diag_mat B) ! i = b" by auto
    note iprop = this
    hence "B $$ (i,i) = b" unfolding diag_mat_def by simp
    moreover have "i < n" using assms iprop unfolding diag_mat_def by simp
    ultimately have "(B*B) $$ (i,i) = x" 
      using x = b*b diagonal_mat_sq_index[of B n i i] assms iprop by simp
    hence "x = (diag_mat (B*B)) ! i" using i < n assms 
      unfolding diag_mat_def by fastforce
    moreover have "i<length (diag_mat (B * B))" 
      using i < n assms unfolding diag_mat_def by auto
    ultimately show "x  set (diag_mat (B * B))" 
      using in_set_conv_nth[of x "diag_mat (B*B)"] 
      by simp
  qed
qed

lemma hermitian_square_spectrum_eq:
  fixes A::"complex Matrix.mat"
  assumes "hermitian A"
and "A carrier_mat n n"
and "0 < n"
shows "spectrum (A*A) = {a*a | a. a spectrum A}"
proof -
  obtain B U where herm: "real_diag_decomp A B U" 
    using hermitian_real_diag_decomp[of A] assms by auto
  hence "spectrum (A*A) = set (diag_mat (B*B))" 
    using unitary_diag_square_spectrum assms real_diag_decompD(1)  by blast
  also have "... = {a*a|a. a set (diag_mat B)}" 
    using diag_mat_square_eq[of B] assms herm
    by (meson real_diag_decompD(1) unitary_diagD(2) unitary_diag_carrier(1))
  also have "... = {a*a | a. a spectrum A}" 
    using assms herm real_diag_decompD(1) spectrum_def unitary_diag_spectrum_eq 
    by blast
  finally show ?thesis .
qed

lemma adjoint_uminus:
  shows "Complex_Matrix.adjoint (-A) = - (Complex_Matrix.adjoint A)"
proof (rule eq_matI)
  fix i j
  assume "i < dim_row (- Complex_Matrix.adjoint A)" and 
    "j < dim_col (- Complex_Matrix.adjoint A)"
  thus "Complex_Matrix.adjoint (- A) $$ (i, j) = 
    (- Complex_Matrix.adjoint A) $$ (i, j)"
    by (simp add: adjoint_eval conjugate_neg)
qed auto

lemma (in fixed_carrier_mat) sum_mat_zero:
  assumes "finite I"
  and "i. i  I  A i  fc_mats"
  and "i. i I  f i = 0"
shows "sum_mat (λ i. (f i)  m (A i)) I = 0m dimR dimC" using assms
proof (induct rule: finite_induct)
  case empty
  then show ?case using sum_mat_empty by simp
next
  case (insert j F)
  hence "sum_mat (λi. f i m A i) (insert j F) = f j m A j + 
    sum_mat (λi. f i m A i) F" 
    using sum_mat_insert
    by (smt (verit, best) Set.basic_monos(7) image_subsetI insertI1 
        smult_mem subset_insertI)
  also have "... = 0m dimR dimC + sum_mat (λi. f i m A i) F" 
    using insert smult_zero[of "A j"] fc_mats_carrier by force
  also have "... = 0m dimR dimC + 0m dimR dimC" using insert by simp
  finally show ?case by simp
qed

lemma (in fixed_carrier_mat) sum_mat_zero':
  fixes A::"'b  'a Matrix.mat"
  assumes "finite I"
  and "i. i  I  A i = 0m dimR dimC"
shows "sum_mat A I = 0m dimR dimC" using assms
proof (induct rule: finite_induct)
  case empty
  then show ?case using sum_mat_empty by simp
next
  case (insert j F)
  have "sum_mat A (insert j F) =  A j + sum_mat A F" using sum_mat_insert
    by (metis Set.basic_monos(7) image_subsetI insertI1 insert(1) 
        insert(2) insert(4) subset_insertI zero_mem) 
  also have "... = 0m dimR dimC + sum_mat A F" 
    using insert by simp
  also have "... = 0m dimR dimC + 0m dimR dimC" using insert by simp
  finally show ?case by simp
qed

lemma (in fixed_carrier_mat) sum_mat_remove:
  assumes "A ` I  fc_mats"
    and A: "finite I" and x: "x  I"
  shows "sum_mat A I = A x + sum_mat A (I-{x})" unfolding sum_mat_def
  using assms sum_with_insert[of A x "I-{x}"] insert_Diff by fastforce

lemma (in fixed_carrier_mat) sum_mat_singleton:
  fixes A::"'b  'a Matrix.mat"
  assumes "finite I"
  and "A ` I  fc_mats"
  and "j  I"
  and "iI. i  j  f i = 0"
shows "sum_mat (λ i. (f i)  m (A i)) I = f j  m (A j)"
proof -
  have "sum_mat (λ i. (f i)  m (A i)) I = f j  m (A j) + 
    sum_mat (λ i. (f i)  m (A i)) (I-{j})" using sum_mat_remove
    by (metis (no_types, lifting) assms(1) assms(2) assms(3) 
        image_subset_iff smult_mem)
  moreover have "sum_mat (λ i. (f i)  m (A i)) (I-{j}) = 0m dimR dimC"
  proof (rule sum_mat_zero)
    show "i. i  I - {j}   A i  fc_mats" using assms by auto
  qed (auto simp add: assms)
  ultimately show "sum_mat (λ i. (f i)  m (A i)) I = f j  m (A j)"
    by (metis Matrix.right_add_zero_mat assms(2) assms(3) fc_mats_carrier 
        image_subset_iff smult_mem)
qed

context fixed_carrier_mat 
begin
lemma sum_mat_disj_union:
  assumes "finite J"
  and "finite I"
  and "I  J = {}"
  and " i  I  J. A i  fc_mats"
shows "sum_mat A (I  J) = sum_mat A I + sum_mat A J" using assms
proof (induct rule: finite_induct)
  case empty  
  then show ?case
    by (simp add: sum_mat_carrier)
next
  case (insert x F)
  have "sum_mat A (I  (insert x F)) = sum_mat A (insert x (I  F))" by simp
  also have "... = A x + sum_mat A (I  F)" 
  proof (rule sum_mat_insert)
    show "A x  fc_mats" by (simp add: local.insert(6))
    show "A ` (I  F)  fc_mats" using local.insert(6) by force
    show "finite (I F)" using insert by simp
    show "x  I  F" using insert by auto 
  qed
  also have "... = A x + sum_mat A I + sum_mat A F" using insert
    by (simp add: add_assoc fc_mats_carrier sum_mat_carrier)
  also have "... = sum_mat A I + sum_mat A (insert x F)"
  proof -
    have "A x + sum_mat A F = sum_mat A (insert x F)"
      by (simp add: insert.prems(3) local.insert(1) local.insert(2) 
          subset_eq sum_mat_insert) 
    thus ?thesis
      by (metis Un_iff add_assoc add_commute fc_mats_carrier 
          insertCI local.insert(6) sum_mat_carrier)
  qed
  finally show ?case .
qed

lemma sum_with_reindex_cong':
  fixes g :: "'c  'a Matrix.mat"
  assumes "x. g x  fc_mats"
  and "x. h x  fc_mats"
  and "inj_on l B"
  and "x. x  B  g (l x) = h x"
  shows "sum_with (+) (0m dimR dimC) g (l ` B) = 
  sum_with (+) (0m dimR dimC) h B" 
  by (rule sum_with_reindex_cong, (simp add: assms)+)

lemma sum_mat_cong':
  shows "finite I  (i. i I  A i = B i)  
    (i. i I  A i  fc_mats)  
    (i. i I  B i  fc_mats)  I = J  sum_mat A I = sum_mat B J"
proof (induct arbitrary: J rule: finite_induct)
  case empty
  then show ?case by simp
next
  case (insert x F)
  have "sum_mat A (insert x F) = A x + sum_mat A F" 
    using insert sum_mat_insert[of A]
    by (meson image_subsetI insert_iff)
  also have "... = B x + sum_mat B F" using insert by force
  also have "... = sum_mat B (insert x F)" using insert sum_mat_insert[of B]
    by (metis image_subsetI insert_iff)
  also have "... = sum_mat B J" using insert by simp
  finally show ?case .
qed

lemma sum_mat_reindex_cong:
  assumes "finite B"
  and "x. x  l` B  g x  fc_mats"
  and "x. x  B  h x  fc_mats"
  and "inj_on l B"
  and "x. x  B  g (l x) = h x"
  shows "sum_mat g (l ` B) = sum_mat h B"
proof -
  define gp where "gp = (λi. if i l`B then g i else (0m dimR dimC))"
  define hp where "hp = (λi. if i  B then h i else (0m dimR dimC))"
  have "sum_mat g (l`B) = sum_mat gp (l`B)" 
  proof (rule sum_mat_cong')
    show "i. i  l ` B  g i = gp i" unfolding gp_def by auto
    show "i. i  l ` B  g i  fc_mats" using assms by simp
    show "i. i  l ` B  gp i  fc_mats" unfolding gp_def using assms by auto
  qed (simp add: assms)+
  also have "... = sum_mat hp B" unfolding sum_mat_def
  proof (rule sum_with_reindex_cong')
    show "x. gp x  fc_mats" unfolding gp_def using assms
      by (simp add: zero_mem)
    show "x. hp x  fc_mats" unfolding hp_def using assms 
      by (simp add: zero_mem)
    show "x. x  B  gp (l x) = hp x"
      by (simp add: assms(5) gp_def hp_def)
  qed (simp add: assms)
  also have "... = sum_mat h B"
  proof (rule sum_mat_cong')
    show "i. i  B  hp i = h i" unfolding hp_def by auto
    show "i. i  B  hp i  fc_mats" unfolding hp_def using assms by auto
    show "i. i  B  h i  fc_mats" using assms by simp
  qed (simp add: assms)+
  finally show ?thesis .
qed

lemma sum_mat_mod_eq:
  fixes A :: "nat  'a Matrix.mat"
  assumes "x. x  {..<m}  A x  fc_mats"
shows "sum_mat (λi. A (i mod m)) ((λi. n * m+i)`{..< m}) = sum_mat A {..<m}" 
proof (rule sum_mat_reindex_cong)
  show "x. x  {..<m}  A ((n * m + x) mod m) = A x" by simp
  show "inj_on ((+) (n * m)) {..<m}" by simp
  show "x. x  (+) (n * m) ` {..<m}  A (x mod m)  fc_mats" 
    using assms by force
qed (simp add: assms)+

lemma sum_mat_singleton':
  assumes "A i  fc_mats"
  shows "sum_mat A {i} = A i"
  by (metis add_zero assms comm_add_mat empty_iff fc_mats_carrier 
      finite.intros(1) image_is_empty subsetI sum_mat_empty sum_mat_insert 
      zero_mem)

end

context cpx_sq_mat
begin

lemma sum_mat_mod_div_ne_0:
  assumes "k. k < (nC::nat)  A k  carrier_mat n n"
  and "j. j < (nD::nat)  B j  carrier_mat m m"
  and "0 < n"
  and "0 < m"
  and "dimR = n *m"
  and "nD  0"
shows "sum_mat (λi. sum_mat (λj. f i * g jm ((A i)  (B j))) {..< nD}) 
  {..< nC} = 
  sum_mat (λi. (f (i div nD) * g (i mod nD))m 
  ((A (i div nD))  (B (i mod nD)))) {..< nC*nD}" 
proof -
  define D where "D = (λi. sum_mat (λj. f i*g jm ((A i)  (B j))) {..< nD})"
  have fc: "fc_mats = carrier_mat (n*m) (n*m)" 
    using assms fc_mats_carrier dim_eq 
    by simp
  show ?thesis using  assms
  proof (induct nC)
    case 0
    define C where "C = sum_mat D {..< (0::nat)}"
    have "C =  0m (n*m) (n*m)" unfolding C_def 
      using sum_mat_empty assms dim_eq
      by (simp add: fixed_carrier_mat_def)
    moreover have "sum_mat (λi. (f (i div nD) * g (i mod nD))m 
      ((A (i div nD))  (B (i mod nD)))) {..< 0*nD} = 0m (n*m) (n*m)" 
      using sum_mat_empty assms dim_eq
      by (simp add: fixed_carrier_mat_def)
    ultimately show ?case unfolding C_def by simp
  next
    case (Suc nC)
    define C where "C = sum_mat (λi. (f (i div nD) * g (i mod nD))m 
      ((A (i div nD))  (B (i mod nD)))) {..< nC*nD}"
    have dm: "i. i  {..<Suc nC}  D i  fc_mats"
    proof -
      fix i
      assume "i  {..<Suc nC}"
      hence "A i  carrier_mat n n" using Suc by simp
      hence "j. j {..< nD}  B j  carrier_mat m m" using Suc
        by simp
      hence "j. j {..< nD}  A i  B j  fc_mats" 
        using fc A i  carrier_mat n n tensor_mat_carrier
        by (metis carrier_matD(1) carrier_matD(2))
      thus "D i  fc_mats" unfolding D_def
        by (metis (mono_tags, lifting) cpx_sq_mat_smult fc_mats_carrier 
            sum_mat_carrier)
    qed
    have "sum_mat D {..< Suc nC} = sum_mat D ({..< nC}  {nC..< Suc nC})" 
    proof -
      have "{..< Suc nC} = {..< nC}  {nC..< Suc nC}" by auto
      thus ?thesis by simp
    qed
    also have "... = sum_mat D {..< nC} + sum_mat D {nC..< Suc nC}"
    proof (rule sum_mat_disj_union)
      show "i{..<nC}  {nC..<Suc nC}. D i  fc_mats" using dm by auto
    qed auto
    also have "... = C + sum_mat D {nC..< Suc nC}" 
      using Suc unfolding C_def D_def by simp
    also have "... = C + (sum_mat (λi. (f (i div nD) * g (i mod nD))m 
      ((A (i div nD))  (B (i mod nD)))) {nC*nD..< Suc nC*nD})"
    proof -
      have "sum_mat D {nC..< Suc nC} = sum_mat D {nC}" by simp
      also have "... = D nC"  using  dm
        by (simp add: sum_mat_singleton')
      also have "... = (sum_mat (λi. (f nC * g (i mod nD))m 
        ((A nC)  (B (i mod nD)))) ((+) (nC * nD) ` {..<nD}))" 
        unfolding D_def 
      proof (rule sum_mat_mod_eq[symmetric])
        show "x. x  {..<nD}  f nC * g x m (A nC  B x)  fc_mats" 
        proof -
          fix x
          assume "x {..< nD}"
          hence "B x  carrier_mat m m" using Suc by simp
          have "A nC  carrier_mat n n" using Suc by simp
          hence  "A nC  B x  fc_mats" 
            using fc tensor_mat_carrier  B x  carrier_mat m m by blast
          thus "f nC * g x m (A nC  B x)  fc_mats"
            by (simp add: cpx_sq_mat_smult)
        qed
      qed
      also have "... = sum_mat (λi. (f (i div nD) * g (i mod nD))m 
        ((A (i div nD))  (B (i mod nD)))) {nC*nD..< Suc nC*nD}" 
      proof (rule sum_mat_cong')
        show "(+) (nC * nD) ` {..<nD} = {nC * nD..<Suc nC * nD}"
          by (simp add: lessThan_atLeast0) 
        show "i. i  (+) (nC * nD) ` {..<nD}  
          f nC * g (i mod nD) m (A nC  B (i mod nD))  fc_mats"
        proof -
          fix i
          assume "i  (+) (nC * nD) ` {..<nD}"
          hence "i mod nD < nD" using  assms mod_less_divisor by blast 
          hence "B (i mod nD)  carrier_mat m m" using Suc by simp
          moreover have "A nC  carrier_mat n n" using Suc by simp
          ultimately have "A nC  B (i mod nD)  fc_mats" 
            using fc tensor_mat_carrier by blast
          thus "f nC * g (i mod nD) m (A nC  B (i mod nD))  fc_mats"
            by (simp add: cpx_sq_mat_smult)
        qed
        show "i. i  (+) (nC * nD) ` {..<nD}  
          f (i div nD) * g (i mod nD) m (A (i div nD)  B (i mod nD))  
          fc_mats"
        proof -
          fix i
          assume "i  (+) (nC * nD) ` {..<nD}"
          hence "i div nD = nC" using Suc(2) mod_less_divisor
            by (metis (+) (nC * nD) ` {..<nD} = {nC * nD..<Suc nC * nD} 
                index_div_eq semiring_norm(174))
          have "i mod nD < nD" using i  (+) (nC * nD) ` {..<nD}  
              mod_less_divisor assms by blast
          hence "B (i mod nD)  carrier_mat m m" using Suc by simp
          moreover have "A (i div nD)  carrier_mat n n" 
            using i div nD = nC Suc by simp
          ultimately have "A (i div nD)  B (i mod nD)  fc_mats" 
            using fc tensor_mat_carrier by blast
          thus "f (i div nD) * g (i mod nD) m (A (i div nD)  B (i mod nD)) 
             fc_mats"
            by (simp add: cpx_sq_mat_smult)
        qed
      qed auto
      finally have "sum_mat D {nC..< Suc nC} = 
        sum_mat (λi. (f (i div nD) * g (i mod nD))m 
        ((A (i div nD))  (B (i mod nD)))) {nC*nD..< Suc nC*nD}" .
      thus ?thesis by simp
    qed
    also have "... =  
      sum_mat (λi. f (i div nD)*g (i mod nD) m (A (i div nD)  B (i mod nD)))
      ({..< nC * nD}  {nC * nD..<Suc nC * nD})" unfolding C_def 
    proof (rule sum_mat_disj_union[symmetric])
      show "i{..<nC * nD}  {nC * nD..<Suc nC * nD}.
       f (i div nD) * g (i mod nD) m (A (i div nD)  B (i mod nD))  fc_mats" 
      proof
        fix i
        assume "i  {..<nC * nD}  {nC * nD..<Suc nC * nD}"
        hence "i  {..< Suc nC * nD}" by auto
        hence "i div nD < Suc nC" using Suc(2) mod_less_divisor
          by (simp add: less_mult_imp_div_less)
        have "i mod nD < nD" using i  {..<nC * nD}  {nC * nD..<Suc nC * nD}
          Suc(2) mod_less_divisor assms by blast
        hence "B (i mod nD)  carrier_mat m m" using Suc by simp
        moreover have "A (i div nD)  carrier_mat n n" 
          using i div nD < Suc nC Suc by simp
        ultimately have "A (i div nD)  B (i mod nD)  fc_mats" 
          using fc tensor_mat_carrier by blast
        thus "f (i div nD) * g (i mod nD) m (A (i div nD)  B (i mod nD)) 
           fc_mats"
          by (simp add: cpx_sq_mat_smult)
      qed
    qed auto
    also have "... = 
      sum_mat (λi. f (i div nD)*g (i mod nD) m (A (i div nD)  B (i mod nD)))
      {..< Suc nC * nD}"
    proof -
      have "{..< nC * nD}  {nC * nD..<Suc nC * nD} = {..< Suc nC *  nD}" 
        by auto
      thus ?thesis by simp
    qed
    finally show ?case unfolding D_def .
  qed
qed

lemma sum_mat_mod_div_eq_0:
  assumes "k. k < (nC::nat)  A k  carrier_mat n n"
  and "0 < n"
  and "nD = 0"
  and "dimR = n *m"
shows "sum_mat (λi. sum_mat (λj. f i * g jm ((A i)  (B j))) {..< nD}) 
  {..< nC} = 
  sum_mat (λi. (f (i div nD) * g (i mod nD))m 
  ((A (i div nD))  (B (i mod nD)))) {..< nC*nD}" 
proof-
  have "{..< nC*nD} = {}" using assms by simp
  hence "sum_mat (λi. f (i div nD) * g (i mod nD) m 
    (A (i div nD)  B (i mod nD))) {..<nC * nD} = 0m (n*m) (n*m)"
    using sum_mat_empty assms dim_eq
    by (simp add: fixed_carrier_mat_def)
  moreover have "sum_mat (λi. sum_mat (λj. f i * g j m (A i  B j)) {..<nD}) 
    {..<nC} = 0m dimR dimC" 
  proof (rule sum_mat_zero')
    fix i
    assume "i  {..< nC}"
    show "sum_mat (λj. f i * g j m (A i  B j)) {..<nD} = 0m dimR dimC" 
      using assms sum_mat_empty by simp
  qed simp
  ultimately show ?thesis using assms dim_eq by simp
qed

lemma sum_mat_mod_div:
  assumes "k. k < (nC::nat)  A k  carrier_mat n n"
  and "j. j < (nD::nat)  B j  carrier_mat m m"
  and "0 < n"
  and "0 < m"
  and "dimR = n *m"
shows "sum_mat (λi. sum_mat (λj. f i * g jm ((A i)  (B j))) {..< nD}) 
  {..< nC} = 
  sum_mat (λi. (f (i div nD) * g (i mod nD))m 
  ((A (i div nD))  (B (i mod nD)))) {..< nC*nD}" 
proof (cases "nD = 0")
  case True
  then show ?thesis using sum_mat_mod_div_eq_0 assms by simp
next
  case False
  then show ?thesis using sum_mat_mod_div_ne_0 assms by simp
qed

lemma sum_sum_mat_expand_ne_0:
  assumes "k. k < (nC::nat)  A k  carrier_mat n n"
  and "j. j < (nD::nat)  B j  carrier_mat m m"
  and "R carrier_mat (n*m) (n*m)"
  and "0 < n"
  and "0 < m"
  and "nD  0"
  and "dimR = n *m"
shows "sum_mat (λi. sum_mat (λj. f i * g jm ((A i)  (B j))*R) {..< nD}) 
  {..< nC} = 
  sum_mat (λi. (f (i div nD) * g (i mod nD))m 
  ((A (i div nD))  (B (i mod nD))) * R) {..< nC*nD}" 
proof -
  define D where "D = (λi. sum_mat (λj. f i*g jm ((A i)  (B j)) * R) 
    {..< nD})"
  have fc: "fc_mats = carrier_mat (n*m) (n*m)" 
    using assms fc_mats_carrier dim_eq 
    by simp
  show ?thesis using  assms
  proof (induct nC)
    case 0
    define C where "C = sum_mat D {..< (0::nat)}"
    have "C =  0m (n*m) (n*m)" unfolding C_def 
      using sum_mat_empty assms dim_eq
      by (simp add: fixed_carrier_mat_def)
    moreover have "sum_mat (λi. (f (i div nD) * g (i mod nD))m 
      ((A (i div nD))  (B (i mod nD))) * R) {..< 0*nD} = 0m (n*m) (n*m)" 
      using sum_mat_empty assms dim_eq
      by (simp add: fixed_carrier_mat_def)
    ultimately show ?case unfolding C_def by simp
  next
    case (Suc nC)
    define C where "C = sum_mat (λi. (f (i div nD) * g (i mod nD))m 
      ((A (i div nD))  (B (i mod nD)))* R) {..< nC*nD}"
    have "R fc_mats" using fc_mats_carrier Suc dim_eq by simp
    have dm: "i. i  {..<Suc nC}  D i  fc_mats"
    proof -
      fix i
      assume "i  {..<Suc nC}"
      hence "A i  carrier_mat n n" using Suc by simp
      hence "j. j {..< nD}  B j  carrier_mat m m" using Suc
        by simp
      hence "j. j {..< nD}  A i  B j  fc_mats" 
        using fc A i  carrier_mat n n tensor_mat_carrier
        by (metis carrier_matD(1) carrier_matD(2))
      hence "j. j  {..< nD}  (A i  B j) * R  fc_mats" using Suc fc
        using cpx_sq_mat_mult by blast
      thus "D i  fc_mats" unfolding D_def
        by (metis (mono_tags, lifting) R  fc_mats 
            j. j  {..<nD}  A i  B j  fc_mats cpx_sq_mat_mult 
            cpx_sq_mat_smult fc_mats_carrier sum_mat_carrier)
    qed
    have "sum_mat D {..< Suc nC} = sum_mat D ({..< nC}  {nC..< Suc nC})" 
    proof -
      have "{..< Suc nC} = {..< nC}  {nC..< Suc nC}" by auto
      thus ?thesis by simp
    qed
    also have "... = sum_mat D {..< nC} + sum_mat D {nC..< Suc nC}"
    proof (rule sum_mat_disj_union)
      show "i{..<nC}  {nC..<Suc nC}. D i  fc_mats" using dm by auto
    qed auto
    also have "... = C + sum_mat D {nC..< Suc nC}" 
      using Suc unfolding C_def D_def by simp
    also have "... = C + (sum_mat (λi. (f (i div nD) * g (i mod nD))m 
      ((A (i div nD))  (B (i mod nD))) * R) {nC*nD..< Suc nC*nD})"
    proof -
      have "sum_mat D {nC..< Suc nC} = sum_mat D {nC}" by simp
      also have "... = D nC"  using  dm
        by (simp add: sum_mat_singleton')
      also have "... = (sum_mat (λi. (f nC * g (i mod nD))m 
        ((A nC)  (B (i mod nD))) * R) ((+) (nC * nD) ` {..<nD}))" 
        unfolding D_def 
      proof (rule sum_mat_mod_eq[symmetric])
        show "x. x  {..<nD}  f nC * g x m (A nC  B x)*R  fc_mats" 
        proof -
          fix x
          assume "x {..< nD}"
          hence "B x  carrier_mat m m" using Suc by simp
          have "A nC  carrier_mat n n" using Suc by simp
          hence  "A nC  B x  fc_mats" 
            using fc tensor_mat_carrier  B x  carrier_mat m m by blast
          thus "f nC * g x m (A nC  B x) * R  fc_mats"
            by (simp add: R  fc_mats cpx_sq_mat_mult cpx_sq_mat_smult)
        qed
      qed
      also have "... = sum_mat (λi. (f (i div nD) * g (i mod nD))m 
        ((A (i div nD))  (B (i mod nD))) * R) {nC*nD..< Suc nC*nD}" 
      proof (rule sum_mat_cong')
        show "(+) (nC * nD) ` {..<nD} = {nC * nD..<Suc nC * nD}"
          by (simp add: lessThan_atLeast0) 
        show "i. i  (+) (nC * nD) ` {..<nD}  
          f nC * g (i mod nD) m (A nC  B (i mod nD))*R  fc_mats"
        proof -
          fix i
          assume "i  (+) (nC * nD) ` {..<nD}"
          hence "i mod nD < nD" using Suc mod_less_divisor by blast 
          hence "B (i mod nD)  carrier_mat m m" using Suc by simp
          moreover have "A nC  carrier_mat n n" using Suc by simp
          ultimately have "A nC  B (i mod nD)  fc_mats" 
            using fc tensor_mat_carrier by blast
          thus "f nC * g (i mod nD) m (A nC  B (i mod nD)) * R  fc_mats"
             by (simp add: R  fc_mats cpx_sq_mat_mult cpx_sq_mat_smult)
        qed
        show "i. i  (+) (nC * nD) ` {..<nD}  
          f (i div nD) * g (i mod nD) m (A (i div nD)  B (i mod nD)) * R  
          fc_mats"
        proof -
          fix i
          assume "i  (+) (nC * nD) ` {..<nD}"
          hence "i div nD = nC" using Suc(2) mod_less_divisor
            by (metis (+) (nC * nD) ` {..<nD} = {nC * nD..<Suc nC * nD} 
                index_div_eq semiring_norm(174))
          have "i mod nD < nD" using i  (+) (nC * nD) ` {..<nD}  Suc
              mod_less_divisor by blast
          hence "B (i mod nD)  carrier_mat m m" using Suc by simp
          moreover have "A (i div nD)  carrier_mat n n" 
            using i div nD = nC Suc by simp
          ultimately have "A (i div nD)  B (i mod nD)  fc_mats" 
            using fc tensor_mat_carrier by blast
          thus "f (i div nD) * g (i mod nD) m (A (i div nD)  B (i mod nD))*R 
             fc_mats"
             by (simp add: R  fc_mats cpx_sq_mat_mult cpx_sq_mat_smult)
        qed
      qed auto
      finally have "sum_mat D {nC..< Suc nC} = 
        sum_mat (λi. (f (i div nD) * g (i mod nD))m 
        ((A (i div nD))  (B (i mod nD)))* R) {nC*nD..< Suc nC*nD}" .
      thus ?thesis by simp
    qed
    also have "... =  
      sum_mat (λi. f (i div nD)*g (i mod nD)m(A (i div nD)  B (i mod nD))*R)
      ({..< nC * nD}  {nC * nD..<Suc nC * nD})" unfolding C_def 
    proof (rule sum_mat_disj_union[symmetric])
      show "i{..<nC * nD}  {nC * nD..<Suc nC * nD}.
       f (i div nD) *g (i mod nD) m (A (i div nD)  B (i mod nD))*R  fc_mats" 
      proof
        fix i
        assume "i  {..<nC * nD}  {nC * nD..<Suc nC * nD}"
        hence "i  {..< Suc nC * nD}" by auto
        hence "i div nD < Suc nC" using Suc(2) mod_less_divisor
          by (simp add: less_mult_imp_div_less)
        have "i mod nD < nD" using i  {..<nC * nD}  {nC * nD..<Suc nC * nD}
          Suc mod_less_divisor by blast
        hence "B (i mod nD)  carrier_mat m m" using Suc by simp
        moreover have "A (i div nD)  carrier_mat n n" 
          using i div nD < Suc nC Suc by simp
        ultimately have "A (i div nD)  B (i mod nD)  fc_mats" 
          using fc tensor_mat_carrier by blast
        thus "f (i div nD) * g (i mod nD) m (A (i div nD)  B (i mod nD))*R 
           fc_mats"
           by (simp add: R  fc_mats cpx_sq_mat_mult cpx_sq_mat_smult)
      qed
    qed auto
    also have "... = 
      sum_mat (λi. f (i div nD)*g (i mod nD)m(A (i div nD)  B (i mod nD))*R)
      {..< Suc nC * nD}"
    proof -
      have "{..< nC * nD}  {nC * nD..<Suc nC * nD} = {..< Suc nC *  nD}" 
        by auto
      thus ?thesis by simp
    qed
    finally show ?case unfolding D_def .
  qed
qed

lemma sum_sum_mat_expand_eq_0:
  assumes "k. k < (nC::nat)  A k  carrier_mat n n"
  and "R carrier_mat (n*m) (n*m)"
  and "0 < n"
  and "0 < m"
  and "nD = 0"
  and "dimR = n *m"
shows "sum_mat (λi. sum_mat (λj. f i * g jm ((A i)  (B j))*R) {..< nD}) 
  {..< nC} = 
  sum_mat (λi. (f (i div nD) * g (i mod nD))m 
  ((A (i div nD))  (B (i mod nD))) * R) {..< nC*nD}" 
proof -
  have "{..< nC*nD} = {}" using assms by simp
  hence "sum_mat (λi. f (i div nD) * g (i mod nD) m 
    (A (i div nD)  B (i mod nD))*R) {..<nC * nD} = 0m (n*m) (n*m)"
    using sum_mat_empty assms dim_eq
    by (simp add: fixed_carrier_mat_def)
  moreover have "sum_mat (λi. sum_mat (λj. f i * g j m (A i  B j)*R) 
    {..<nD}) 
    {..<nC} = 0m dimR dimC" 
  proof (rule sum_mat_zero')
    fix i
    assume "i  {..< nC}"
    show "sum_mat (λj. f i * g j m (A i  B j) * R) {..<nD} = 
      0m dimR dimC" 
      using assms sum_mat_empty by simp
  qed simp
  ultimately show ?thesis using assms dim_eq by simp
qed

lemma sum_sum_mat_expand:
  assumes "k. k < (nC::nat)  A k  carrier_mat n n"
  and "j. j < (nD::nat)  B j  carrier_mat m m"
  and "R carrier_mat (n*m) (n*m)"
  and "0 < n"
  and "0 < m"
  and "dimR = n *m"
shows "sum_mat (λi. sum_mat (λj. f i * g jm ((A i)  (B j))*R) {..< nD}) 
  {..< nC} = 
  sum_mat (λi. (f (i div nD) * g (i mod nD))m 
  ((A (i div nD))  (B (i mod nD))) * R) {..< nC*nD}"
proof (cases "nD = 0")
  case True
  then show ?thesis using assms sum_sum_mat_expand_eq_0 by simp
next
  case False
  then show ?thesis using assms sum_sum_mat_expand_ne_0 by simp
qed

end

section ‹Results on tensor products›

lemma tensor_mat_trace:
  assumes "A carrier_mat n n"
  and "B carrier_mat m m"
  and "0 < n"
  and "0 < m"
shows "Complex_Matrix.trace (A  B) = Complex_Matrix.trace A *
  Complex_Matrix.trace B"
proof -
  have "{0 ..< n*m} = {..< n*m}" by auto
  have n: "{0 ..< n} = {..< n}" by auto
  have m: "{0 ..< m} = {..< m}" by auto
  have "Complex_Matrix.trace (A  B) = ( i  {0 ..< n*m}. (A  B) $$ (i,i))"
    unfolding Complex_Matrix.trace_def using tensor_mat_carrier assms by simp
  also have "... = ( i  {..< n*m}. 
    A $$ (i div m, i div m) * B $$ (i mod m, i mod m))"
    using index_tensor_mat' assms {0 ..< n*m} = {..< n*m} by simp
  also have "... =sum (λi. sum (λj. A $$ (i, i) * B $$ (j,j)) {..< m}) {..< n}"
    by (rule sum_nat_div_mod[symmetric])
  also have "... = sum (λi. A $$(i,i)) {..< n}*(sum (λj. B $$ (j,j)) {..< m})" 
    by (rule sum_product[symmetric])
  also have "... = Complex_Matrix.trace A * (Complex_Matrix.trace B)"
    using n m assms unfolding Complex_Matrix.trace_def by simp
  finally show ?thesis .
qed

lemma tensor_vec_inner_prod:
  assumes "u  carrier_vec n"
  and "v  carrier_vec n"
  and "a  carrier_vec n"
  and "b  carrier_vec n"
  and "0 < n"
shows "Complex_Matrix.inner_prod (tensor_vec u v) (tensor_vec a b) =
  Complex_Matrix.inner_prod u a * Complex_Matrix.inner_prod v b"
proof -
  have "{0 ..< n * n} = {..< n*n}" by auto
  have "{0 ..< n} = {..< n}" by auto
  have "Complex_Matrix.inner_prod (tensor_vec u v) (tensor_vec a b) = 
    ( i  {0 ..< n * n}. (vec_index (tensor_vec a b) i) * 
    vec_index (conjugate (tensor_vec u v)) i)" 
    unfolding scalar_prod_def using assms by simp
  also have "... = ( i  {0 ..< n * n}. vec_index a (i div n) * 
    vec_index b (i mod n) * (vec_index (conjugate (tensor_vec u v)) i))" 
  proof -
    have " i < n * n. vec_index (tensor_vec a b) i = vec_index a (i div n) * 
    vec_index b (i mod n)" using assms by simp
    thus ?thesis by auto
  qed
  also have "... = ( i  {0 ..< n * n}. vec_index a (i div n) * 
    vec_index b (i mod n) * (conjugate (vec_index (tensor_vec u v) i)))" 
    using assms by simp
  also have "... = ( i  {0 ..< n * n}. vec_index a (i div n) * 
    vec_index b (i mod n) * (conjugate (vec_index u (i div n) * 
    vec_index v (i mod n))))" 
  proof -
    have " i < n * n. vec_index (tensor_vec u v) i = vec_index u (i div n) * 
    vec_index v (i mod n)" 
      using assms by simp
    thus ?thesis by auto
  qed
  also have "... = ( i  {0 ..< n * n}. vec_index a (i div n) * 
    vec_index b (i mod n) * (conjugate (vec_index u (i div n)) * 
    (conjugate (vec_index v (i mod n)))))" 
    by simp
  also have "... = ( i  {0 ..< n * n}. vec_index a (i div n) * 
    (conjugate (vec_index u (i div n)) * (vec_index b (i mod n) *  
    (conjugate (vec_index v (i mod n))))))"
    by (simp add: ab_semigroup_mult_class.mult_ac(1) 
        vector_space_over_itself.scale_left_commute)
  also have "... = ( i  {..< n * n}. (vec_index a (i div n) * 
    (conjugate (vec_index u (i div n))) * (vec_index b (i mod n) *  
    (conjugate (vec_index v (i mod n))))))"
    using {0 ..< n * n} = {..< n*n}
    by (metis (no_types, lifting) sum.cong vector_space_over_itself.scale_scale)
  also have "... =sum (λi. sum (λj. vec_index a i * conjugate (vec_index u i) *
    (vec_index b j * (conjugate (vec_index v j)))) {..< n}) {..< n}"
    by (rule sum_nat_div_mod[symmetric])
  also have "... = sum (λi. vec_index a i * conjugate (vec_index u i)) {..< n}*
    (sum (λj. vec_index b j * (conjugate (vec_index v j))) {..< n})" 
    by (rule sum_product[symmetric])
  also have "... = Complex_Matrix.inner_prod u a * Complex_Matrix.inner_prod v b" 
  proof -
    have "dim_vec (conjugate u) = n" using assms by simp
    moreover have "dim_vec (conjugate v) = n" using assms by simp
    ultimately show ?thesis using {0 ..< n} = {..< n} 
      unfolding Matrix.scalar_prod_def by simp
  qed
  finally show ?thesis .
qed

lemma tensor_mat_positive:
  assumes "A carrier_mat n n"
  and "B carrier_mat m m"
  and "0 < n"
  and "0 < m"
  and "Complex_Matrix.positive A"
  and "Complex_Matrix.positive B"
shows "Complex_Matrix.positive (A  B)"
proof (rule positive_if_decomp)
  show "A  B  carrier_mat (n*m) (n*m)" using assms by auto
  have "P  carrier_mat n n. P * Complex_Matrix.adjoint P = A" 
    using assms positive_only_if_decomp by simp
  from this obtain P where "P carrier_mat n n" 
    and "P * Complex_Matrix.adjoint P = A" by auto note ppr = this
  have "Q  carrier_mat m m. Q * Complex_Matrix.adjoint Q = B" 
    using assms positive_only_if_decomp by simp
  from this obtain Q where "Q carrier_mat m m" 
    and "Q * Complex_Matrix.adjoint Q = B" by auto note qpr = this
  define M where "M = P  Q"
  have "Complex_Matrix.adjoint M = 
    Complex_Matrix.adjoint P  (Complex_Matrix.adjoint Q)" unfolding M_def 
    using tensor_mat_adjoint ppr qpr assms 
    by blast
  hence "M * Complex_Matrix.adjoint M = 
    (P * Complex_Matrix.adjoint P)  (Q * Complex_Matrix.adjoint Q)"
    using mult_distr_tensor M_def ppr qpr assms by fastforce
  also have "... = A  B" using ppr qpr by simp
  finally have "M * Complex_Matrix.adjoint M = A  B" .
  thus "M. M * Complex_Matrix.adjoint M = A  B" by auto
qed


lemma tensor_mat_square_idty:
  assumes "A * A = 1m n"
  and "B * B = 1m m"
  and "0 < n"
  and "0 < m"
shows "(A  B) * (A  B) = 1m (n*m)"
proof -
  have "(A  B) * (A  B) = A*A  (B*B)" 
  proof (rule mult_distr_tensor[symmetric])
    show a: "dim_col A = dim_row A"
      by (metis assms(1) index_mult_mat(2) index_mult_mat(3) index_one_mat(2) 
          index_one_mat(3))
    show b: "dim_col B = dim_row B"
      by (metis assms(2) index_mult_mat(2) index_mult_mat(3) index_one_mat(2) 
          index_one_mat(3))
    show "0 < dim_col A"
      by (metis a assms(1) assms(3) index_mult_mat(2) index_one_mat(2))
    thus "0 < dim_col A" .
    show "0 < dim_col B"
      by (metis b assms(2) assms(4) index_mult_mat(2) index_one_mat(2))
    thus "0 < dim_col B" .
  qed
  also have "... = 1m n  1m m" using assms by simp
  also have "... = 1m (n*m)" using tensor_mat_id assms by simp
  finally show ?thesis .
qed

lemma tensor_mat_commute:
  assumes "A  carrier_mat n n"
  and "B  carrier_mat m m"
  and "C  carrier_mat n n"
  and "D  carrier_mat m m"
  and "0 < n"
  and "0 < m"
  and "A * C = C * A"
  and "B * D = D * B"
shows "(A  B) * (C  D) = (C  D) * (A  B)"
proof -
  have "(A  B) * (C  D) = (A*C)  (B*D)" using mult_distr_tensor assms
    by (metis carrier_matD(1) carrier_matD(2))
  also have "... = (C*A)  (D*B)" using assms by simp
  also have "... = (C  D) * (A  B)" using mult_distr_tensor assms
    by (metis carrier_matD(1) carrier_matD(2))
  finally show ?thesis .
qed

lemma tensor_mat_mult_id:
  assumes "A carrier_mat n n"
  and "B carrier_mat m m"
  and "0 < n"
  and "0 < m"
shows "(A  1m m) * (1m n  B) = A  B"
proof -
  have "(A  1m m) * (1m n  B) = (A * 1m n)  (1m m * B)" 
    using mult_distr_tensor
    by (metis assms carrier_matD(1) carrier_matD(2) 
        index_one_mat(2) index_one_mat(3))
  also have "... = A  B"
    by (metis assms(1) assms(2) left_mult_one_mat right_mult_one_mat)
  finally show ?thesis .
qed

lemma tensor_mat_trace_mult_distr:
  assumes "A carrier_mat n n"
  and "B carrier_mat m m"
  and "C carrier_mat n n"
  and "D carrier_mat m m"
  and "0 < n"
  and "0 < m"
  shows "Complex_Matrix.trace ((A   B) * (CD)) =
    Complex_Matrix.trace (A * C) * (Complex_Matrix.trace (B * D))" 
proof -
  have "(A   B) * (CD) = (A*C)  (B*D)" using assms mult_distr_tensor by auto
  hence "Complex_Matrix.trace ((A   B) * (CD)) =
    Complex_Matrix.trace ((A*C)  (B*D))" by simp
  also have "... = Complex_Matrix.trace (A * C) * (Complex_Matrix.trace (B * D))"
    by (meson assms mult_carrier_mat tensor_mat_trace) 
  finally show ?thesis .
qed

lemma tensor_mat_diagonal:
  assumes "A carrier_mat n n"
  and "B carrier_mat m m"
  and "diagonal_mat A"
  and "diagonal_mat B"
shows "diagonal_mat (A  B)" unfolding diagonal_mat_def
proof (intro allI impI)
  fix i j
  assume "i < dim_row (A  B)"
  and "j < dim_col (A  B)"
  and "i j"
  have "A  B  carrier_mat (n*m) (n*m)" 
    using assms tensor_mat_carrier by blast
  hence "i < n * m"
    by (metis i < dim_row (A  B) carrier_matD(1))
  have "j < n* m"
    using A  B  carrier_mat (n * m) (n * m) j < dim_col (A  B) by auto 
  have "(A  B) $$ (i, j) = A $$ (i div (dim_row B), j div (dim_col B)) * 
    B $$ (i mod (dim_row B), j mod (dim_col B))" using index_tensor_mat'
    by (metis i < dim_row (A  B) j < dim_col (A  B) dim_col_tensor_mat 
        dim_row_tensor_mat less_nat_zero_code neq0_conv semiring_norm(63) 
        semiring_norm(64))
  also have "... = 0" 
  proof (cases "i div (dim_row B) = j div (dim_col B)")
    case True
    have "i div (dim_row B) < n" using assms i < n * m
      by (metis carrier_matD(1) less_mult_imp_div_less)
    moreover have "j div (dim_row B) < n" using assms j < n * m
      by (metis carrier_matD(1) less_mult_imp_div_less)
    ultimately have "(i mod (dim_row B)  j mod (dim_col B))" using i  j
      by (metis True assms(2) carrier_matD(1) carrier_matD(2) mod_div_decomp)
    then show ?thesis using assms unfolding diagonal_mat_def
      by (metis i < n * m carrier_matD(1) carrier_matD(2) gr_zeroI 
          mod_less_divisor mult.commute semiring_norm(63) zero_order(3))
  next
    case False
    have "i div (dim_row B) < n" using assms i < n * m
      by (metis carrier_matD(1) less_mult_imp_div_less)
    moreover have "j div (dim_row B) < n" using assms j < n * m
      by (metis carrier_matD(1) less_mult_imp_div_less)
    ultimately show ?thesis using assms unfolding diagonal_mat_def
      by (metis False carrier_matD(1) carrier_matD(2) semiring_norm(63))
  qed
  finally show "(A  B) $$ (i, j) = 0" .
qed


lemma tensor_mat_add_right:
  assumes "A carrier_mat n m"
  and "B carrier_mat i j"
  and "C carrier_mat i j"
  and "0 < m"
  and "0 < j"
shows "A  (B + C) = (A  B) + (A  C)"
proof (rule eq_matI)
  have "B + C  carrier_mat i j" using assms by simp
  hence bc: "A  (B + C)  carrier_mat (n * i) (m * j)" 
    using assms tensor_mat_carrier
    by (metis carrier_matD(1) carrier_matD(2))
  have  "A  B  carrier_mat (n * i) (m * j)" 
    using assms tensor_mat_carrier
    by (metis carrier_matD(1) carrier_matD(2))
  moreover have  "A  C  carrier_mat (n * i) (m * j)" 
    using assms tensor_mat_carrier
    by (metis carrier_matD(1) carrier_matD(2))
  ultimately have a: "(A  B) + (A  C)  carrier_mat (n * i) (m * j)" 
    by simp
  thus dr: "dim_row (A  (B + C)) = dim_row ((A  B) + (A  C))" 
    using bc by simp
  show dc: "dim_col (A  B + C) = dim_col ((A  B) + (A  C))" 
    using a bc by simp
  fix k l
  assume "k < dim_row ((A  B) + (A  C))"
  and "l < dim_col ((A  B) + (A  C))"
  hence "(A  B + C) $$ (k, l) = 
    A $$ (k div dim_row (B + C), l div dim_col (B + C)) * 
    (B + C) $$ (k mod dim_row (B + C), l mod dim_col (B + C))" 
    using index_tensor_mat'
    by (metis B + C  carrier_mat i j dc dr assms(1) assms(4) assms(5) bc 
        carrier_matD(1) carrier_matD(2))
  also have "... = A $$ (k div dim_row (B + C), l div dim_col (B + C)) * 
    (B $$ (k mod dim_row (B + C), l mod dim_col (B + C)) +
    C $$ (k mod dim_row (B + C), l mod dim_col (B + C)))"
    by (metis div_eq_0_iff B + C  carrier_mat i j 
        k < dim_row ((A  B) + (A  C)) assms(3) assms(5) bc carrier_matD(1) 
        carrier_matD(2) dr index_add_mat(1) less_nat_zero_code mod_div_trivial 
        mult_not_zero)
  also have "... = A $$ (k div dim_row (B + C), l div dim_col (B + C)) * 
    B $$ (k mod dim_row (B + C), l mod dim_col (B + C)) + 
    A $$ (k div dim_row (B + C), l div dim_col (B + C)) * 
    C $$ (k mod dim_row (B + C), l mod dim_col (B + C))"
    using distrib_left by blast
  also have "... = (A  B) $$ (k,l) + (A  C) $$ (k,l)"
    using k < dim_row ((A  B) + (A  C)) l < dim_col ((A  B) + (A  C)) 
      assms by force
  also have "... = ((A  B) + (A  C)) $$ (k,l)"
    using k < dim_row ((A  B) + (A  C)) l < dim_col ((A  B) + (A  C)) 
    by force
  finally show "(A  B + C) $$ (k, l) = ((A  B) + (A  C)) $$ (k,l)" .
qed

lemma tensor_mat_zero:
  assumes "B  carrier_mat i j"
  and "0 < j"
  and "0 < m"
shows "0m n m  B = 0m (n * i) (m * j)"
proof (rule eq_matI)
  show "dim_row (0m n m  B) = dim_row (0m (n * i) (m * j))" 
    using assms by simp
  show "dim_col (0m n m  B) = dim_col (0m (n * i) (m * j))" 
    using assms by simp
  fix k l
  assume "k < dim_row (0m (n * i) (m * j))" 
    and "l < dim_col (0m (n * i) (m * j))"
  thus "(0m n m  B) $$ (k, l) = 0m (n * i) (m * j) $$ (k,l)" 
    using index_tensor_mat assms less_mult_imp_div_less by force
qed

lemma tensor_mat_zero':
  assumes "B  carrier_mat i j"
  and "0 < j"
  and "0 < m"
shows "B  0m n m = 0m (i * n) (j*m)"
proof (rule eq_matI)
  show "dim_row (B  0m n m) = dim_row (0m (i * n) (j * m))" 
    using assms by simp
  show "dim_col (B  0m n m) = dim_col (0m (i * n) (j * m))" 
    using assms by simp
  fix k l
  assume "k < dim_row (0m (i * n) (j * m))" 
    and "l < dim_col (0m (i * n) (j * m))"
  thus "(B  0m n m ) $$ (k, l) = 0m (i * n) (j * m) $$ (k,l)" 
    using index_tensor_mat assms less_mult_imp_div_less
    by (metis (no_types, lifting) carrier_matD(1) carrier_matD(2) 
        index_zero_mat(1) index_zero_mat(2) index_zero_mat(3) 
        less_nat_zero_code linorder_neqE_nat mod_less_divisor mult_eq_0_iff)
qed

lemma tensor_mat_sum_right:
  fixes A::"complex Matrix.mat"
  assumes "finite I"
  and "A carrier_mat n m"
  and "k. k I  ((B k)::complex Matrix.mat)  carrier_mat i j"
  and "0 < m"
  and "0 < j"
  and "dimR = n *i"
  and "dimC = m*j"
shows "A  (fixed_carrier_mat.sum_mat i j B I) = 
  fixed_carrier_mat.sum_mat (n*i) (m*j) (λi. A  (B i)) I" 
  using assms
proof (induct rule: finite_induct)
  case empty
  hence "A  (fixed_carrier_mat.sum_mat i j B {}) = 0m (n *i) (m*j)" 
    using tensor_mat_zero'
    by (simp add: fixed_carrier_mat.sum_mat_empty fixed_carrier_mat_def) 
  also have "... = fixed_carrier_mat.sum_mat (n*i) (m*j) (λi. A  (B i)) {}"
    by (metis fixed_carrier_mat.intro fixed_carrier_mat.sum_mat_empty)
  finally show ?case .
next
  case (insert x F)
  hence "A  (fixed_carrier_mat.sum_mat i j B (insert x F)) =
    A  (B x + (fixed_carrier_mat.sum_mat i j B F))"
  proof -
    have "fixed_carrier_mat.sum_mat i j B (insert x F) =
      B x + (fixed_carrier_mat.sum_mat i j B F)"
      using fixed_carrier_mat.sum_mat_insert
      by (metis fixed_carrier_mat.intro image_subsetI insertCI 
          insert(1) insert(2) insert(5))
    thus ?thesis by simp
  qed
  also have "... = (A  (B x)) + (A  (fixed_carrier_mat.sum_mat i j B F))"
  proof (rule tensor_mat_add_right)
    show "0 < m" using assms by simp
    show "0 < j" using assms by simp
    show "A  carrier_mat n m" using insert by simp
    show "B x  carrier_mat i j" using insert by simp
    show "fixed_carrier_mat.sum_mat i j B F  carrier_mat i j" 
    proof (rule fixed_carrier_mat.sum_mat_carrier)
      show "k. k  F  B k  carrier_mat i j" using insert by simp
      show "fixed_carrier_mat (carrier_mat i j) i j" 
        by (simp add: fixed_carrier_mat.intro) 
    qed
  qed
  also have "... = (A  (B x)) + 
    fixed_carrier_mat.sum_mat (n*i) (m*j) (λi. A  (B i)) F" 
    using insert by simp
  also have "... = fixed_carrier_mat.sum_mat (n*i) (m*j) (λi. A  (B i)) 
    (insert x F)"   
  proof (rule fixed_carrier_mat.sum_mat_insert[symmetric])
    show "finite F" using insert by simp
    show "x F" using insert by simp
    show "A  B x  carrier_mat (n*i) (m*j)" 
      using  tensor_mat_carrier insert
      by (metis carrier_matD(1) carrier_matD(2) insertI1) 
    show "(λi. A  B i) ` F  carrier_mat (n*i) (m*j)" 
    proof -
      {
        fix k
        assume "k F"
        hence "A  (B k)  carrier_mat (n*i) (m*j)" 
          using  tensor_mat_carrier insert by blast
      }
      thus ?thesis by auto
    qed
    show "fixed_carrier_mat (carrier_mat (n * i) (m * j)) (n * i) (m * j)"
      by (simp add: fixed_carrier_mat.intro)
  qed
  finally show "A  (fixed_carrier_mat.sum_mat i j B (insert x F)) =
    fixed_carrier_mat.sum_mat (n*i) (m*j) (λi. A  (B i)) (insert x F)" .
qed

lemma tensor_mat_add_left:
  assumes "A carrier_mat n m"
  and "B carrier_mat n m"
  and "C carrier_mat i j"
  and "0 < m"
  and "0 < j"
shows "(A + B)  C = (A  C) + (B  C)"
proof (rule eq_matI)
  have "A + B  carrier_mat n m" using assms by simp
  hence bc: "(A+B)  C  carrier_mat (n * i) (m * j)" 
    using assms tensor_mat_carrier
    by (metis carrier_matD(1) carrier_matD(2))
  have  "A  C  carrier_mat (n * i) (m * j)" 
    using assms tensor_mat_carrier
    by (metis carrier_matD(1) carrier_matD(2))
  moreover have  "B  C  carrier_mat (n * i) (m * j)" 
    using assms tensor_mat_carrier
    by (metis carrier_matD(1) carrier_matD(2))
  ultimately have a: "(A  C) + (B  C)  carrier_mat (n * i) (m * j)" 
    by simp
  thus dr: "dim_row ((A+B)  C) = dim_row ((A  C) + (B  C))" 
    using bc by simp
  show dc: "dim_col ((A+B)  C) = dim_col ((A  C) + (B  C))" 
    using a bc by simp
  fix k l
  assume "k < dim_row ((A  C) + (B  C))"
  and "l < dim_col ((A  C) + (B  C))"
  hence "((A+B)  C) $$ (k, l) = 
    (A+B) $$ (k div dim_row C, l div dim_col C) * 
    C $$ (k mod dim_row C, l mod dim_col C)" 
    using index_tensor_mat'
    by (metis A + B  carrier_mat n m assms(3) assms(4) assms(5) bc 
        carrier_matD(1) carrier_matD(2) dc dr)
  also have "... = (A $$ (k div dim_row C, l div dim_col C) + 
    B $$ (k div dim_row C, l div dim_col C)) *
    C $$ (k mod dim_row C, l mod dim_col C)"
    using k < dim_row ((A  C) + (B  C)) l < dim_col ((A  C) + (B  C)) 
      less_mult_imp_div_less by force
  also have "... = A $$ (k div dim_row C, l div dim_col C) * 
    C $$ (k mod dim_row C, l mod dim_col C) + 
    B $$ (k div dim_row C, l div dim_col C) * 
    C $$ (k mod dim_row C, l mod dim_col C)"
    using distrib_right by blast
  also have "... = (A  C) $$ (k,l) + (B  C) $$ (k,l)"
    using k < dim_row ((A  C) + (B  C)) l < dim_col ((A  C) + (B  C)) 
      assms by fastforce
  also have "... = ((A  C) + (B  C)) $$ (k,l)"
    using k < dim_row ((A  C) + (B  C)) l < dim_col ((A  C) + (B  C)) 
    by force
  finally show "((A+B)  C) $$ (k, l) = ((A  C) + (B  C)) $$ (k,l)" .
qed

lemma tensor_mat_smult_left:
  assumes "A carrier_mat n m"
  and "B carrier_mat i j"
  and "0 < m"
  and "0 < j"
shows "x  m A   B = x  m (A  B)"
proof (rule eq_matI)
  have "x  m A  carrier_mat n m" using assms by simp
  hence "x m A  B  carrier_mat (n * i) (m * j)" 
    using assms tensor_mat_carrier
    by (metis carrier_matD(1) carrier_matD(2))
  moreover have "A  B  carrier_mat (n * i) (m * j)" 
    using assms tensor_mat_carrier
    by (metis carrier_matD(1) carrier_matD(2))
  ultimately show 
    "dim_row (x m A  B) = dim_row (x m (A  B))" 
    "dim_col (x m A  B) = dim_col (x m (A  B))" by auto
  fix k l
  assume k: "k < dim_row (x m (A  B))"
  and l: "l < dim_col (x m (A  B))"
  hence "(x m A  B) $$ (k, l) = 
    (x m A) $$ (k div dim_row B, l div dim_col B) * 
    B $$ (k mod dim_row B, l mod dim_col B)" 
    using index_tensor_mat' assms by force
  also have "... = x * (A $$ (k div dim_row B, l div dim_col B)) * 
    B $$ (k mod dim_row B, l mod dim_col B)"
    using k l less_mult_imp_div_less by fastforce
  also have "... = x * (A $$ (k div dim_row B, l div dim_col B) * 
    B $$ (k mod dim_row B, l mod dim_col B))" by simp
  also have "... = x * (A  B) $$ (k,l)"
    using assms k l by force
  also have "... = (x m (A  B)) $$ (k,l)" using assms k l by auto
  finally show "(x m A  B) $$ (k, l) = (x m (A  B)) $$ (k,l)" .
qed

lemma tensor_mat_smult_right:
  assumes "A carrier_mat n m"
  and "B carrier_mat i j"
  and "0 < m"
  and "0 < j"
shows "A   (x  m B) = x  m (A  B)"
proof (rule eq_matI)
  have "x  m B  carrier_mat i j" using assms by simp
  hence "A  (x m B)  carrier_mat (n * i) (m * j)" 
    using assms tensor_mat_carrier
    by (metis carrier_matD(1) carrier_matD(2))
  moreover have "A  B  carrier_mat (n * i) (m * j)" 
    using assms tensor_mat_carrier
    by (metis carrier_matD(1) carrier_matD(2))
  ultimately show 
    "dim_row (A  x m B) = dim_row (x m (A  B))" 
    "dim_col (A  x m B) = dim_col (x m (A  B))" by auto
  fix k l
  assume k: "k < dim_row (x m (A  B))"
  and l: "l < dim_col (x m (A  B))"
  hence "(A  (x m B)) $$ (k, l) = 
    A $$ (k div dim_row (x m B), l div dim_col (x m B)) * 
    (x m B) $$ (k mod dim_row (x m B), l mod dim_col (x m B))" 
    using index_tensor_mat' assms by force
  also have "... = A $$ (k div dim_row (x m B), l div dim_col (x m B)) * 
    (x * B $$ (k mod dim_row (x m B), l mod dim_col (x m B)))"
    using k l
    by (metis (no_types, opaque_lifting) add_lessD1 dim_col_tensor_mat 
        dim_row_tensor_mat index_smult_mat(1) index_smult_mat(2) 
        index_smult_mat(3) mod_less_divisor nat_0_less_mult_iff 
        plus_nat.simps(1)) 
  also have "... = x *(A$$ (k div dim_row (x m B), l div dim_col (x m B))* 
    B $$ (k mod dim_row (x m B), l mod dim_col (x m B)))" by simp
  also have "... = x * (A  B) $$ (k,l)"
    using assms k l by force
  also have "... = (x m (A  B)) $$ (k,l)" using assms k l by auto
  finally show "(A  (x m B)) $$ (k, l) = (x m (A  B)) $$ (k,l)" .
qed

lemma tensor_mat_smult:
  assumes "A carrier_mat n m"
  and "B carrier_mat i j"
  and "0 < m"
  and "0 < j"
shows "x m A   (y  m B) = x * y  m (A  B)"
  by (metis (no_types, opaque_lifting) assms smult_carrier_mat 
      smult_smult_times tensor_mat_smult_left tensor_mat_smult_right)

lemma tensor_mat_singleton_right:
  assumes "0 < dim_col A"
  and "B  carrier_mat 1 1"
shows "A  B = B $$(0,0)  m A"
proof (rule eq_matI)
  show "dim_row (A  B) = dim_row (B $$ (0, 0) m A)" using assms by auto
  show "dim_col (A  B) = dim_col (B $$ (0, 0) m A)" using assms by auto
  fix i j
  assume "i < dim_row (B $$ (0, 0) m A)"
  and "j < dim_col (B $$ (0, 0) m A)"
  have "(A  B) $$ (i, j) = A $$ (i div dim_row B,j div dim_col B) * 
    B $$(i mod dim_row B, j mod dim_col B)" using index_tensor_mat
    i < dim_row (B $$ (0, 0) m A) j < dim_col (B $$ (0, 0) m A) assms 
    by fastforce
  also have "... = A $$(i,j) * B$$(0,0)" using assms by auto
  also have "... = (B $$ (0, 0) m A) $$ (i, j)"
    using i < dim_row (B $$ (0, 0) m A) j < dim_col (B $$ (0, 0) m A) 
    by force
  finally show "(A  B) $$ (i, j) = (B $$ (0, 0) m A) $$ (i, j)" .
qed

lemma tensor_mat_singleton_left:
  assumes "0 < dim_col A"
  and "B  carrier_mat 1 1"
shows "B  A = B $$(0,0)  m A"
proof (rule eq_matI)
  show "dim_row (B  A) = dim_row (B $$ (0, 0) m A)" using assms by auto
  show "dim_col (B  A) = dim_col (B $$ (0, 0) m A)" using assms by auto
  fix i j
  assume "i < dim_row (B $$ (0, 0) m A)"
  and "j < dim_col (B $$ (0, 0) m A)"
  have "(B  A) $$ (i, j) = A $$ (i div dim_row B,j div dim_col B) * 
    B $$(i mod dim_row B, j mod dim_col B)" using index_tensor_mat
    i < dim_row (B $$ (0, 0) m A) j < dim_col (B $$ (0, 0) m A) assms 
    by fastforce
  also have "... = A $$(i,j) * B$$(0,0)" using assms by auto
  also have "... = (B $$ (0, 0) m A) $$ (i, j)"
    using i < dim_row (B $$ (0, 0) m A) j < dim_col (B $$ (0, 0) m A) 
    by force
  finally show "(B  A) $$ (i, j) = (B $$ (0, 0) m A) $$ (i, j)" .
qed

lemma tensor_mat_sum_left:
  assumes "finite I"
  and "B carrier_mat i j"
  and "k. k I  A k  carrier_mat n m"
  and "0 < m"
  and "0 < j"
  and "dimR = n *i"
  and "dimC = m*j"
shows "(fixed_carrier_mat.sum_mat n m A I)  B = 
  fixed_carrier_mat.sum_mat (n*i) (m*j) (λi. (A i)  B) I" 
  using assms
proof (induct rule: finite_induct)
  case empty
  hence "(fixed_carrier_mat.sum_mat n m A {})  B = 0m (n *i) (m*j)" 
    using tensor_mat_zero
    by (simp add: fixed_carrier_mat.sum_mat_empty fixed_carrier_mat_def) 
  also have "... = fixed_carrier_mat.sum_mat (n*i) (m*j) (λi. (A i)  B) {}"
    by (metis fixed_carrier_mat.intro fixed_carrier_mat.sum_mat_empty)
  finally show ?case .
next
  case (insert x F)
  hence "(fixed_carrier_mat.sum_mat n m A (insert x F))  B =
    (A x + (fixed_carrier_mat.sum_mat n m A F))   B"
  proof -
    have "fixed_carrier_mat.sum_mat n m A (insert x F) =
      A x + (fixed_carrier_mat.sum_mat n m A F)"
      using fixed_carrier_mat.sum_mat_insert
      by (metis fixed_carrier_mat.intro image_subsetI insertCI 
          insert(1) insert(2) insert(5))
    thus ?thesis by simp
  qed
  also have "... = (A x  B) + (fixed_carrier_mat.sum_mat n m A F  B)"
  proof (rule tensor_mat_add_left)
    show "0 < m" using assms by simp
    show "0 < j" using assms by simp
    show "A x  carrier_mat n m" using insert by simp
    show "B  carrier_mat i j" using insert by simp
    show "fixed_carrier_mat.sum_mat n m A F  carrier_mat n m" 
    proof (rule fixed_carrier_mat.sum_mat_carrier)
      show "k. k  F  A k  carrier_mat n m" using insert by simp
      show "fixed_carrier_mat (carrier_mat n m) n m" 
        by (simp add: fixed_carrier_mat.intro) 
    qed
  qed
  also have "... = (A x  B) + 
    fixed_carrier_mat.sum_mat (n*i) (m*j) (λi. A i  B) F" 
    using insert by simp
  also have "... = fixed_carrier_mat.sum_mat (n*i) (m*j) (λi. A i  B) 
    (insert x F)"  
  proof (rule fixed_carrier_mat.sum_mat_insert[symmetric])
    show "finite F" using insert by simp
    show "x F" using insert by simp
    show "A x  B  carrier_mat (n*i) (m*j)" 
      using  tensor_mat_carrier insert by blast
    show "(λi. A i  B) ` F  carrier_mat (n*i) (m*j)" 
    proof -
      {
        fix k
        assume "k F"
        hence "A k  B  carrier_mat (n*i) (m*j)" 
          using tensor_mat_carrier insert by blast
      }
      thus ?thesis by auto
    qed
    show "fixed_carrier_mat (carrier_mat (n * i) (m * j)) (n * i) (m * j)"
      by (simp add: fixed_carrier_mat.intro)
  qed
  finally show "fixed_carrier_mat.sum_mat n m A (insert x F)  B =
    fixed_carrier_mat.sum_mat (n*i) (m*j) (λi. A i  B) (insert x F)" .
qed

lemma tensor_mat_diag_elem:
  assumes "A carrier_mat n n"
  and "B carrier_mat m m"
  and "i < n * m"
  and "0 < n*m"
shows "(A  B) $$ (i, i) = A $$ (i div m, i div m) * 
    B $$ (i mod m, i mod m)"
proof -
  have "i < dim_row (A  B)" using assms by auto
  have "(A  B) $$ (i, i) = A $$ (i div (dim_row B), i div (dim_col B)) * 
    B $$ (i mod (dim_row B), i mod (dim_col B))" using index_tensor_mat'
    by (metis i < dim_row (A  B) assms carrier_matD(2) dim_row_tensor_mat 
        nat_0_less_mult_iff)
  also have "... = A $$ (i div m, i div m) * B $$ (i mod m, i mod m)"
    using assms by auto
  finally show ?thesis .
qed

context cpx_sq_mat
begin

lemma tensor_mat_sum_mat_right:
  assumes "finite I"
  and "A carrier_mat n n"
  and "k. k I  B k  carrier_mat i i"
  and "0 < n"
  and "0 < i"
  and "dimR = n *i"
shows "A  (fixed_carrier_mat.sum_mat i i B I) = sum_mat (λi. A  (B i)) I"
  using assms dim_eq tensor_mat_sum_right by blast 

lemma tensor_mat_sum_mat_left:
  assumes "finite I"
  and "B carrier_mat i i"
  and "k. k I  A k  carrier_mat n n"
  and "0 < n"
  and "0 < i"
  and "dimR = n *i"
shows "(fixed_carrier_mat.sum_mat n n A I)  B = sum_mat (λi. (A i)  B) I" 
  using assms dim_eq tensor_mat_sum_left by blast 

lemma tensor_mat_sum_nat_mod_div_ne_0:
  assumes "k. k < (nC::nat)  A k  carrier_mat n n"
  and "j. j < (nD::nat)  B j  carrier_mat m m"
  and "fixed_carrier_mat.sum_mat n n (λi. f i m (A i)) {..< nC} = C"
  and "fixed_carrier_mat.sum_mat m m (λj. g j m (B j)) {..< nD} = D"
  and "0 < n"
  and "0 < m"
  and "nD 0"
  and "dimR = n *m"
shows "sum_mat (λi. (f (i div nD) * g (i mod nD))m 
  ((A (i div nD))  (B (i mod nD)))) 
  {..< nC*nD} = C D" using assms
proof (induct nC arbitrary: C)
  case 0
  hence "C = fixed_carrier_mat.sum_mat n n (λi. f i m (A i)) {}" by simp
  also have "... =  0m n n" 
    using fixed_carrier_mat.sum_mat_empty[of _ n n "λi. f i m (A i)"]
    by (simp add: fixed_carrier_mat_def)
  finally have "C = 0m n n" .
  moreover have "D  carrier_mat m m" using 0
    fixed_carrier_mat.sum_mat_carrier[of _ m m "{..< nD}" "λj. g j m (B j)"]
    by (simp add: fixed_carrier_mat_def)
  ultimately have "C D = 0m (n*m) (n*m)" using tensor_mat_zero
    by (simp add: "0"(5) "0"(6)) 
  have "sum_mat (λi. (f (i div nD) * g (i mod nD))m 
    ((A (i div nD))  (B (i mod nD)))) 
    {..< 0*nD} = sum_mat (λi. (f (i div nD) * g (i mod nD))m 
    ((A (i div nD))  (B (i mod nD)))) {}" by simp
  also have "... = 0m (n*m) (n*m)" using sum_mat_empty
    using "0" dim_eq by blast
  also have "... = C D" using C D = 0m (n*m) (n*m) by simp
  finally show ?case .
next
  case (Suc nC)
  define Cp where 
    "Cp = fixed_carrier_mat.sum_mat n n (λi. f i m (A i)) {..< nC}"
  have fc: "i{..<nC * nD}  {nC * nD..<Suc nC * nD}.
    (A (i div nD)  B (i mod nD))  fc_mats" 
  proof
    fix i
    assume "i  {..<nC * nD}  {nC * nD..<Suc nC * nD}"
    hence i: "i  {..< Suc nC * nD}" by auto
    hence "i div nD < Suc nC"
      by (simp add: less_mult_imp_div_less) 
    hence "A (i div nD)  carrier_mat n n" using Suc by simp
    have "i mod nD < nD" using Suc by simp 
    hence "B (i mod nD)  carrier_mat m m" using Suc by simp
    hence "A (i div nD)  B (i mod nD)  carrier_mat (n*m) (n*m)" 
      using tensor_mat_carrier
      by (metis A (i div nD)  carrier_mat n n 
          carrier_matD(1) carrier_matD(2))
    thus "(A (i div nD)  B (i mod nD))  fc_mats"
      using Suc dim_eq fc_mats_carrier by blast 
  qed
  have "sum_mat (λi. (f (i div nD) * g (i mod nD))m 
    ((A (i div nD))  (B (i mod nD)))) 
    {..< (Suc nC)*nD} = sum_mat (λi. (f (i div nD) * g (i mod nD))m 
    ((A (i div nD))  (B (i mod nD)))) {..< nC*nD} + 
    sum_mat (λi. (f (i div nD) * g (i mod nD))m 
    ((A (i div nD))  (B (i mod nD)))) {nC*nD..< (Suc nC)*nD}"  
  proof -
    have "{..< (Suc nC)*nD} = {..< nC*nD}  {nC*nD..< (Suc nC)*nD}" by auto
    moreover have "sum_mat (λi. (f (i div nD) * g (i mod nD))m 
      ((A (i div nD))  (B (i mod nD)))) 
      ({..< nC*nD}  {nC*nD..< (Suc nC)*nD}) = 
      sum_mat (λi. (f (i div nD) * g (i mod nD))m 
      ((A (i div nD))  (B (i mod nD)))) {..< nC*nD} + 
      sum_mat (λi. (f (i div nD) * g (i mod nD))m 
      ((A (i div nD))  (B (i mod nD)))) {nC*nD..< (Suc nC)*nD}"
    proof (rule sum_mat_disj_union)
      show "{..<nC * nD}  {nC * nD..<Suc nC * nD} = {}"
        by (simp add: ivl_disj_int(2))
      show "i{..<nC * nD}  {nC * nD..<Suc nC * nD}.
       f (i div nD) * g (i mod nD) m (A (i div nD)  B (i mod nD))  fc_mats" 
        using fc smult_mem by blast        
    qed simp+
    ultimately show ?thesis by simp
  qed
  also have "... = 
    (Cp  D) + 
    sum_mat (λi. (f (i div nD) * g (i mod nD))m 
    ((A (i div nD))  (B (i mod nD)))) {nC*nD..< (Suc nC)*nD}"
  proof -
    have "sum_mat (λi. (f (i div nD) * g (i mod nD))m 
      ((A (i div nD))  (B (i mod nD)))) {..< nC*nD} = Cp  D"
      unfolding Cp_def using Suc by simp
    thus ?thesis by simp
  qed
  also have "... = 
    (Cp  D) + 
    sum_mat (λi. (f nC * g (i mod nD))m 
    ((A nC)  (B (i mod nD)))) {nC*nD..< (Suc nC)*nD}" 
  proof -
    have "sum_mat (λi. (f (i div nD) * g (i mod nD))m 
      ((A (i div nD))  (B (i mod nD)))) {nC*nD..< (Suc nC)*nD} = 
      sum_mat (λi. (f nC * g (i mod nD))m 
      ((A nC)  (B (i mod nD)))) {nC*nD..< (Suc nC)*nD}" 
    proof (rule sum_mat_cong)
      show "i. i  {nC * nD..<Suc nC * nD} 
        f (i div nD) * g (i mod nD) m (A (i div nD)  B (i mod nD))  
        fc_mats" using fc by (metis UnI2 smult_mem) 
      show "i. i  {nC * nD..<Suc nC * nD}  
        f nC * g (i mod nD) m (A nC  B (i mod nD))  fc_mats" 
      proof
        fix i
        assume "i  {nC * nD..<Suc nC * nD}" 
        hence "i mod nD < nD" using Suc mod_less_divisor by blast 
        hence "B (i mod nD)  carrier_mat m m" using Suc by simp
        moreover have "A nC  carrier_mat n n" using Suc by simp
        ultimately have "A nC  B (i mod nD)  carrier_mat (n*m) (n*m)" 
          using tensor_mat_carrier by (metis carrier_matD(1) carrier_matD(2))
        hence "(A nC  B (i mod nD))  fc_mats"
            using Suc dim_eq fc_mats_carrier by blast 
        thus "f nC * g (i mod nD) m (A nC  B (i mod nD))  fc_mats"
          using smult_mem by blast
      qed simp
      show "i. i  {nC * nD..<Suc nC * nD} 
        f (i div nD) * g (i mod nD) m (A (i div nD)  B (i mod nD)) =
        f nC * g (i mod nD) m (A nC  B (i mod nD))"
      proof -
        fix i
        assume "i  {nC * nD..<Suc nC * nD}"
        hence "i div nD = nC"
          by (metis atLeastLessThan_iff div_nat_eqI mult.commute)
        thus "f (i div nD) * g (i mod nD) m (A (i div nD)  B (i mod nD)) =
          f nC * g (i mod nD) m (A nC  B (i mod nD))" by simp
      qed
    qed simp
    thus ?thesis by simp
  qed
  also have "... = 
    (Cp  D) + 
    sum_mat (λi. (f nC m (A nC))  (g (i mod nD)m (B (i mod nD)))) 
    {nC*nD..< (Suc nC)*nD}" 
  proof -
    have "sum_mat (λi. f nC * g (i mod nD) m (A nC  B (i mod nD))) 
      {nC * nD..<Suc nC * nD} =
      sum_mat (λi. f nC m A nC  g (i mod nD) m B (i mod nD)) 
      {nC * nD..<Suc nC * nD}" 
    proof (rule sum_mat_cong)
      show "i. i  {nC * nD..<Suc nC * nD}  
        f nC * g (i mod nD) m (A nC  B (i mod nD))  fc_mats" 
      proof -
        fix i
        assume "i  {nC * nD..<Suc nC * nD}"
        have "i mod nD < nD" using Suc mod_less_divisor by blast
        hence "B (i mod nD)  carrier_mat m m" using Suc by simp
        moreover have "A nC  carrier_mat n n" by (simp add: Suc(2))
        ultimately have "A nC  (B (i mod nD))  carrier_mat (n*m) (n*m)"
          using tensor_mat_carrier
          by (metis carrier_matD(1) carrier_matD(2))
        hence "A nC  B (i mod nD)  fc_mats" using fc_mats_carrier
          Suc dim_eq by blast
        thus "f nC * g (i mod nD) m (A nC  B (i mod nD))  fc_mats"
          using cpx_sq_mat_smult by blast
      qed
      show "i. i  {nC * nD..<Suc nC * nD}  
        f nC m A nC  g (i mod nD) m B (i mod nD)  fc_mats"
      proof -
        fix i
        assume "i  {nC * nD..<Suc nC * nD}"
        have "i mod nD < nD" using Suc mod_less_divisor by blast
        hence "g (i mod nD) m B (i mod nD)  carrier_mat m m" using Suc 
          by simp
        moreover have "f nC m A nC  carrier_mat n n" by (simp add: Suc(2))
        ultimately have "f nC m A nC  g (i mod nD) m B (i mod nD)  
          carrier_mat (n*m) (n*m)"
          using tensor_mat_carrier
          by (metis carrier_matD(1) carrier_matD(2))
        thus "f nC m A nC  g (i mod nD) m B (i mod nD)  fc_mats" 
          using fc_mats_carrier Suc dim_eq by blast
      qed
      show "i. i  {nC * nD..<Suc nC * nD} 
        f nC * g (i mod nD) m (A nC  B (i mod nD)) =
        f nC m A nC  g (i mod nD) m B (i mod nD)" 
      proof -
        fix i
        assume "i  {nC * nD..<Suc nC * nD}"
        show "f nC * g (i mod nD) m (A nC  B (i mod nD)) =
        f nC m A nC  g (i mod nD) m B (i mod nD)" using tensor_mat_smult
          by (metis div_eq_0_iff Suc(3) Suc(8) 
              Suc.prems(1) assms(5) assms(6) lessI mod_div_trivial)
      qed
    qed simp
    thus ?thesis by simp
  qed
  also have "... = 
    (Cp  D) +
    ((f nC m (A nC))  (fixed_carrier_mat.sum_mat m m 
      (λi. g (i mod nD)m (B (i mod nD))) 
    {nC*nD..< (Suc nC)*nD}))"
  proof -
    have "sum_mat (λi. f nC m A nC  g (i mod nD) m B (i mod nD)) 
      {nC * nD..<Suc nC * nD} = 
      f nC m (A nC)  (fixed_carrier_mat.sum_mat m m 
        (λi. g (i mod nD)m (B (i mod nD))) 
      {nC*nD..< (Suc nC)*nD})" 
    proof (rule tensor_mat_sum_mat_right[symmetric])
      show "0 < n" "0 < m" "dimR = n*m" using Suc by auto
      show "f nC m A nC  carrier_mat n n" by (simp add: Suc(2))
      fix i
      assume "i  {nC * nD..<Suc nC * nD}"
      have "i mod nD < nD" using Suc mod_less_divisor by blast
      hence "B (i mod nD)  carrier_mat m m" using Suc by simp
      thus "g (i mod nD) m B (i mod nD)  carrier_mat m m" by simp
    qed simp
    thus ?thesis by simp
  qed
  also have "... = 
    (Cp  D) +
    ((f nC m (A nC))  (fixed_carrier_mat.sum_mat m m 
    (λj. g j m (B j)) {..< nD}))" 
  proof -
    have "fixed_carrier_mat.sum_mat m m (λi. g (i mod nD) m B (i mod nD)) 
      {nC * nD..<Suc nC * nD} = 
      fixed_carrier_mat.sum_mat m m (λi. g (i mod nD) m B (i mod nD)) 
      ((+) (nC * nD) ` {..<nD})" 
    proof (rule fixed_carrier_mat.sum_mat_cong')
      show "{nC * nD..<Suc nC * nD} = (+) (nC * nD) ` {..<nD}"
        by (simp add: lessThan_atLeast0)
      show "fixed_carrier_mat (carrier_mat m m) m m"
        by (simp add: fixed_carrier_mat.intro)
      show "i. i  {nC * nD..<Suc nC * nD}  
        g (i mod nD) m B (i mod nD)  carrier_mat m m"
      proof -
        fix i
        assume "i  {nC * nD..<Suc nC * nD}"
        hence "i mod nD < nD"
          using Suc mod_less_divisor by blast
        thus "g (i mod nD) m B (i mod nD)  carrier_mat m m"
          using Suc(3) smult_carrier_mat by blast
      qed
      thus "i. i  {nC * nD..<Suc nC * nD}  
        g (i mod nD) m B (i mod nD)  carrier_mat m m" .
    qed simp+
    also have "... = 
      fixed_carrier_mat.sum_mat m m (λj. g j m B j) {..<nD}" 
    proof (rule fixed_carrier_mat.sum_mat_mod_eq)
      show "fixed_carrier_mat (carrier_mat m m) m m"
        by (simp add: fixed_carrier_mat.intro)
      show "x. x  {..<nD}  g x m B x  carrier_mat m m"
        by (simp add: Suc(3))
    qed
    finally have "fixed_carrier_mat.sum_mat m m 
      (λi. g (i mod nD) m B (i mod nD)) 
      {nC * nD..<Suc nC * nD} =
      fixed_carrier_mat.sum_mat m m (λj. g j m B j) {..<nD}" .
    thus ?thesis by simp
  qed
  also have "... = (Cp  D) + ((f nC m (A nC))  D)" using Suc by simp
  also have "... = Cp + (f nC m (A nC))  D" 
  proof (rule tensor_mat_add_left[symmetric])      
    show "Cp  carrier_mat n n" unfolding Cp_def 
    proof (rule fixed_carrier_mat.sum_mat_carrier)
      show "i. i  {..< nC}  f i m A i  carrier_mat n n"
        by (simp add: Suc(2))
      show "fixed_carrier_mat (carrier_mat n n) n n"
        by (simp add: fixed_carrier_mat.intro)
    qed
    have "fixed_carrier_mat.sum_mat m m (λj. g j m B j) {..<nD} 
      carrier_mat m m" 
    proof (rule fixed_carrier_mat.sum_mat_carrier)
      show "i. i  {..< nD}  g i m B i  carrier_mat m m"
        by (simp add: Suc)
      show "fixed_carrier_mat (carrier_mat m m ) m m"
        by (simp add: fixed_carrier_mat.intro)
    qed
    thus "D  carrier_mat m m" using Suc by simp
    show "f nC m A nC  carrier_mat n n"
      by (simp add: Suc(2))
  qed (auto simp add: Suc)
  also have "... = 
    (fixed_carrier_mat.sum_mat n n (λi. f i m (A i)) {..< Suc nC})  D"
  proof -
    have "Cp + f nC m A nC = f nC m A nC + Cp" 
    proof (rule comm_add_mat)
      show "f nC m A nC  carrier_mat n n" by (simp add: Suc(2))
      show "Cp  carrier_mat n n" unfolding Cp_def 
      proof (rule fixed_carrier_mat.sum_mat_carrier)
        show "i. i  {..< nC}  f i m A i  carrier_mat n n"
          by (simp add: Suc(2))
        show "fixed_carrier_mat (carrier_mat n n) n n"
          by (simp add: fixed_carrier_mat.intro)
      qed
    qed
    also have "... = fixed_carrier_mat.sum_mat n n 
      (λi. f i m (A i)) (insert nC {..< nC})" unfolding Cp_def 
    proof (rule fixed_carrier_mat.sum_mat_insert[symmetric])
      show "f nC m A nC  carrier_mat n n"
        by (simp add: Suc(2))
      show "fixed_carrier_mat (carrier_mat n n) n n"
        by (simp add: fixed_carrier_mat.intro)
      show "(λi. f i m A i) ` {..<nC}  carrier_mat n n"
      proof 
        fix x
        assume "x  (λi. f i m A i) ` {..<nC}"
        hence "i  {..<nC}. x = f i m A i" by auto
        from this obtain i where "i {..<nC}" and "x = f i m A i" by auto
        have "f i m A i  carrier_mat n n"
          using Suc.prems(1) i  {..<nC} by auto 
        thus "x  carrier_mat n n" using x = f i m A i by simp
      qed
    qed auto
    also have "... = fixed_carrier_mat.sum_mat n n (λi. f i m (A i)) 
      {..< Suc nC}"
    proof (rule fixed_carrier_mat.sum_mat_cong')
      show "fixed_carrier_mat (carrier_mat n n) n n"
        by (simp add: fixed_carrier_mat.intro)
      show "insert nC {..<nC} = {..<Suc nC}"
        by (simp add: lessThan_Suc)
      show "i. i  insert nC {..<nC}  f i m A i  carrier_mat n n"
        by (simp add: Suc(2) insert nC {..<nC} = {..<Suc nC})
      thus "i. i  insert nC {..<nC}  f i m A i  carrier_mat n n" .
    qed auto
    finally have "Cp + f nC m A nC = fixed_carrier_mat.sum_mat n n 
      (λi. f i m (A i)) {..< Suc nC}" .
    thus ?thesis by simp
  qed
  also have "... = C D" using Suc by simp
  finally show ?case .
qed

lemma tensor_mat_sum_nat_mod_div_eq_0:
  assumes "k. k < (nC::nat)  A k  carrier_mat n n"
  and "fixed_carrier_mat.sum_mat n n (λi. f i m (A i)) {..< nC} = C"
  and "fixed_carrier_mat.sum_mat m m (λj. g j m (B j)) {..< nD} = D"
  and "0 < n"
  and "0 < m"
  and "nD = 0"
  and "dimR = n *m"
shows "sum_mat (λi. (f (i div nD) * g (i mod nD))m 
  ((A (i div nD))  (B (i mod nD)))) 
  {..< nC*nD} = C D"
proof -
  have "D = fixed_carrier_mat.sum_mat m m (λi. g i m (B i)) {}" 
    using assms by auto
  also have "... =  0m m m" 
    using fixed_carrier_mat.sum_mat_empty[of _ m m "λi. g i m (B i)"]
    by (simp add: fixed_carrier_mat_def)
  finally have "D = 0m m m" .
  moreover have "C  carrier_mat n n" using assms
    fixed_carrier_mat.sum_mat_carrier[of _ n n "{..< nC}" "λj. f j m (A j)"]
    by (simp add: fixed_carrier_mat_def)
  ultimately have "C D = 0m (n*m) (n*m)" using tensor_mat_zero'
    by (simp add: assms)
  have "sum_mat (λi. (f (i div nD) * g (i mod nD))m 
    ((A (i div nD))  (B (i mod nD)))) 
    {..< nC*nD} = sum_mat (λi. (f (i div nD) * g (i mod nD))m 
    ((A (i div nD))  (B (i mod nD)))) {}" using assms by simp
  also have "... = 0m (n*m) (n*m)" using sum_mat_empty
    using assms(7) dim_eq by blast
  also have "... = C D" using C D = 0m (n*m) (n*m) by simp
  finally show ?thesis .
qed

lemma tensor_mat_sum_nat_mod_div:
  assumes "k. k < (nC::nat)  A k  carrier_mat n n"
  and "j. j < (nD::nat)  B j  carrier_mat m m"
  and "fixed_carrier_mat.sum_mat n n (λi. f i m (A i)) {..< nC} = C"
  and "fixed_carrier_mat.sum_mat m m (λj. g j m (B j)) {..< nD} = D"
  and "0 < n"
  and "0 < m"
  and "dimR = n *m"
shows "sum_mat (λi. (f (i div nD) * g (i mod nD))m 
  ((A (i div nD))  (B (i mod nD)))) 
  {..< nC*nD} = C D"
proof (cases "nD = 0")
  case True
  then show ?thesis using assms 
      tensor_mat_sum_nat_mod_div_eq_0[OF assms(1) assms(3)] by simp
next
  case False
  then show ?thesis using assms tensor_mat_sum_nat_mod_div_ne_0 by simp
qed

end
  
lemma tensor_mat_sum_mult_trace_expand_ne_0:
  assumes "k. k < (nC::nat)  A k  carrier_mat n n"
  and "j. j < (nD::nat)  B j  carrier_mat m m"
  and "R  carrier_mat (n*m) (n*m)"
  and "fixed_carrier_mat.sum_mat n n (λi. f i m (A i)) {..< nC} = C"
  and "fixed_carrier_mat.sum_mat m m (λj. g j m (B j)) {..< nD} = D"
  and "0 < n"
  and "0 < m"
  and "nD  0"
  shows "sum (λi. Complex_Matrix.trace ((f (i div nD) * g (i mod nD))m 
    ((A (i div nD))  (B (i mod nD))) * R)) {..< nC * nD} = 
    Complex_Matrix.trace ((C  D) * R)" 
proof -
  define fc::"complex Matrix.mat set" where "fc = carrier_mat (n*m) (n*m)"
  interpret cpx_sq_mat "n*m" "n*m" fc  
  proof 
    show "0 < n*m" using assms by simp
  qed (auto simp add: fc_def)
  have fc: "i{..<nC * nD}.
    f (i div nD) * g (i mod nD) m (A (i div nD)  B (i mod nD))  fc" 
  proof
    fix i
    assume "i  {..<nC * nD}"
    hence "i div nD <  nC"
      by (simp add: less_mult_imp_div_less) 
    hence "A (i div nD)  carrier_mat n n" using assms by simp
    have "i mod nD < nD" using assms by simp 
    hence "B (i mod nD)  carrier_mat m m" using assms by simp
    hence "A (i div nD)  B (i mod nD)  carrier_mat (n*m) (n*m)" 
      using tensor_mat_carrier
      by (metis A (i div nD)  carrier_mat n n 
          carrier_matD(1) carrier_matD(2))
    hence "(A (i div nD)  B (i mod nD))  fc"
      using assms dim_eq fc_mats_carrier by blast 
    thus "f (i div nD) * g (i mod nD) m(A (i div nD)  B (i mod nD))  fc"
      using smult_mem by blast
  qed
  have "sum_mat (λi. ((f (i div nD) * g (i mod nD))m 
    ((A (i div nD))  (B (i mod nD))) * R)) {..< nC * nD} = 
    (sum_mat (λi. f (i div nD) * g (i mod nD)m 
    ((A (i div nD))  (B (i mod nD)))) {..< nC * nD}) * R"
  proof (rule sum_mat_distrib_right)
    show "R fc" using assms unfolding fc_def by simp
  qed (auto simp add: fc assms)
  also have "... = (C  D) * R"
  proof -
    have "sum_mat (λi. f (i div nD) * g (i mod nD)m 
      ((A (i div nD))  (B (i mod nD)))) {..< nC * nD} = C  D"
      using tensor_mat_sum_nat_mod_div assms by simp
    thus ?thesis by simp
  qed
  finally have sr: "sum_mat (λi. ((f (i div nD) * g (i mod nD))m 
    ((A (i div nD))  (B (i mod nD))) * R)) {..< nC * nD} = (C  D) * R" .
  have "sum (λi. Complex_Matrix.trace ((f (i div nD) * g (i mod nD))m 
    ((A (i div nD))  (B (i mod nD))) * R)) {..< nC * nD} = 
    Complex_Matrix.trace (sum_mat (λi. ((f (i div nD) * g (i mod nD))m 
    ((A (i div nD))  (B (i mod nD))) * R)) {..< nC * nD})"
  proof (rule trace_sum_mat[symmetric])
    show "i. i  {..<nC * nD} 
      f (i div nD) * g (i mod nD) m (A (i div nD)  B (i mod nD)) * R  fc"
      using fc assms cpx_sq_mat_mult fc_def by blast 
  qed simp
  also have "... = Complex_Matrix.trace ((C  D) * R)" using sr by simp
  finally show ?thesis .
qed

lemma tensor_mat_sum_mult_trace_expand_eq_0:
  assumes "k. k < (nC::nat)  A k  carrier_mat n n"
  and "R  carrier_mat (n*m) (n*m)"
  and "fixed_carrier_mat.sum_mat n n (λi. f i m (A i)) {..< nC} = C"
  and "fixed_carrier_mat.sum_mat m m (λj. g j m (B j)) {..< nD} = D"
  and "0 < n"
  and "0 < m"
  and "nD = 0"
  shows "sum (λi. Complex_Matrix.trace ((f (i div nD) * g (i mod nD))m 
    ((A (i div nD))  (B (i mod nD))) * R)) {..< nC * nD} = 
    Complex_Matrix.trace ((C  D) * R)" 
proof -
  have "D = 0m m m" using assms fixed_carrier_mat.sum_mat_empty
    fixed_carrier_mat.intro by fastforce
  hence "C  D = C  (0m m m)" by simp
  also have "... = 0m (n*m) (n*m)" 
  proof (rule tensor_mat_zero')
    have "fixed_carrier_mat.sum_mat n n (λi. f i m A i) {..<nC}  
      carrier_mat n n" 
    proof (rule fixed_carrier_mat.sum_mat_carrier)
      show "fixed_carrier_mat (carrier_mat n n) n n"
        by (simp add: fixed_carrier_mat.intro)
      show "i. i  {..<nC}  f i m A i  carrier_mat n n" using assms
        by simp
    qed 
    thus "C carrier_mat n n" using assms by simp
  qed (simp add: assms)+
  finally have "C  D = 0m (n*m) (n*m)" .
  hence "(C  D) * R = 0m (n*m) (n*m)"
    by (simp add: assms left_mult_zero_mat)
  hence "Complex_Matrix.trace ((C  D) * R) = 0" by simp
  moreover have "sum (λi. Complex_Matrix.trace ((f (i div nD)*g (i mod nD))m 
    ((A (i div nD))  (B (i mod nD))) * R)) {..< nC * nD} = 0" 
    using assms by simp
  ultimately show ?thesis by simp
qed

lemma tensor_mat_sum_mult_trace_expand:
  assumes "k. k < (nC::nat)  A k  carrier_mat n n"
  and "j. j < (nD::nat)  B j  carrier_mat m m"
  and "R  carrier_mat (n*m) (n*m)"
  and "fixed_carrier_mat.sum_mat n n (λi. f i m (A i)) {..< nC} = C"
  and "fixed_carrier_mat.sum_mat m m (λj. g j m (B j)) {..< nD} = D"
  and "0 < n"
  and "0 < m"
  shows "sum (λi. Complex_Matrix.trace ((f (i div nD) * g (i mod nD))m 
    ((A (i div nD))  (B (i mod nD))) * R)) {..< nC * nD} = 
    Complex_Matrix.trace ((C  D) * R)" 
proof (cases "nD = 0")
  case True
  then show ?thesis 
    using assms tensor_mat_sum_mult_trace_expand_eq_0[OF assms(1)] by simp
next
  case False
  then show ?thesis 
    using assms tensor_mat_sum_mult_trace_expand_ne_0[OF assms(1) assms(2)] 
    by simp
qed

lemma tensor_mat_sum_mult_trace_ne_0:
  assumes "k. k < (nC::nat)  A k  carrier_mat n n"
  and "j. j < (nD::nat)  B j  carrier_mat m m"
  and "R  carrier_mat (n*m) (n*m)"
  and "fixed_carrier_mat.sum_mat n n (λi. f i m (A i)) {..< nC} = C"
  and "fixed_carrier_mat.sum_mat m m (λj. g j m (B j)) {..< nD} = D"
  and "0 < n"
  and "0 < m"
  and "0  nD"
  shows "sum (λi. (sum (λj. Complex_Matrix.trace ((f i * g j)m 
    ((A i)  (B j)) * R)) {..< nD})) {..< nC} = 
    Complex_Matrix.trace ((C  D) * R)" 
proof -
  define fc::"complex Matrix.mat set" where "fc = carrier_mat (n*m) (n*m)"
  interpret cpx_sq_mat "n*m" "n*m" fc  
  proof 
    show "0 < n*m" using assms by simp
  qed (auto simp add: fc_def)
  have "sum (λi. (sum (λj. Complex_Matrix.trace ((f i * g j)m 
    ((A i)  (B j)) * R)) {..< nD})) {..< nC} = 
    sum (λi. Complex_Matrix.trace (sum_mat (λj. (f i * g j)m 
    ((A i)  (B j)) * R) {..< nD})) {..< nC}"
  proof (rule sum.cong)
    fix x
    assume "x  {..< nC}"
    hence "A x  carrier_mat n n" using assms by simp
    show "(j {..< nD}. Complex_Matrix.trace (f x * g j m (A x  B j) * R))=
         Complex_Matrix.trace (sum_mat (λj. f x * g j m (A x  B j) * R) 
          {..<nD})" 
    proof (rule trace_sum_mat[symmetric])
      fix j
      assume "j  {..< nD}"
      hence "B j  carrier_mat m m" using assms by simp
      hence "A x  B j  carrier_mat (n*m) (n*m)" 
        using tensor_mat_carrier
        by (metis A x  carrier_mat n n carrier_matD(1) carrier_matD(2))
      hence "A x  B j  fc"
        using assms dim_eq fc_mats_carrier by blast 
      thus "f x * g j m(A x  B j)*R  fc"
        using smult_mem assms(3) cpx_sq_mat_mult fc_def by blast
    qed simp
  qed simp
  also have "... = Complex_Matrix.trace (sum_mat (λi. 
    (sum_mat (λj. (f i * g j)m ((A i)  (B j)) * R) {..< nD})) {..< nC})"
  proof (rule trace_sum_mat[symmetric])
    fix x
    assume "x{..< nC}"
    hence "A x carrier_mat n n" using assms by simp
    show "sum_mat (λj. f x * g j m (A x  B j) * R) {..<nD}  fc" 
      unfolding fc_def
    proof (rule sum_mat_carrier)
      fix j
      assume "j {..< nD}"
      hence "B j  carrier_mat m m" using assms by simp
      hence "A x  B j  carrier_mat (n*m) (n*m)" 
        using tensor_mat_carrier
        by (metis A x  carrier_mat n n carrier_matD(1) carrier_matD(2))
      hence "A x  B j  fc"
        using assms dim_eq fc_mats_carrier by blast 
      thus "f x * g j m(A x  B j)*R  fc"
        using smult_mem assms(3) cpx_sq_mat_mult fc_def by blast
    qed 
  qed simp
  also have "... = Complex_Matrix.trace 
    (sum_mat (λi. (f (i div nD) * g (i mod nD))m 
    ((A (i div nD))  (B (i mod nD)))*R) {..< nC*nD})" 
  proof -
    have "sum_mat (λi. sum_mat (λj. f i * g j m (A i  B j) * R) {..<nD}) 
      {..<nC} = sum_mat (λi. (f (i div nD) * g (i mod nD))m 
    ((A (i div nD))  (B (i mod nD)))*R) {..< nC*nD}" 
      by (rule sum_sum_mat_expand, (auto simp add: assms))
    thus ?thesis by simp
  qed
  also have "... = (i<nC * nD.
      Complex_Matrix.trace
       (f (i div nD) * g (i mod nD) m (A (i div nD)  B (i mod nD)) * R))"
  proof (rule trace_sum_mat)
    fix i
    assume "i  {..<nC * nD}"
    hence "i div nD <  nC"
      by (simp add: less_mult_imp_div_less) 
    hence "A (i div nD)  carrier_mat n n" using assms by simp
    have "i mod nD < nD" using assms by simp 
    hence "B (i mod nD)  carrier_mat m m" using assms by simp
    hence "A (i div nD)  B (i mod nD)  carrier_mat (n*m) (n*m)" 
      using tensor_mat_carrier
      by (metis A (i div nD)  carrier_mat n n 
          carrier_matD(1) carrier_matD(2))
    hence "(A (i div nD)  B (i mod nD))  fc"
      using assms dim_eq fc_mats_carrier by blast 
    hence "f (i div nD) * g (i mod nD) m(A (i div nD)  B (i mod nD))  fc"
      using smult_mem by blast
    thus "f (i div nD)*g (i mod nD) m (A (i div nD)  B (i mod nD))*R  fc"
      using assms(3) cpx_sq_mat_mult fc_mats_carrier by blast
  qed simp
  also have "... = Complex_Matrix.trace ((C  D) * R)" 
  proof (rule tensor_mat_sum_mult_trace_expand)
    show "k. k < nC  A k  carrier_mat n n" using assms by simp
    show "j. j < nD  B j  carrier_mat m m" using assms by simp
  qed (auto simp add: assms)
  finally show ?thesis .
qed

lemma tensor_mat_sum_mult_trace_eq_0:
  assumes "k. k < (nC::nat)  A k  carrier_mat n n"
  and "R  carrier_mat (n*m) (n*m)"
  and "fixed_carrier_mat.sum_mat n n (λi. f i m (A i)) {..< nC} = C"
  and "fixed_carrier_mat.sum_mat m m (λj. g j m (B j)) {..< nD} = D"
  and "0 < n"
  and "0 < m"
  and "0 = (nD::nat)"
  shows "sum (λi. (sum (λj. Complex_Matrix.trace ((f i * g j)m 
    ((A i)  (B j)) * R)) {..< nD})) {..< nC} = 
    Complex_Matrix.trace ((C  D) * R)" 
proof -
  define fc::"complex Matrix.mat set" where "fc = carrier_mat (n*m) (n*m)"
  interpret cpx_sq_mat "n*m" "n*m" fc  
  proof 
    show "0 < n*m" using assms by simp
  qed (auto simp add: fc_def)
  have "fixed_carrier_mat.sum_mat m m (λj. g j m (B j)) {} = 0m m m"
    using assms fixed_carrier_mat.sum_mat_empty[of _ m m ]
    fixed_carrier_mat.intro by fastforce
  hence "D = 0m m m" using assms by simp
  hence "C  D = C  (0m m m)" by simp
  also have "... = 0m (n*m) (n*m)" 
  proof (rule tensor_mat_zero')
    have "fixed_carrier_mat.sum_mat n n (λi. f i m A i) {..<nC}  
      carrier_mat n n" 
    proof (rule fixed_carrier_mat.sum_mat_carrier)
      show "fixed_carrier_mat (carrier_mat n n) n n"
        by (simp add: fixed_carrier_mat.intro)
      show "i. i  {..<nC}  f i m A i  carrier_mat n n" using assms
        by simp
    qed 
    thus "C carrier_mat n n" using assms by simp
    show "0<n" "0<m" using assms by auto
  qed 
  finally have "C  D = 0m (n*m) (n*m)" .
  hence "(C  D) * R = 0m (n*m) (n*m)"
    by (simp add: assms left_mult_zero_mat)
  hence 1: "Complex_Matrix.trace ((C  D) * R) = 0" by simp
  have "i. i{..< nC}  sum (λj. Complex_Matrix.trace ((f i * g j)m 
    ((A i)  (B j)) * R)) {..< nD} = 0"
  proof -
    fix i
    assume "i {..< nC}"
    show "sum (λj. Complex_Matrix.trace ((f i * g j)m 
      ((A i)  (B j)) * R)) {..< nD} = 0" using assms by simp
  qed
  hence "sum (λi. (sum (λj. Complex_Matrix.trace ((f i * g j)m 
    ((A i)  (B j)) * R)) {..< nD})) {..< nC} = 0" by simp
  thus ?thesis using 1 by simp
qed

lemma tensor_mat_sum_mult_trace:
  assumes "k. k < (nC::nat)  A k  carrier_mat n n"
  and "j. j < (nD::nat)  B j  carrier_mat m m"
  and "R  carrier_mat (n*m) (n*m)"
  and "fixed_carrier_mat.sum_mat n n (λi. f i m (A i)) {..< nC} = C"
  and "fixed_carrier_mat.sum_mat m m (λj. g j m (B j)) {..< nD} = D"
  and "0 < n"
  and "0 < m"
  shows "sum (λi. (sum (λj. Complex_Matrix.trace ((f i * g j)m 
    ((A i)  (B j)) * R)) {..< nD})) {..< nC} = 
    Complex_Matrix.trace ((C  D) * R)" 
proof (cases "nD = 0")
  case True
  then show ?thesis using assms tensor_mat_sum_mult_trace_eq_0[OF assms(1)] 
    by simp
next
  case False
  then show ?thesis 
    using assms tensor_mat_sum_mult_trace_ne_0[OF assms(1) assms(2)] by simp
qed

lemma tensor_mat_make_pm_mult_trace:
  assumes "A  carrier_mat n n"
  and "hermitian A"
  and "B carrier_mat m m"
  and "hermitian B"
  and "R  carrier_mat (n*m) (n*m)"
  and "(nA, M) = cpx_sq_mat.make_pm n n A"
  and "(nB, N) = cpx_sq_mat.make_pm m m B"
  and "0 < n"
  and "0 < m"
shows "sum (λi. (sum (λj. Complex_Matrix.trace 
    ((complex_of_real (meas_outcome_val (M i)) * 
    complex_of_real (meas_outcome_val (N j)))m 
    ((meas_outcome_prj (M i))  (meas_outcome_prj (N j))) * R)) {..< nB})) 
    {..< nA} = 
    Complex_Matrix.trace ((A  B) * R)" 
proof (rule tensor_mat_sum_mult_trace)
  have A: "cpx_sq_mat.proj_measurement n n (carrier_mat n n) nA M" 
  proof (rule cpx_sq_mat.make_pm_proj_measurement)
    show "A  carrier_mat n n" using assms by simp
    show "cpx_sq_mat n n (carrier_mat n n)"
      by (simp add: assms cpx_sq_mat.intro cpx_sq_mat_axioms.intro 
          fixed_carrier_mat_def)
  qed (auto simp add: assms)
  have B: "cpx_sq_mat.proj_measurement m m (carrier_mat m m) nB N" 
  proof (rule cpx_sq_mat.make_pm_proj_measurement)
    show "B  carrier_mat m m" using assms by simp
    show "cpx_sq_mat m m (carrier_mat m m)"
      by (simp add: assms cpx_sq_mat.intro cpx_sq_mat_axioms.intro 
          fixed_carrier_mat_def)
  qed (auto simp add: assms)
  show "k. k < nA  meas_outcome_prj (M k)  carrier_mat n n"
  proof -
    fix k
    assume "k < nA"
    show "meas_outcome_prj (M k)  carrier_mat n n" 
      using cpx_sq_mat.proj_measurement_carrier
      by (meson A k < nA assms(8) cpx_sq_mat_axioms.intro cpx_sq_mat_def 
          fixed_carrier_mat.intro)
  qed
  show "k. k < nB  meas_outcome_prj (N k)  carrier_mat m m"
  proof -
    fix k
    assume "k < nB"
    show "meas_outcome_prj (N k)  carrier_mat m m" 
      using cpx_sq_mat.proj_measurement_carrier
      by (meson B k < nB assms(9) cpx_sq_mat_axioms.intro cpx_sq_mat_def 
          fixed_carrier_mat.intro)
  qed
  show "fixed_carrier_mat.sum_mat n n 
    (λi. complex_of_real (meas_outcome_val (M i)) m meas_outcome_prj (M i)) 
    {..<nA} = A"
  proof (rule cpx_sq_mat.make_pm_sum)
    show "cpx_sq_mat n n (carrier_mat n n)"
      by (simp add: assms cpx_sq_mat.intro cpx_sq_mat_axioms.intro 
          fixed_carrier_mat_def)
  qed (auto simp add: assms)
  show "fixed_carrier_mat.sum_mat m m
    (λi. complex_of_real (meas_outcome_val (N i)) m meas_outcome_prj (N i)) 
    {..<nB} = B"
  proof (rule cpx_sq_mat.make_pm_sum)
    show "cpx_sq_mat m m (carrier_mat m m)"
      by (simp add: assms cpx_sq_mat.intro cpx_sq_mat_axioms.intro 
          fixed_carrier_mat_def)
  qed (auto simp add: assms)
qed (auto simp add: assms)

lemma tensor_mat_mat_conj:
  assumes "A carrier_mat n n"
  and "B  carrier_mat n n"
  and "U  carrier_mat n n"
  and "C carrier_mat m m"
  and "D  carrier_mat m m"
  and "V  carrier_mat m m"
  and "0 < n"
  and "0 < m"
  and "A = mat_conj U B"
  and "C = mat_conj V D"
shows "A  C = mat_conj (U  V) (B  D)" 
proof -
  have "A  C = (U * B * Complex_Matrix.adjoint U)  
    (V * D * Complex_Matrix.adjoint V)" using assms unfolding mat_conj_def 
    by simp
  also have "... = (U*B  (V*D)) * 
    (Complex_Matrix.adjoint U  Complex_Matrix.adjoint V)" 
    using mult_distr_tensor assms by simp
  also have "... = (U  V) * (B  D) * Complex_Matrix.adjoint (U  V)"
    using mult_distr_tensor assms
    by (metis carrier_matD(1) carrier_matD(2) tensor_mat_adjoint)
  finally show ?thesis unfolding mat_conj_def by simp
qed

lemma unitarily_equiv_mat_conj[simp]:
  assumes "unitarily_equiv A B U"
  shows "A = mat_conj U B" unfolding mat_conj_def
  by (simp add: assms unitarily_equiv_eq)

lemma hermitian_tensor_mat_decomp:
  assumes "A carrier_mat n n"
  and "C carrier_mat m m"
  and "unitary_diag A B U"
  and "unitary_diag C D V"
  and "0 < n"
  and "0 < m"
shows "unitary_diag (A  C) (B  D) (U  V)"
proof (rule unitary_diagI')
  show "A  C  carrier_mat (n *m) (n*m)" using assms
    by (metis carrier_matD(1) carrier_matD(2) tensor_mat_carrier)
  show "B  D  carrier_mat (n * m) (n * m)" using assms
    by (metis (no_types, opaque_lifting)  carrier_matD(1) 
        carrier_matD(2) carrier_mat_triv dim_col_tensor_mat 
        dim_row_tensor_mat unitary_diag_carrier(1))
  show "Complex_Matrix.unitary (U  V)"
    by (metis Complex_Matrix.unitary_def assms(3) assms(4) 
        carrier_matD(2) carrier_mat_triv dim_col_tensor_mat dim_row_tensor_mat 
        nat_0_less_mult_iff tensor_mat_unitary unitary_diagD(3) unitary_zero 
        zero_order(5))
  show "diagonal_mat (B  D)" using tensor_mat_diagonal
    by (meson assms(3) assms(4) unitarily_equiv_carrier'(2) unitary_diagD(2) 
        unitary_diag_imp_unitarily_equiv)
  show "A  C = mat_conj (U  V) (B  D)" 
  proof (rule tensor_mat_mat_conj[of _ n _ _ _ m])
    show "B carrier_mat n n"
      using assms(1) assms(3) unitary_diag_carrier(1) by auto
    show "D  carrier_mat m m"
      using assms unitary_diag_carrier(1) by auto
    show "U  carrier_mat n n"
      using assms(1) assms(3) unitary_diag_carrier(2) by blast
    show "V  carrier_mat m m"
      using assms unitary_diag_carrier(2) by blast
  qed (auto simp add: assms)
qed

end