Theory Combinatorial_Nullstellensatz

theory Combinatorial_Nullstellensatz
  imports "HOL-Computational_Algebra.Polynomial"
begin

section ‹Alon's Combinatorial Nullstellensatz›

text ‹
  Alon's Combinatorial Nullstellensatz citeAlon1999 is a
  polynomial method theorem: if the coefficient of a distinguished monomial
  of total degree d1 + … + dn is nonzero, then the polynomial cannot
  vanish on every point of any grid whose i›-th side has more than
  di elements.

  The development below proves this statement for sparse multivariate
  polynomials over arbitrary fields.  A polynomial is represented by a
  finite-support coefficient function from exponent lists to coefficients.
  This keeps the formalization independent of a particular multivariate
  polynomial library while still matching the usual mathematical statement.
›

subsection ‹Univariate interpolation›

text ‹
  The proof starts with a small amount of univariate interpolation.  For a
  finite set S› and a point x ∈ S›, the following denominator and basis
  polynomial are the standard Lagrange factors.
›

definition lagrange_denom :: "'a::field set  'a  'a" where
  "lagrange_denom S x = (yS - {x}. x - y)"

definition lagrange_basis :: "'a::field set  'a  'a poly" where
  "lagrange_basis S x =
     smult (inverse (lagrange_denom S x))
       (yS - {x}. [:- y, 1:])"

lemma lagrange_denom_nonzero:
  assumes "finite S" "x  S"
  shows "lagrange_denom S x  0"
  using assms by (auto simp: lagrange_denom_def)

lemma poly_lagrange_basis:
  assumes "finite S" "x  S" "z  S"
  shows "poly (lagrange_basis S x) z = (if z = x then 1 else 0)"
proof (cases "z = x")
  case True
  have "poly (yS - {x}. [:- y, 1:]) x = lagrange_denom S x"
    by (simp add: lagrange_denom_def poly_prod)
  with assms True lagrange_denom_nonzero[of S x] show ?thesis
    by (simp add: lagrange_basis_def)
next
  case False
  with assms have "z  S - {x}"
    by simp
  then have "poly (yS - {x}. [:- y, 1:]) z = 0"
    using assms by (simp add: poly_prod)
  with False show ?thesis
    by (simp add: lagrange_basis_def)
qed

lemma degree_lagrange_basis_le:
  assumes "finite S" "x  S"
  shows "degree (lagrange_basis S x)  card S - 1"
