Theory Sylvester_Criterion

theory Sylvester_Criterion
  imports Misc_Matrix_Results

begin

section "Sylvester's Criterion Setup"

definition sylvester_criterion :: "('a::{comm_ring_1,ord}) mat  bool" where
  "sylvester_criterion A  (k  {0..dim_row A}.  Determinant.det (lps A k) > 0)"

lemma leading_principle_submatrix_sylvester:
  assumes "A  carrier_mat n n"
  assumes "m  n"
  assumes "sylvester_criterion A"
  shows "sylvester_criterion (lps A m)"
  using nested_leading_principle_submatrices
  by (smt (verit, del_insts) assms atLeastAtMost_iff carrier_matD(1) order.trans leading_principal_submatrix_carrier sylvester_criterion_def)

lemma sylvester_criterion_positive_det:
  assumes "A  carrier_mat n n"
  assumes "sylvester_criterion A"
  shows "det A > 0"
proof-
  have "A = lps A n"
    unfolding leading_principal_submatrix_def submatrix_def
    using assms(1) pick_n_le
    by auto
  thus ?thesis using assms unfolding sylvester_criterion_def by force
qed

section "Sylvester's Criterion"

subsection "Forward Implication"

lemma sylvester_criterion_forward:
  fixes A :: "complex mat"
  assumes "A  carrier_mat n n"
  assumes "x  carrier_vec n"
  assumes "hermitian A"
  assumes "sylvester_criterion A"
  assumes "x  0v n"
  shows "Re (QF A x) > 0"
  using assms
proof(induction n arbitrary: A x)
  case 0
  then show ?case by (metis carrier_vecD eq_vecI not_less_zero zero_carrier_vec)
