Theory DL_Shallow_Model
section ‹Shallow Network Model›
theory DL_Shallow_Model
imports DL_Network Tensor_Rank
begin
fun shallow_model' where
"shallow_model' Z M 0 = Conv (Z,M) (Input M)" |
"shallow_model' Z M (Suc N) = Pool (shallow_model' Z M 0) (shallow_model' Z M N)"
definition shallow_model where
"shallow_model Y Z M N = Conv (Y,Z) (shallow_model' Z M N)"
lemma valid_shallow_model': "valid_net (shallow_model' Z M N)"
apply (induction N) unfolding shallow_model'.simps
by (simp add: valid_net.intros, metis shallow_model'.elims shallow_model'.simps(1) valid_net.intros output_size.simps)
lemma output_size_shallow_model': "output_size (shallow_model' Z M N) = Z"
apply (induction N) unfolding shallow_model'.simps using output_size.simps by simp_all
lemma valid_shallow_model: "valid_net (shallow_model Y Z M N)"
unfolding shallow_model_def using valid_shallow_model' valid_net.intros output_size.simps output_size_shallow_model' by metis
lemma output_size_shallow_model: "output_size (shallow_model Y Z M N) = Y"
unfolding shallow_model_def using output_size_shallow_model' output_size.simps by simp
lemma input_sizes_shallow_model: "input_sizes (shallow_model Y Z M N) = replicate (Suc N) M"
apply (induction N) unfolding shallow_model_def input_sizes.simps by simp_all
lemma balanced_net_shallow_model': "balanced_net (shallow_model' Z M N)"
proof(induction N)
case 0
then show ?case
by (metis balanced_net.simps shallow_model'.simps(1))
next
case (Suc N)
have "count_weights True (Conv (Z, M) (Input M)) = count_weights True (shallow_model' Z M N)"
by (induction N; simp)
then show ?case unfolding shallow_model'.simps
by (simp add: Suc.IH balanced_net_Conv balanced_net_Input balanced_net_Pool)
qed
lemma balanced_net_shallow_model: "balanced_net (shallow_model Y Z M N)"
unfolding shallow_model_def
by (simp add: balanced_net_Conv balanced_net_shallow_model')
lemma cprank_max1_shallow_model':
assumes "y < output_size (shallow_model' Z M N)"
shows "cprank_max1 (tensors_from_net (insert_weights s (shallow_model' Z M N) w) $ y)"
using assms proof (induction N arbitrary:w)
case 0
then have "input_sizes (insert_weights s (shallow_model' Z M 0) w) = [M]"
unfolding shallow_model_def shallow_model'.simps insert_weights.simps
input_sizes.simps by metis
then have "dims (tensors_from_net (insert_weights s (shallow_model' Z M 0) w) $ y) = [M]"
using dims_tensors_from_net[OF vec_setI] "0.prems"(1) output_size_correct_tensors
remove_insert_weights valid_shallow_model' by metis
then show ?case
using order1 by (metis One_nat_def eq_imp_le length_Cons list.size(3))
next
case (Suc N)
have y_le_IH:"y < dim_vec (tensors_from_net (insert_weights s (shallow_model' Z M N) (λi. w (i + (count_weights s (shallow_model' Z M 0))))))"
using output_size_correct_tensors[of "insert_weights s (shallow_model' Z M N) (λi. w (i + (count_weights s (shallow_model' Z M 0))))",
unfolded remove_insert_weights, OF valid_shallow_model']
using Suc.prems(1) output_size_shallow_model' by auto
have cprank_max1_IH:"cprank_max1 (tensors_from_net (insert_weights s (shallow_model' Z M N) (λi. w (i + (count_weights s (shallow_model' Z M 0))))) $ y)"
using Suc.IH Suc.prems(1) output_size_shallow_model' by auto
have y_le_0:"y < dim_vec (tensors_from_net (insert_weights s (shallow_model' Z M 0) w))"
by (metis assms output_size_correct_tensors output_size_shallow_model' remove_insert_weights valid_shallow_model')
have cprank_max1_0:"cprank_max1 (tensors_from_net (insert_weights s (shallow_model' Z M 0) w) $ y)"
proof -
have "input_sizes (insert_weights s (shallow_model' Z M 0) w) = [M]"
unfolding shallow_model_def shallow_model'.simps insert_weights.simps
input_sizes.simps by metis
then show ?thesis using order1 dims_tensors_from_net[OF vec_setI] One_nat_def eq_imp_le length_Cons list.size(3) y_le_0 by metis
qed
then show ?case unfolding shallow_model'.simps(2) insert_weights.simps tensors_from_net.simps
using cprank_max1_IH cprank_max1_0 cprank_max1_prod index_component_mult y_le_0 y_le_IH
by (metis Suc.IH output_size_correct_tensors remove_insert_weights valid_shallow_model')
qed
lemma cprank_shallow_model:
assumes "m = insert_weights s (shallow_model Y Z M N) w"
assumes "y < Y"
shows "cprank (tensors_from_net m $ y) ≤ Z"
proof -
have "s ⟹ shared_weight_net m"
by (simp add: assms(1) balanced_net_shallow_model shared_weight_net_insert_weights)
have "cprank_max Z (tensors_from_net m $ y)"
proof -
have dim_extract: "dim_row (extract_matrix w Y Z) = Y"
using dim_extract_matrix(1) by force
have dimc_extract_matrix: "dim_col (extract_matrix w Y Z) = Z"
using dim_extract_matrix(2) by force
have input_sizes: "(input_sizes (insert_weights s (shallow_model' Z M N) (λi. w (i + Y * Z)))) = (input_sizes (shallow_model' Z M N))"
using input_sizes_remove_weights remove_insert_weights by auto
have 0:"tensors_from_net m $ y = Tensor_Plus.listsum (input_sizes (shallow_model' Z M N))
(map (λj. (extract_matrix w Y Z) $$ (y, j) ⋅ (tensors_from_net (insert_weights s (shallow_model' Z M N) (λi. w (i + Y * Z)))) $ j) [0..<Z])"
unfolding ‹m = insert_weights s (shallow_model Y Z M N) w› shallow_model_def insert_weights.simps tensors_from_net.simps
using nth_mat_tensorlist_mult dims_tensors_from_net assms(2) dim_extract output_size_correct_tensors[of "insert_weights s (shallow_model' Z M N) (λi. w (i + Y * Z))", unfolded remove_insert_weights, OF valid_shallow_model']
dimc_extract_matrix output_size_shallow_model' input_sizes by auto
define Bs where "Bs = map (λj. extract_matrix w Y Z $$ (y, j) ⋅ tensors_from_net (insert_weights s (shallow_model' Z M N) (λi. w (i + Y * Z))) $ j) [0..<Z]"
have "⋀B. B ∈ set Bs ⟹ cprank_max1 B" "⋀B. B ∈ set Bs ⟹ dims B = input_sizes (shallow_model' Z M N)"
proof -
fix B assume "B ∈ set Bs"
then obtain j where "B = Bs ! j" "j < length Bs" by (metis in_set_conv_nth)
then have "j < Z" using length_map Bs_def by simp
have 1:"cprank_max1 (tensors_from_net (insert_weights s (shallow_model' Z M N) (λi. w (i + Y * Z))) $ j)"
using ‹j < Z› output_size_shallow_model' cprank_max1_shallow_model' by auto
then have "cprank_max1 (extract_matrix w Y Z $$ (y, j) ⋅ tensors_from_net (insert_weights s (shallow_model' Z M N) (λi. w (i + Y * Z))) $ j)"
using smult_prod_extract1 cprank_max1_order0[OF 1, of "extract_matrix w Y Z $$ (y, j) ⋅ 1"]
by (metis dims_smult mult.left_neutral order_tensor_one)
then show "cprank_max1 B" by (simp add: Bs_def ‹B = Bs ! j› ‹j < Z›)
show "dims B = input_sizes (shallow_model' Z M N)" unfolding ‹B = Bs ! j› Bs_def
nth_map[of j "[0..<Z]", unfolded length_upt Nat.diff_0, OF ‹j < Z›] dims_smult
input_sizes[symmetric]
by (rule dims_tensors_from_net; rule vec_setI[where i=j], simp add:‹j < Z›, metis (no_types) ‹j < Z› output_size_correct_tensors output_size_shallow_model' remove_insert_weights valid_shallow_model')
qed
then show ?thesis unfolding 0 using cprank_maxI length_map Bs_def by (metis (no_types, lifting) diff_zero length_upt)
qed
then show ?thesis unfolding cprank_def by (simp add: Least_le)
qed
end