Theory Tensor_Unit_Vec
section ‹Unit Vectors as Tensors›
theory Tensor_Unit_Vec
imports Tensor_Product
definition unit_vec::"nat ⇒ nat ⇒ 'a::ring_1 tensor"
where "unit_vec n i = tensor_from_lookup [n] (λx. if x=[i] then 1 else 0)"
lemma dims_unit_vec: "dims (unit_vec n i) = [n]" unfolding unit_vec_def by (simp add: tensor_from_lookup_def)
lemma lookup_unit_vec:
assumes "j<n"
shows "lookup (unit_vec n i) [j] = (if i=j then 1 else 0)"
proof -
have "[j] ⊲ [n]" by (simp add: assms valid_index.Cons valid_index.Nil)
then have "lookup (unit_vec n i) [j] = (λx. if x=[i] then 1 else 0) [j]"
by (simp add: lookup_tensor_from_lookup unit_vec_def)
then show ?thesis by auto
lemma subtensor_prod_with_unit_vec:
fixes A::"'a::ring_1 tensor"
assumes "j<n"
shows "subtensor (unit_vec n i ⊗ A) j = (if i=j then A else (tensor0 (dims A)))"
proof -
have 0:"lookup (unit_vec n i) [j] = (if i=j then 1 else 0)" unfolding unit_vec_def
by (simp add: assms lookup_tensor_from_lookup valid_index.Cons valid_index.Nil)
have 1:"order (unit_vec n i) = 1" unfolding unit_vec_def by (simp add: tensor_from_lookup_def)
from assms have 2:"j < hd (dims (tensor_from_lookup [n] (λx. if x = [i] then 1 else 0)))"
by (simp add: dims_tensor_from_lookup)
show ?thesis using unit_vec_def subtensor_prod_with_vec 1 2 0 smult_1 tensor_smult0
by (metis (no_types, lifting) tensor_from_lookup_eqI)
lemma subtensor_decomposition:
assumes "dims A ≠ []"
shows "listsum (dims A) (map (λi. unit_vec (hd (dims A)) i ⊗ subtensor A i) [0..<hd (dims A)]) = A" (is "?LS = A")
proof -
let ?f = "λi. unit_vec (hd (dims A)) i ⊗ subtensor A i"
have correct_dims:"⋀B. B ∈ set (map ?f [0..<hd (dims A)]) ⟹ dims B = dims A"
fix B
assume "B ∈ set (map ?f [0..<hd (dims A)])"
then obtain i where B:"B = ?f i" and "i<hd (dims A)" by auto
then have "dims (subtensor A i) = tl (dims A)" using dims_subtensor using assms by blast
then show "dims B = dims A" unfolding B
by (metis append_Cons assms dims_tensor_prod dims_unit_vec list.exhaust_sel self_append_conv2)
have "⋀j. j < hd (dims A) ⟹ subtensor ?LS j = subtensor A j"
proof -
fix j
assume "j < hd (dims A)"
have 1:"subtensor ?LS j = listsum (tl (dims A)) (map (λA. subtensor A j) (map ?f [0..<hd (dims A)]))"
using subtensor_listsum[of "(map (λi. ?f i) [0..<hd (dims A)])" "dims A" j, OF correct_dims assms ‹j < hd (dims A)›]
by linarith
also have "... = listsum (tl (dims A)) (map (λi. subtensor (?f i) j) [0..<hd (dims A)])"
proof -
have "map (λA. subtensor A j) (map ?f [0..<hd (dims A)]) = map (λi. subtensor (?f i) j) [0..<hd (dims A)]"
unfolding map_map[of "(λA. subtensor A j)" "?f" "[0..<hd (dims A)]"] by simp
with 1 show ?thesis by metis
also have "... = map (λi. if i = j then subtensor A i else tensor0 (dims (subtensor A i))) [0..<hd (dims A)] ! j"
unfolding subtensor_prod_with_unit_vec[OF ‹j < hd (dims A)›]
using listsum_all_0_but_one[of j "(map (λi. if i = j then subtensor A i else tensor0 (dims (subtensor A i))) [0..<hd (dims A)])" "tl (dims A)"]
by (simp add: ‹j < hd (dims A)› assms)
also have "... = subtensor A j" by (simp add: ‹j < hd (dims A)›)
finally show "subtensor ?LS j = subtensor A j" by auto
moreover have "dims ?LS = dims A" using correct_dims listsum_dims by blast
ultimately show ?thesis using subtensor_eqI by (metis (no_types, lifting) assms)