next
  case (Suc n)

  have *: "k. k  dim_row A  det (leading_principal_submatrix A k) > 0"
    using Suc(5) atLeastAtMost_iff unfolding sylvester_criterion_def by blast

  define An where "An  (leading_principal_submatrix A n)"
  define vn where "vn  vec_first (col A n) n"
  define vnc where "vnc  conjugate vn"
  define wn where "wn  vec_first (row A n) n"
  define a where "a  A $$ (n, n)"
  define xn where "xn  vec_first x n"
  define xnc where "xnc = conjugate xn"
  define b where "b  x$n"
  define b_conj where "b_conj  conjugate b"

  have carrier_An: "An  carrier_mat n n"
    by (metis An_def Suc.prems(1) le_add2 leading_principal_submatrix_carrier plus_1_eq_Suc)
  have herm_An: "hermitian An"
    using principal_submatrix_hermitian[of A "Suc n" "{..n}"] An_def
    unfolding leading_principal_submatrix_def
    by (metis Suc.prems(1) Suc.prems(3) dual_order.refl le_SucI lessThan_subset_iff principal_submatrix_hermitian)

  have "(col A n) = conjugate (row A n)"
    by (metis Suc.prems(1) Suc.prems(3) adjoint_col carrier_matD(1) hermitian_def lessI)
  hence wn_vn_conj: "wn = vnc"
    by (metis Suc.prems(1) conjugate_vec_first col_carrier_vec conjugate_id le_add2 lessI plus_1_eq_Suc vn_def vnc_def wn_def)

  have "invertible_mat An"
    by (metis "*" An_def Suc(2) carrier_An carrier_matD(1) invertible_det le_add2 less_irrefl plus_1_eq_Suc)
  then obtain An' where An': "inverts_mat An' An  An'  carrier_mat n n"
    by (metis (no_types, lifting) invertible_mat_def An_def Suc.prems(1) carrier_matD(1) carrier_matI index_mult_mat(3) index_one_mat(3) inverts_mat_def le_add2 leading_principal_submatrix_carrier plus_1_eq_Suc square_mat.simps)

  have xn: "xn  carrier_vec n" by (simp add: xn_def)
  moreover have An: "An  carrier_mat n n"
    using leading_principal_submatrix_carrier
    by (metis An_def Suc.prems(1) Suc_n_not_le_n linorder_linear)
  ultimately have An_xn: "An *v xn  carrier_vec n" by fastforce
  have vn: "vn  carrier_vec n" by (simp add: vn_def)
  hence b_vn: "b v vn  carrier_vec n" by simp

  have "(An *v xn + b v vn)  carrier_vec n" using An_xn b_vn by auto

  from herm_An have hermitian: "hermitian An'"
    by (metis hermitian_mat_inv An' An inverts_mat_symm)
  hence An_inv_conj: "conjugate An' = An'T"
    by (metis conjugate_id hermitian_def adjoint_is_conjugate_transpose)

  have **: "(An *v (xn + b v (An' *v vn)))  (xnc + b_conj v (An'T *v vnc))
      = (QF An xn) + b * (xnc  vn) + b_conj * (xn  vnc) + (cmod b)^2 * ((An' *v vn)  vnc)"
    (is "?lhs = ?rhs")
  proof-
    define E where "E  ((An *v xn)  (xnc + b_conj v (An'T *v vnc)))"
    define F where "F  ((b v vn)  (xnc + b_conj v (An'T *v vnc)))"

    have "An *v (xn + b v (An' *v vn)) = (An *v xn) + b v vn"
      (is "?lhs' = _")
    proof-
      have "?lhs' = An *v xn + An *v (b v (An' *v vn))"
        by (meson An' carrier_An mult_add_distrib_mat_vec mult_mat_vec_carrier smult_carrier_vec vn xn)
      also have "... = An *v xn + (b v ((An * An') *v vn))"
        by (metis An' assoc_mult_mat_vec carrier_An mult_mat_vec mult_mat_vec_carrier vn)
      also have "... = (An *v xn) + b v vn"
        by (metis An' carrier_An carrier_matD(1) inverts_mat_def inverts_mat_symm one_mult_mat_vec vn)
      finally show ?thesis .
    qed
    hence "?lhs = ((An *v xn) + b v vn)  (xnc + b_conj v (An'T *v vnc))" by argo
    moreover have "... = E + F"
      unfolding E_def F_def
      by (metis An' An_xn Matrix.carrier_vec_conjugate add_carrier_vec add_scalar_prod_distrib mult_mat_vec_carrier smult_carrier_vec transpose_carrier_mat vnc_def vn xnc_def xn)
    moreover have "E = QF An xn + b_conj * (xn  vnc)"
    proof-
      have "E = ((An *v xn)  xnc) + ((An *v xn)  (b_conj v (An'T *v vnc)))"
        unfolding E_def
        by (metis An' An_xn Matrix.carrier_vec_conjugate mult_mat_vec_carrier scalar_prod_add_distrib smult_carrier_vec transpose_carrier_mat vnc_def vn xnc_def xn)
      moreover have "((An *v xn)  xnc) = QF An xn" by (simp add: xnc_def)
      moreover have "((An *v xn)  (b_conj v (An'T *v vnc))) = b_conj * (xn  vnc)"
        (is "?lhs = _")
      proof-
        have "?lhs = b_conj * ((An *v xn)  (An'T *v vnc))" using An An' by auto
        also have "... = b_conj * ((An *v xn)  (conjugate An' *v vnc))"
          using An_inv_conj by presburger
        also have "... = b_conj * (((An' * An) *v xn)  vnc)"
          by (smt (verit) An An' conj_mat_vec_mult hermitian hermitian_def inner_prod_mult_mat_vec_right vnc_def vn xn)
        also have "... = b_conj * (xn  vnc)"
          by (metis An' carrier_matD(1) inverts_mat_def one_mult_mat_vec xn)
        finally show ?thesis .
      qed
      ultimately show ?thesis by argo
    qed
    moreover have "F = b * (xnc  vn) + (cmod b)^2 * ((An' *v vn)  vnc)"
    proof-
      have "F = (b v vn)  (xnc) + (b v vn)  (b_conj v (An'T *v vnc))"
        unfolding F_def
        by (metis An' Matrix.carrier_vec_conjugate b_vn carrier_matD(2) carrier_vec_dim_vec dim_mult_mat_vec index_smult_vec(2) index_transpose_mat(2) scalar_prod_add_distrib xnc_def xn)
      moreover have "(b v vn)  (xnc) = b * (vn  xnc)" using vn xnc_def xn by auto
      moreover have "(b v vn)  (b_conj v (An'T *v vnc)) = (cmod b)^2 * ((An' *v vn)  vnc)"
        (is "?lhs = _")
      proof-
        have "?lhs = (cmod b)^2 * (vn  (An'T *v vnc))"
          using An' vn b_conj_def complex_norm_square by force
        also have "... = (cmod b)^2 * ((An' *v vn)  vnc)"
          by (metis An_inv_conj An' adjoint_def_alter conj_mat_vec_mult hermitian hermitian_def vnc_def vn)
        finally show ?thesis .
      qed
      ultimately show ?thesis by (metis conjugate_vec_sprod_comm vn xnc_def xn)
    qed
    ultimately show ?thesis by fastforce
  qed

  let ?cn = "b v (An' *v vn)"
  have cn: "?cn  carrier_vec n"
    by (metis An' An_xn invertible_mat An carrier_vecD carrier_vec_dim_vec dim_mult_mat_vec index_mult_mat(3) index_one_mat(3) invertible_mat_def inverts_mat_def smult_carrier_vec square_mat.simps)

  have "A  carrier_mat (Suc n) (Suc n)"
    by (simp add: Suc.prems(1))
  moreover have "x  carrier_vec (Suc n)"
    by (simp add: Suc.prems(2))
  ultimately have Ax: "A *v x = (An *v xn + b v vn) @v (vec 1 (λi. (wn  xn) + a * b))"
      (is "_ = _ @v ?Ax_last")
    using mat_vec_prod_leading_principal_submatrix
    unfolding An_def a_def b_def vn_def wn_def xn_def by blast

  hence "QF A x = ... ∙c x" by force
  also have "... = ((An *v xn + b v vn) ∙c xn) + ((wn  xn) + a * b) * b_conj"
  proof-
    have "x  carrier_vec (dim_vec ((An *v xn + b v vn) @v ?Ax_last))"
      using Suc.prems(2) vn by force
    moreover have "(An *v xn + b v vn) ∙c (vec_first x (dim_vec (An *v xn + b v vn)))
        = (An *v xn + b v vn) ∙c xn"
      by (simp add: vn_def xn_def)
    moreover have "dim_vec ?Ax_last = 1" by simp
    moreover have "?Ax_last ∙c (vec_last x 1)  = (wn  xn + a * b) * b_conj"
    proof-
      have "dim_vec ?Ax_last = 1" by simp
      moreover have "(vec_last x 1)$0 = b"
        by (smt (verit) Suc.prems(2) add.commute add.right_neutral add_diff_cancel_right' b_def carrier_vecD index_vec plus_1_eq_Suc vec_last_def zero_less_one_class.zero_less_one)
      moreover have "?Ax_last$0 = (wn  xn + a * b)" by simp
      moreover have "?Ax_last ∙c (vec_last x 1) = ?Ax_last$0 * conjugate ((vec_last x 1)$0)"
        unfolding scalar_prod_def by force
      ultimately show ?thesis using b_conj_def by presburger
    qed
    ultimately show ?thesis by (simp add: inner_prod_append(2))
  qed
  also have "... = QF An xn + ((b v vn) ∙c xn) + ((wn  xn) * b_conj) + (a * b * b_conj)"
    using inner_prod_distrib_right[of xn n "An *v xn" "b v vn"] b_vn An_xn
    by (simp add: ring_class.ring_distribs(2) xn)
  also have "... = QF An xn + ((b v vn) ∙c xn) + ((wn  xn) * b_conj) + (a * (cmod b)^2)"
    using b_conj_def complex_norm_square by auto
  also have "... = QF An xn + b * (xnc  vn) + b_conj * (xn  vnc) + (a * (cmod b)^2)"
    by (metis conjugate_vec_sprod_comm inner_prod_smult_left mult.commute vnc_def vn wn_vn_conj xnc_def xn)
  also have "... = (An *v (xn + b v (An' *v vn)))  (xnc + b_conj v (An'T *v vnc))
      - (cmod b)^2 * ((An' *v vn)  vnc) + (a * (cmod b)^2)"
    using ** by fastforce
  also have "... = (An *v (xn + b v (An' *v vn)))  (xnc + b_conj v (An'T *v vnc))
      + (cmod b)^2 * (a - QF An' vn)"
    by (simp add: right_diff_distrib vnc_def)
  also have "... = QF An (xn + ?cn) + (cmod b)^2 * (a - QF An' vn)"
  proof-
    have "conjugate (An' *v vn) = (conjugate An' *v vnc)"
      by (metis conj_mat_vec_mult adjoint_dim_col carrier_mat_triv carrier_vecD cn dim_mult_mat_vec hermitian hermitian_def index_smult_vec(2) vnc_def vn)
    thus ?thesis
      by (smt (verit, ccfv_threshold) An_inv_conj b_conj_def cn conjugate_add_vec conjugate_smult_vec xnc_def xn quadratic_form_def)
  qed
  finally have eq: "QF A x = QF An (xn + ?cn) + (cmod b)^2 * (a - QF An' vn)" .

  have xn_cn: "(xn + ?cn)  carrier_vec n" using add_carrier_vec cn xn by blast
  have "sylvester_criterion An"
    using leading_principle_submatrix_sylvester
    by (metis An_def Suc.prems(1) Suc.prems(4) Suc_n_not_le_n linorder_linear) 
  hence 1: "Re (QF An (xn + ?cn)) > 0" if "xn + ?cn  0v n"
    using Suc.IH[OF carrier_An xn_cn herm_An] that by metis
  have 2: "xn + ?cn  0v n" if "b = 0"
  proof-
    have "?cn = 0v n"
      by (metis that cn conjugate_square_eq_0_vec inner_prod_smult_left mult_eq_0_iff smult_smult_assoc)
    hence *: "xn + ?cn = xn" by (simp add: xn)
    show ?thesis
    proof(rule ccontr)
      assume "¬ xn + ?cn  0v n"
      hence "xn = 0v n" using * by argo
      hence "i < n. xn$i = 0" by fastforce
      moreover have "i < n. xn$i = x$i" by (simp add: vec_first_def xn_def)
      moreover have "x$n = 0" using that unfolding b_def .
      ultimately have "i < Suc n. x$i = 0" using less_Suc_eq by presburger
      thus False using Suc.prems(2,4,5) by auto
    qed
  qed
  have 3: "a - QF An' vn > 0"
  proof-
    have "det A = det An * (a - QF An' vn)"
    proof-
      let ?B = "mat_of_cols n [vn]"
      let ?C = "mat_of_rows n [conjugate vn]"
      let ?D = "mat 1 1 (λ_. a)"

      have "(An, ?B, ?C, ?D) = split_block A n n"
      proof-
        have "An = mat n n (($$) A)"
          by (metis An_def An An_xn Suc(2) carrier_matD(2) carrier_vecD dim_col_mat(1) dim_mult_mat_vec dim_row_mat(1) index_mat(1) le_add2 leading_principal_submatrix_index mat_eq_iff plus_1_eq_Suc)
        moreover have "?B = mat n (dim_col A - n) (λ(i, j). A $$ (i, j + n))"
          (is "?lhs = ?rhs")
        proof
          show "dim_row ?lhs = dim_row ?rhs" by simp
          show "dim_col ?lhs = dim_col ?rhs" using Suc(2) by force
          show "i j. i < dim_row ?rhs  j < dim_col ?rhs  ?lhs$$(i,j) = ?rhs$$(i,j)"
          proof-
            fix i j assume *: "i < dim_row ?rhs" "j < dim_col ?rhs"
            hence "j = 0" using Suc(2) by auto
            thus "?lhs$$(i,j) = ?rhs$$(i,j)"
              apply (simp add: vn_def vec_first_def mat_of_cols_def)
              using "*"(1) Suc(2) by auto
          qed
        qed
        moreover have "?C = mat (dim_row A - n) n (λ(i, j). A $$ (i + n, j))"
          (is "?lhs = ?rhs")
        proof
          show "dim_row ?lhs = dim_row ?rhs" using Suc(2) by force
          show "dim_col ?lhs = dim_col ?rhs" by simp
          show "i j. i < dim_row ?rhs  j < dim_col ?rhs  ?lhs$$(i,j) = ?rhs$$(i,j)"
          proof-
            fix i j assume *: "i < dim_row ?rhs" "j < dim_col ?rhs"
            hence "i = 0" using Suc(2) by auto
            moreover have "conjugate (vec n (($) (col A n))) = (vec n (($) (row A n)))"
              using Suc(4)
              unfolding hermitian_def adjoint_def
              by (metis vn_def vnc_def vec_first_def wn_def wn_vn_conj)
            ultimately show "?lhs$$(i,j) = ?rhs$$(i,j)"
              apply (simp add: vn_def vec_first_def mat_of_cols_def)
              using "*"(2) Suc(2) by (simp add: mat_of_rows_def)
          qed
        qed
        moreover have "?D = mat (dim_row A - n) (dim_col A - n) (λ(i, j). A $$ (i + n, j + n))"
          (is "?lhs = ?rhs")
        proof
          show row: "dim_row ?lhs = dim_row ?rhs" using Suc(2) by fastforce
          show col: "dim_col ?lhs = dim_col ?rhs" using Suc(2) by fastforce
          show "i j. i < dim_row ?rhs  j < dim_col ?rhs  ?lhs$$(i,j) = ?rhs$$(i,j)"
          proof-
            fix i j assume *: "i < dim_row ?rhs" "j < dim_col ?rhs"
            hence "i = 0  j = 0" using Suc(2) by auto
            thus "?lhs$$(i,j) = ?rhs$$(i,j)"
              apply (simp add: a_def)
              using col row by force
          qed
        qed
        ultimately show ?thesis unfolding split_block_def by metis
      qed
      hence "det A = det An * det (?D - ?C * An' * ?B)"
        using schur_formula[of An ?B ?C ?D A n n An'] An' Suc(4) herm_An hermitian_is_square
        by (metis Suc(2) carrier_matD(1) carrier_matD(2) lessI)
      moreover have "det (?D - ?C * An' * ?B) = (a - QF An' vn)"
      proof-
        have "?C * An' * ?B = mat 1 1 (λ_. QF An' vn)" (is "?lhs = ?rhs")
        proof
          have dim: "?C * An' * ?B  carrier_mat 1 1" by (simp add: carrier_matI)
          thus "dim_row ?lhs = dim_row ?rhs" "dim_col ?lhs = dim_col ?rhs" by auto
          have "col (An' * ?B) 0 = An' *v vn"
            by (metis An' vn mat_vec_as_mat_mat_mult)
          moreover have "row ?C 0 = conjugate vn" using vnc_def wn_def wn_vn_conj by auto
          moreover have "(?C * (An' * ?B))$$(0,0) = row ?C 0  col (An' * ?B) 0" by simp
          moreover have "?C * (An' * ?B) = ?C * An' * ?B"
            by (metis An' assoc_mult_mat carrier_matI mat_of_cols_carrier(2) mat_of_rows_carrier(3))
          ultimately have "(?C * An' * ?B)$$(0,0) = (conjugate vn)  (An' *v vn)" by argo
          also have "... = (An' *v vn) ∙c vn" 
            by (metis An' conjugate_vec_sprod_comm mult_mat_vec_carrier vn)
          also have "... = QF An' vn" by simp
          finally show "i j. i < dim_row ?rhs  j < dim_col ?rhs  ?lhs$$(i,j) = ?rhs$$(i,j)"
            by fastforce
        qed
        hence "?D - ?C * An' * ?B = mat 1 1 (λ_. a - QF An' vn)" by fastforce
        thus ?thesis by (simp add: det_single)
      qed
      ultimately show ?thesis by argo
    qed
    moreover have "det A > 0" using Suc.prems(1,4) sylvester_criterion_positive_det by blast
    moreover have "det An > 0" using Suc(2,5) unfolding An_def sylvester_criterion_def by simp
    ultimately show ?thesis by (simp add: less_complex_def zero_less_mult_iff)
  qed
  have 4: "(cmod b)^2 > 0" if "b  0" using that by force

  have ?case if "b = 0"
  proof-
    have "Re (QF An (xn + ?cn)) > 0" using that 1 2 by blast
    thus ?thesis unfolding eq by (simp add: that)
  qed
  moreover have ?case if "b  0"
  proof-
    have "(cmod b)^2 * (a - QF An' vn) > 0"
      using 3 4[OF that] by (simp add: less_le square_nneg_complex)
    moreover have "Re (QF An (xn + ?cn))  0" using 1 carrier_An by fastforce
    ultimately show ?thesis unfolding eq by (simp add: less_complex_def)
  qed
  ultimately show ?case by blast
qed

subsection "Reverse Implication"

lemma prod_list_gz:
  fixes l :: "real list"
  assumes "x  set l. x > 0"
  shows "prod_list l > 0"
  using assms apply (induct l)
   apply fastforce
  by auto

lemma sylvester_criterion_reverse:
  fixes A :: "complex mat"
  assumes "A  carrier_mat n n"
  assumes "hermitian A"
  assumes "positive_definite A"
  shows "sylvester_criterion A"
  unfolding sylvester_criterion_def
proof
  fix k assume k: "k  {0..dim_row A}"
  let ?A' = "lps A k"
  have pd: "positive_definite ?A'"
    using assms(1,3) leading_principal_submatrix_positive_definite k by auto
  hence det_nz: "det ?A'  0" using positive_definite_det_nz by blast
  have square: "square_mat ?A'" using pd hermitian_is_square positive_definite_def by blast
  have A'_dim: "?A'  carrier_mat k k"
    using assms(1) k leading_principal_submatrix_carrier by auto

  have "e  set (map Re (eigvals ?A')). e > 0"
  proof
    fix e assume "e  set (map Re (eigvals ?A'))"
    then obtain e' where e': "e'  set (eigvals ?A')  e = Re e'"
      by auto
    moreover have "e' > 0"
    proof-
      have "e'  spectrum ?A'"
        by (metis e' Projective_Measurements.spectrum_def Spectral_Radius.spectrum_def hermitian_square mem_Collect_eq pd positive_definite_def spectrum_eigenvalues)
      then obtain x where x: "x  carrier_vec k  x  0v k  ?A' *v x = e' v x"
        unfolding spectrum_def eigenvalue_def eigenvector_def using A'_dim by auto
      hence "e' * (x ∙c x) > 0" using pd A'_dim unfolding positive_definite_def by fastforce
      moreover have "x ∙c x > 0" using conjugate_square_greater_0_vec x by blast
      ultimately show ?thesis by (simp add: less_complex_def zero_less_mult_iff)
    qed
    ultimately show "e > 0" by (simp add: less_complex_def)
  qed
  hence "prod_list (map Re (eigvals ?A')) > 0"
    using prod_list_gz by blast
  moreover have "prod_list (eigvals ?A') = prod_list (map Re (eigvals ?A'))"
  proof-
    have "i < (length (eigvals ?A')). (eigvals ?A')!i = (map Re (eigvals ?A'))!i"
    proof safe
      fix i assume *: "i < length (eigvals ?A')"
      hence "(eigvals ?A')!i  Reals"
        by (metis eigenvalue_root_char_poly eigvals_poly_length hermitian_eigenvalues_real hermitian_square linear_poly_root nth_mem pd positive_definite_def)
      thus "(eigvals ?A')!i = (map Re (eigvals ?A'))!i" using * by auto
    qed
    thus ?thesis
      by (metis length_map map_nth_eq_conv of_real_hom.hom_prod_list)
  qed
  ultimately show "0 < det ?A'"
    using det_is_prod_of_eigenvalues[OF square] by (simp add: less_complex_def)
qed

subsection "Theorem Statement"

theorem sylvester_criterion:
  fixes A :: "complex mat"
  assumes "A  carrier_mat n n"
  assumes "hermitian A"
  shows "sylvester_criterion A  positive_definite A"
proof
  show 1: "sylvester_criterion A  positive_definite A"
    unfolding positive_definite_def
    using sylvester_criterion_forward[of A n] assms complex_is_Real_iff hermitian_quadratic_form_real less_complex_def
    by simp
  show 2: "positive_definite A  sylvester_criterion A"
    using sylvester_criterion_reverse[OF assms(1,2)] .
qed

end