proof -
  have "degree (yS - {x}. [:- y, 1:] :: 'a poly)
      = (yS - {x}. degree ([:- y, 1:] :: 'a::field poly))"
    using assms by (intro degree_prod_sum_eq) auto
  also have "  card S - 1"
    using assms by auto
  finally show ?thesis
    by (simp add: lagrange_basis_def degree_smult_le dual_order.trans)
qed

lemma lagrange_interpolation:
  fixes P :: "'a::{field,ring_no_zero_divisors} poly"
  assumes fin: "finite S" and deg: "degree P < card S"
  shows "P = (xS. smult (poly P x) (lagrange_basis S x))"
proof (rule poly_eqI_degree[where A = S])
  fix z
  assume z: "z  S"
  have "poly (xS. smult (poly P x) (lagrange_basis S x)) z =
      (xS. poly P x * poly (lagrange_basis S x) z)"
    by (simp add: poly_sum)
  also have " = poly P z"
    using fin z
    by (subst sum.remove[of S z]) (auto simp: poly_lagrange_basis eq_commute)
  finally show "poly P z = poly (xS. smult (poly P x) (lagrange_basis S x)) z"
    by simp
next
  show "card S > degree P"
    using deg by simp
next
  have "degree (xS. smult (poly P x) (lagrange_basis S x))  card S - 1"
  proof (intro degree_sum_le fin)
    fix x
    assume "x  S"
    then show "degree (smult (poly P x) (lagrange_basis S x))  card S - 1"
      by (meson degree_lagrange_basis_le degree_smult_le fin order_trans)
  qed
  moreover from deg have "card S > 0"
    by simp
  ultimately show "card S > degree (xS. smult (poly P x) (lagrange_basis S x))"
    by linarith
qed

lemma coeff_lagrange_basis_top:
  assumes fin: "finite S" and x: "x  S"
  shows "poly.coeff (lagrange_basis S x) (card S - 1) =
    inverse (lagrange_denom S x)"
proof -
  have deg_prod: "degree (yS - {x}. [:- y, 1:] :: 'a::field poly) = card S - 1"
  proof -
    have "degree (yS - {x}. [:- y, 1:] :: 'a poly) = (yS - {x}. degree ([:- y, 1:] :: 'a poly))"
      using fin by (intro degree_prod_sum_eq) auto
    also have " = card S - 1"
      using fin x by simp
    finally show ?thesis .
  qed
  have "poly.coeff (lagrange_basis S x) (card S - 1) =
      inverse (lagrange_denom S x) *
        poly.coeff (yS - {x}. [:- y, 1:] :: 'a poly) (card S - 1)"
    by (simp add: lagrange_basis_def)
  also have " = inverse (lagrange_denom S x) *
      lead_coeff (yS - {x}. [:- y, 1:] :: 'a poly)"
    by (simp add: deg_prod)
  also have " = inverse (lagrange_denom S x)"
    using fin by (simp add: lead_coeff_prod)
  finally show ?thesis .
qed

lemma lagrange_power_sum:
  fixes S :: "'a::field set"
  assumes fin: "finite S" and card: "card S = Suc d"
  assumes k: "k  d"
  shows "(xS. x ^ k / lagrange_denom S x) = (if k = d then 1 else 0)"
proof -
  let ?X = "monom (1::'a) k"
  have interp: "?X = (xS. smult (poly ?X x) (lagrange_basis S x))"
    using card k by (intro lagrange_interpolation[OF fin]) (simp add: degree_monom_eq)
  have coeff_basis:
    "poly.coeff (lagrange_basis S x) d = inverse (lagrange_denom S x)" if "x  S" for x
    using coeff_lagrange_basis_top[OF fin that] card by simp
  have "poly.coeff ?X d =
      poly.coeff (xS. smult (poly ?X x) (lagrange_basis S x)) d"
    using interp by simp
  also have " = (xS. x ^ k / lagrange_denom S x)"
    using fin by (simp add: coeff_sum poly_monom coeff_basis divide_inverse)
  finally show ?thesis
    using k by simp
qed

lemma lagrange_power_sum_list:
  fixes xs :: "'a::field list"
  assumes dist: "distinct xs" and len: "length xs = Suc d" and k: "k  d"
  shows "sum_list (map (λx. x ^ k / lagrange_denom (set xs) x) xs) =  (if k = d then 1 else 0)"
proof -
  have "sum_list (map (λx. x ^ k / lagrange_denom (set xs) x) xs) =
      (xset xs. x ^ k / lagrange_denom (set xs) x)"
    using dist by (rule sum_list_distinct_conv_sum_set)
  also have " = (if k = d then 1 else 0)"
    using dist len k by (intro lagrange_power_sum) (auto simp: distinct_card)
  finally show ?thesis .
qed

subsection ‹Sparse multivariate polynomials›

text ‹
  Monomials are indexed by exponent lists.  The value of [e1, …, en]›
  at [x1, …, xn]› is x1 ^ e1 * … * xn ^ en; mismatched
  lengths evaluate to zero.  The predicate sparse_poly› records
  finite support and a fixed arity.
›

fun monomial_value :: "nat list  'a::comm_semiring_1 list  'a" where
  "monomial_value [] [] = 1"
| "monomial_value (e # es) (x # xs) = x ^ e * monomial_value es xs"
| "monomial_value _ _ = 0"

fun grid_weight :: "'a::field list list  'a list  'a" where
  "grid_weight [] [] = 1"
| "grid_weight (S # Ss) (x # xs) = lagrange_denom (set S) x * grid_weight Ss xs"
| "grid_weight _ _ = 1"

definition support :: "(nat list  'a::zero)  nat list set" where
  "support p = {m. p m  0}"

definition sparse_poly :: "nat  (nat list  'a::zero)  bool" where
  "sparse_poly n p  finite (support p)  (msupport p. length m = n)"

definition total_degree_le :: "(nat list  'a::zero)  nat  bool" where
  "total_degree_le p d  (msupport p. sum_list m  d)"

definition eval_sparse_poly :: "(nat list  'a::comm_semiring_1)  'a list  'a" where
  "eval_sparse_poly p xs = (msupport p. p m * monomial_value m xs)"

lemma product_lists_set_Cons:
  "set (product_lists (xs # xss)) = (λ(x, ys). x # ys) ` (set xs × set (product_lists xss))"
  by auto

lemma sum_list_concat:
  "sum_list (concat xss) = sum_list (map sum_list xss)"
  by (induction xss) simp_all

lemma sum_list_map_zero:
  fixes f :: "'a  'b::monoid_add"
  assumes "x. x  set xs  f x = 0"
  shows "sum_list (map f xs) = 0"
  using assms by (induction xs) auto

lemma sum_list_product_lists_Cons:
  "sum_list (map f (product_lists (xs # xss))) =
    sum_list (map (λx. sum_list (map (λys. f (x # ys)) (product_lists xss))) xs)"
  by (simp add: sum_list_concat map_concat comp_def)

lemma grid_weight_Cons:
  "grid_weight (S # Ss) (x # xs) = lagrange_denom (set S) x * grid_weight Ss xs"
  by simp

lemma monomial_value_Cons:
  "monomial_value (e # es) (x # xs) = x ^ e * monomial_value es xs"
  by simp

definition grid_monom_sum :: "'a::field list list  nat list  'a" where
  "grid_monom_sum Xss es =
    sum_list (map (λxs. monomial_value es xs / grid_weight Xss xs) (product_lists Xss))"

lemma grid_monom_sum_Cons:
  assumes dist: "distinct Xs"
  shows "grid_monom_sum (Xs # Xss) (e # es) =
    sum_list (map (λx. x ^ e / lagrange_denom (set Xs) x) Xs) *
      grid_monom_sum Xss es"
proof -
  have denom_nz: "lagrange_denom (set Xs) x  0" if "x  set Xs" for x
    using that by (intro lagrange_denom_nonzero) auto
  have inner:
    "sum_list (map (λxs. monomial_value (e # es) (x # xs) /
        grid_weight (Xs # Xss) (x # xs)) (product_lists Xss)) =
      (x ^ e / lagrange_denom (set Xs) x) *
        sum_list (map (λxs. monomial_value es xs / grid_weight Xss xs)
          (product_lists Xss))"
    if "x  set Xs" for x
  proof -
    have term_eq:
      "xs. monomial_value (e # es) (x # xs) / grid_weight (Xs # Xss) (x # xs) =
        (x ^ e / lagrange_denom (set Xs) x) *
          (monomial_value es xs / grid_weight Xss xs)"
      using denom_nz[OF that]
      by (simp add: divide_inverse ac_simps)
    have "sum_list (map (λxs. monomial_value (e # es) (x # xs) /
        grid_weight (Xs # Xss) (x # xs)) (product_lists Xss)) =
      sum_list (map (λxs. (x ^ e / lagrange_denom (set Xs) x) *
        (monomial_value es xs / grid_weight Xss xs)) (product_lists Xss))"
      by (simp add: term_eq)
    also have " =
      (x ^ e / lagrange_denom (set Xs) x) *
        sum_list (map (λxs. monomial_value es xs / grid_weight Xss xs) (product_lists Xss))"
      by (rule sum_list_const_mult)
    finally show ?thesis .
  qed
  have mapped:
    "map (λx. sum_list (map (λxs. monomial_value (e # es) (x # xs) /
        grid_weight (Xs # Xss) (x # xs)) (product_lists Xss))) Xs =
     map (λx. (x ^ e / lagrange_denom (set Xs) x) *
        sum_list (map (λxs. monomial_value es xs / grid_weight Xss xs)
          (product_lists Xss))) Xs"
    using inner by auto
  have "grid_monom_sum (Xs # Xss) (e # es) =
      sum_list (map (λx. sum_list (map (λxs. monomial_value (e # es) (x # xs) /
        grid_weight (Xs # Xss) (x # xs)) (product_lists Xss))) Xs)"
    unfolding grid_monom_sum_def by (rule sum_list_product_lists_Cons)
  also have " =
      sum_list (map (λx. x ^ e / lagrange_denom (set Xs) x) Xs) *
        grid_monom_sum Xss es"
    using grid_monom_sum_def[of Xss es] mapped 
          sum_list_mult_const[of _ "grid_monom_sum Xss es" Xs]
    by presburger
  finally show ?thesis .
qed

lemma grid_monom_sum_Nil [simp]:
  "grid_monom_sum [] [] = 1"
  by (simp add: grid_monom_sum_def)

lemma monomial_value_wrong_length:
  "length es  length xs  monomial_value es xs = 0"
proof (induction es arbitrary: xs)
  case Nil
  then show ?case
    by (cases xs) simp_all
next
  case (Cons e es xs)
  then show ?case
    by (cases xs) auto
qed

lemma grid_monom_sum_wrong_length:
  assumes "length es  length Xss"
    shows "grid_monom_sum Xss es = 0"
proof -
  have zero: "monomial_value es xs = 0" if "xs  set (product_lists Xss)" for xs
    by (metis in_set_product_lists_length assms monomial_value_wrong_length that)
  show ?thesis
    unfolding grid_monom_sum_def
    by (rule sum_list_map_zero) (metis divide_eq_0_iff zero)
qed

lemma grid_monom_sum_delta:
  fixes Xss :: "'a::field list list"
  assumes grids: "list_all2 (λXs d. distinct Xs  length Xs = Suc d) Xss ds"
  assumes len: "length es = length ds"
  assumes deg: "sum_list es  sum_list ds"
  shows "grid_monom_sum Xss es = (if es = ds then 1 else 0)"
  using grids len deg
proof (induction Xss arbitrary: ds es)
  case Nil
  then show ?case
    by (cases ds; cases es) simp_all
next
  case (Cons X Xss ds es)
  then obtain d ds' where ds: "ds = d # ds'"
    by (cases ds) auto
  then obtain e es' where es: "es = e # es'"
    using Cons.prems by (cases es) auto
  from Cons.prems ds es have X: "distinct X" "length X = Suc d"
    by auto
  from Cons.prems ds es have grids_tail:
    "list_all2 (λXs d. distinct Xs  length Xs = Suc d) Xss ds'"
    and len_tail: "length es' = length ds'"
    and deg_all: "e + sum_list es'  d + sum_list ds'"
    by auto
  show ?case
  proof (cases e d rule: linorder_cases)
    case less
    have head_delta: "sum_list (map (λx. x ^ e / lagrange_denom (set X) x) X) =
        (if e = d then 1 else 0)"
      using X less by (intro lagrange_power_sum_list) auto
    then have head: "sum_list (map (λx. x ^ e / lagrange_denom (set X) x) X) = 0"
      using less by simp
    show ?thesis
      using ds es less head X by (simp add: grid_monom_sum_Cons)
  next
    case equal
    have head_delta: "sum_list (map (λx. x ^ e / lagrange_denom (set X) x) X) =
        (if e = d then 1 else 0)"
      using X equal by (intro lagrange_power_sum_list) auto
    then have head: "sum_list (map (λx. x ^ e / lagrange_denom (set X) x) X) = 1"
      using equal by simp
    have tail_deg: "sum_list es'  sum_list ds'"
      using deg_all equal by simp
    have tail: "grid_monom_sum Xss es' = (if es' = ds' then 1 else 0)"
      by (rule Cons.IH) (use grids_tail len_tail tail_deg in auto)
    show ?thesis
      using ds es equal head tail X by (simp add: grid_monom_sum_Cons)
  next
    case greater
    have tail_deg: "sum_list es'  sum_list ds'"
      using deg_all greater by simp
    have tail_ne: "es'  ds'"
    proof
      assume "es' = ds'"
      with deg_all greater show False
        by simp
    qed
    have tail: "grid_monom_sum Xss es' = 0"
      using Cons.IH[OF grids_tail len_tail tail_deg] tail_ne by auto
    show ?thesis
      using ds es greater tail X by (simp add: grid_monom_sum_Cons)
  qed
qed

text ‹
  The next lemma is the coefficient formula specialized to the sparse
  representation: the weighted sum of all grid evaluations extracts exactly
  the coefficient of the target exponent list.
›

lemma sum_list_sum:
  fixes f :: "'b  'c  'a::comm_monoid_add"
  assumes "finite A"
  shows "sum_list (map (λx. aA. f x a) xs) =
    (aA. sum_list (map (λx. f x a) xs))"
  using assms by (induction xs) (simp_all add: sum.distrib)

lemma eval_sparse_poly_grid_sum:
  fixes p :: "nat list  'a::field"
  assumes sp: "sparse_poly (length ds) p"
  assumes grids: "list_all2 (λXs d. distinct Xs  length Xs = Suc d) Xss ds"
  assumes deg: "total_degree_le p (sum_list ds)"
  shows "sum_list (map (λxs. eval_sparse_poly p xs / grid_weight Xss xs) (product_lists Xss)) =
    p ds"
proof -
  have fin: "finite (support p)"
    using sp by (simp add: sparse_poly_def)
  have len_Xss: "length Xss = length ds"
    using grids by (simp add: list_all2_lengthD)
  have "sum_list (map (λxs. eval_sparse_poly p xs / grid_weight Xss xs) (product_lists Xss)) =
      sum_list (map (λxs. msupport p.
        p m * (monomial_value m xs / grid_weight Xss xs)) (product_lists Xss))"
    by (simp add: eval_sparse_poly_def sum_divide_distrib sum_distrib_left mult.assoc)
  also have " =
      (msupport p.
        sum_list (map (λxs. p m * (monomial_value m xs / grid_weight Xss xs))
          (product_lists Xss)))"
    using fin by (simp add: sum_list_sum)
  also have " = (msupport p. p m * grid_monom_sum Xss m)"
    by (intro sum.cong refl)
      (simp add: grid_monom_sum_def sum_list_const_mult divide_inverse ac_simps)
  also have " = p ds"
  proof -
    have delta: "grid_monom_sum Xss m = (if m = ds then 1 else 0)" if "m  support p" for m
    proof -
      have "length m = length ds"
        using sp that by (simp add: sparse_poly_def)
      moreover have "sum_list m  sum_list ds"
        using deg that by (simp add: total_degree_le_def)
      ultimately show ?thesis
        by (rule grid_monom_sum_delta[OF grids])
    qed
    show ?thesis
    proof (cases "ds  support p")
      case True
      have "(msupport p. p m * grid_monom_sum Xss m) =
          p ds * grid_monom_sum Xss ds"
        using fin True delta by (subst sum.remove[of "support p" ds]) auto
      also have " = p ds"
        using delta True by simp
      finally show ?thesis .
    next
      case False
      then have p0: "p ds = 0"
        by (simp add: support_def)
      have "(msupport p. p m * grid_monom_sum Xss m) = 0"
        using fin False delta by (intro sum.neutral) auto
      with p0 show ?thesis
        by simp
    qed
  qed
  finally show ?thesis .
qed

theorem combinatorial_nullstellensatz_exact_lists:
  fixes p :: "nat list  'a::field"
  assumes sp: "sparse_poly (length ds) p"
  assumes deg: "total_degree_le p (sum_list ds)"
  assumes coeff: "p ds  0"
  assumes grids: "list_all2 (λXs d. distinct Xs  length Xs = Suc d) Xss ds"
  shows "xsset (product_lists Xss). eval_sparse_poly p xs  0"
proof (rule ccontr)
  assume "¬ ?thesis"
  then have zero: "xs. xs  set (product_lists Xss)  eval_sparse_poly p xs = 0"
    by auto
  have "sum_list (map (λxs. eval_sparse_poly p xs / grid_weight Xss xs) (product_lists Xss)) = 0"
    by (rule sum_list_map_zero) (metis divide_eq_0_iff zero)
  moreover have "sum_list (map (λxs. eval_sparse_poly p xs / grid_weight Xss xs) (product_lists Xss)) =
      p ds"
    by (rule eval_sparse_poly_grid_sum[OF sp grids deg])
  ultimately show False
    using coeff by simp
qed

text ‹
  Finally, the standard ``more than di points'' formulation follows by
  selecting di + 1› points from each side of the grid.
›

definition exact_grid_sublists :: "'a list list  nat list  'a list list" where
  "exact_grid_sublists Xss ds = map (λ(Xs, d). take (Suc d) Xs) (zip Xss ds)"

lemma exact_grid_sublists_all2:
  assumes "list_all2 (λXs d. distinct Xs  length Xs > d) Xss ds"
  shows "list_all2 (λXs d. distinct Xs  length Xs = Suc d)
    (exact_grid_sublists Xss ds) ds"
  using assms
proof (induction Xss arbitrary: ds)
  case Nil
  then show ?case
    by (cases ds) (simp_all add: exact_grid_sublists_def)
next
  case (Cons X Xss ds)
  then obtain d ds' where ds: "ds = d # ds'"
    by (cases ds) auto
  with Cons.prems have "distinct (take (Suc d) X)" "length (take (Suc d) X) = Suc d"
    by auto
  moreover have "list_all2 (λXs d. distinct Xs  length Xs = Suc d)
      (exact_grid_sublists Xss ds') ds'"
    using Cons.IH Cons.prems ds by auto
  ultimately show ?case
    by (simp add: exact_grid_sublists_def ds)
qed

lemma exact_grid_sublists_subset:
  assumes "list_all2 (λXs d. length Xs > d) Xss ds"
  shows "set (product_lists (exact_grid_sublists Xss ds))  set (product_lists Xss)"
  using assms
proof (induction Xss arbitrary: ds)
  case Nil
  then show ?case
    by (cases ds) (simp_all add: exact_grid_sublists_def)
next
  case (Cons X Xss ds)
  then obtain d ds' where ds: "ds = d # ds'"
    by (cases ds) auto
  have tail: "set (product_lists (exact_grid_sublists Xss ds'))  set (product_lists Xss)"
    using Cons.IH Cons.prems ds by auto
  show ?case
  proof
    fix xs
    assume "xs  set (product_lists (exact_grid_sublists (X # Xss) ds))"
    then obtain x ys where xs: "xs = x # ys"
      and x: "x  set (take (Suc d) X)"
      and ys: "ys  set (product_lists (exact_grid_sublists Xss ds'))"
      by (auto simp: exact_grid_sublists_def ds)
    from x have "x  set X"
      by (rule in_set_takeD)
    moreover from tail ys have "ys  set (product_lists Xss)"
      by auto
    ultimately show "xs  set (product_lists (X # Xss))"
      using xs by auto
  qed
qed

theorem combinatorial_nullstellensatz_lists:
  fixes p :: "nat list  'a::field"
  assumes sp: "sparse_poly (length ds) p"
  assumes deg: "total_degree_le p (sum_list ds)"
  assumes coeff: "p ds  0"
  assumes grids: "list_all2 (λXs d. distinct Xs  length Xs > d) Xss ds"
  shows "xsset (product_lists Xss). eval_sparse_poly p xs  0"
proof -
  have exact: "list_all2 (λXs d. distinct Xs  length Xs = Suc d)
      (exact_grid_sublists Xss ds) ds"
    by (rule exact_grid_sublists_all2[OF grids])
  obtain xs where xs: "xs  set (product_lists (exact_grid_sublists Xss ds))"
    and nz: "eval_sparse_poly p xs  0"
    using combinatorial_nullstellensatz_exact_lists[OF sp deg coeff exact] by blast
  have lengths: "list_all2 (λXs d. length Xs > d) Xss ds"
    by (rule list_all2_mono[OF grids]) auto
  have "xs  set (product_lists Xss)"
    using exact_grid_sublists_subset[OF lengths] xs by auto
  with nz show ?thesis
    by blast
qed

end