section ‹Matrices› theory IICF_Matrix imports "../../Sepref" begin subsection ‹Relator and Interface› definition [to_relAPP]: "mtx_rel A ≡ nat_rel ×⇩r nat_rel → A" lemma mtx_rel_id[simp]: "⟨Id⟩mtx_rel = Id" unfolding mtx_rel_def by auto type_synonym 'a mtx = "nat×nat ⇒ 'a" sepref_decl_intf 'a i_mtx is "nat×nat ⇒ 'a" lemma [synth_rules]: "INTF_OF_REL A TYPE('a) ⟹ INTF_OF_REL (⟨A⟩mtx_rel) TYPE('a i_mtx)" by simp subsection ‹Operations› definition op_mtx_new :: "'a mtx ⇒ 'a mtx" where [simp]: "op_mtx_new c ≡ c" sepref_decl_op (no_def) mtx_new: "op_mtx_new" :: "(nat_rel×⇩rnat_rel → A) → ⟨A⟩mtx_rel" apply (rule fref_ncI) unfolding op_mtx_new_def[abs_def] mtx_rel_def by parametricity (* TODO: Ad-hoc rule *) lemma mtx_init_adhoc_frame_match_rule[sepref_frame_match_rules]: "hn_val (nat_rel×⇩rnat_rel → A) x y ⟹⇩t hn_val (nat_rel×⇩rnat_rel → the_pure (pure A)) x y" by simp definition op_mtx_copy :: "'a mtx ⇒ 'a mtx" where [simp]: "op_mtx_copy c ≡ c" sepref_decl_op (no_def) mtx_copy: "op_mtx_copy" :: "⟨A⟩mtx_rel → ⟨A⟩mtx_rel" . sepref_decl_op mtx_get: "λ(c::'a mtx) ij. c ij" :: "⟨A⟩mtx_rel → (nat_rel×⇩rnat_rel) → A" apply (rule fref_ncI) unfolding mtx_rel_def by parametricity sepref_decl_op mtx_set: "fun_upd::'a mtx ⇒ _" :: "⟨A⟩mtx_rel → (nat_rel×⇩rnat_rel) → A → ⟨A⟩mtx_rel" apply (rule fref_ncI) unfolding mtx_rel_def proof goal_cases case 1 have [param]: "((=), (=)) ∈ nat_rel ×⇩r nat_rel → nat_rel ×⇩r nat_rel → bool_rel" by simp show ?case by parametricity qed definition mtx_nonzero :: "_ mtx ⇒ (nat×nat) set" where "mtx_nonzero m ≡ {(i,j). m (i,j)≠0}" sepref_decl_op mtx_nonzero: "mtx_nonzero" :: "⟨A⟩mtx_rel → ⟨nat_rel×⇩rnat_rel⟩set_rel" where "IS_ID (A::(_×(_::zero)) set)" proof goal_cases case 1 assume "IS_ID A" hence U: "A=Id" by (simp only: IS_ID_def) have [param]: "((=),(=))∈A→A→bool_rel" using U by simp show ?case apply (rule fref_ncI) unfolding mtx_rel_def apply parametricity unfolding U by simp_all qed subsection ‹Patterns› lemma pat_amtx_get: "c$e≡op_mtx_get$'c$'e" by simp lemma pat_amtx_set: "fun_upd$c$e$v≡op_mtx_set$'c$'e$'v" by simp lemmas amtx_pats[pat_rules] = pat_amtx_get pat_amtx_set subsection ‹Pointwise Operations› subsubsection ‹Auxiliary Definitions and Lemmas› locale pointwise_op = fixes f :: "'p ⇒ 's ⇒ 's" fixes q :: "'s ⇒ 'p ⇒ 'a" assumes upd_indep1[simp, intro]: "p≠p' ⟹ q (f p s) p' = q s p'" assumes upd_indep2[simp, intro]: "p≠p' ⟹ q (f p (f p' s)) p = q (f p s) p" begin lemma pointwise_upd_fold: "distinct ps ⟹ q (fold f ps s) p = (if p∈set ps then q (f p s) p else q s p)" by (induction ps arbitrary: s) auto end lemma pointwise_fun_fold: fixes f :: "'a ⇒ ('a ⇒ 'b) ⇒ ('a ⇒ 'b)" fixes s :: "'a ⇒ 'b" assumes indep1: "⋀x x' s. x ≠ x' ⟹ f x s x' = s x'" assumes indep2: "⋀x x' s. x ≠ x' ⟹ f x (f x' s) x = f x s x" assumes [simp]: "distinct xs" shows "fold f xs s x = (if x ∈ set xs then f x s x else s x)" proof - interpret pointwise_op f "λs. s" by unfold_locales fact+ show ?thesis using pointwise_upd_fold[of xs s x] by auto qed lemma list_prod_divmod_eq: "List.product [0..<M] [0..<N] = map (λi. (i div N, i mod N)) [0..<N*M]" proof - have [simp]: "i < m*n ⟹ (i::nat) div n < m" for i m n by (metis mult.commute div_eq_0_iff div_mult2_eq gr_implies_not_zero mult_not_zero) have [simp]: "i<N*M ⟹ N>0 ∧ M>0" for i by (cases N; cases M; auto) show ?thesis by (rule nth_equalityI) (auto simp add: product_nth algebra_simps) qed lemma nfoldli_prod_divmod_conv: "nfoldli (List.product [0..<N] [0..<M]) ctd (λ(i,j). f i j) = nfoldli [0..<N*M] ctd (λi. f (i div M) (i mod M))" apply (intro ext) apply (subst list_prod_divmod_eq) apply (simp add: nfoldli_map) apply (fo_rule cong)+ apply (auto simp: algebra_simps) done lemma nfoldli_prod_divmod_conv': "nfoldli [0..<M] ctd (λi. nfoldli [0..<N] ctd (f i)) = nfoldli [0..<N*M] ctd (λi. f (i div N) (i mod N))" apply (intro ext) apply (subst nfoldli_nfoldli_prod_conv) by (simp add: nfoldli_prod_divmod_conv algebra_simps) lemma foldli_prod_divmod_conv': "foldli [0..<M] ctd (λi. foldli [0..<N] ctd (f i)) = foldli [0..<N*M] ctd (λi. f (i div N) (i mod N))" (is "?lhs=?rhs") proof - have "RETURN (?lhs s) = RETURN (?rhs s)" for s apply (subst foldli_eq_nfoldli)+ apply (subst nfoldli_prod_divmod_conv') .. thus ?thesis by auto qed lemma fold_prod_divmod_conv': "fold (λi. fold (f i) [0..<N]) [0..<M] = fold (λi. f (i div N) (i mod N)) [0..<N*M]" using foldli_prod_divmod_conv'[of M "λ_. True" N f, THEN fun_cong] apply (intro ext) apply (simp add: foldli_foldl foldl_conv_fold) done lemma mtx_nonzero_cases[consumes 0, case_names nonzero zero]: obtains "(i,j)∈mtx_nonzero m" | "m (i,j) = 0" by (auto simp: mtx_nonzero_def) subsubsection ‹Unary Pointwise› definition mtx_pointwise_unop :: "(nat×nat ⇒ 'a ⇒ 'a) ⇒ 'a mtx ⇒ 'a mtx" where "mtx_pointwise_unop f m ≡ λ(i,j). f (i,j) (m(i,j))" context fixes f :: "nat×nat ⇒ 'a ⇒ 'a" begin sepref_register "PR_CONST (mtx_pointwise_unop f)" :: "'a i_mtx ⇒ 'a i_mtx" lemma [def_pat_rules]: "mtx_pointwise_unop$f ≡ UNPROTECT (mtx_pointwise_unop f)" by simp end locale mtx_pointwise_unop_loc = fixes N :: nat and M :: nat fixes f :: "(nat×nat) ⇒ 'a::{zero} ⇒ 'a" assumes pres_zero[simp]: "⟦ i≥N ∨ j≥M ⟧ ⟹ f (i,j) 0 = 0" begin definition "opr_fold_impl ≡ fold (λi. fold (λj m. m( (i,j) := f (i,j) (m(i,j)) )) [0..<M]) [0..<N]" lemma opr_fold_impl_eq: assumes "mtx_nonzero m ⊆ {0..<N}×{0..<M}" shows "mtx_pointwise_unop f m = opr_fold_impl m" apply (rule ext) unfolding opr_fold_impl_def apply (simp add: fold_fold_prod_conv) apply (subst pointwise_fun_fold) apply (auto simp: mtx_pointwise_unop_def distinct_product) [3] apply clarsimp subgoal for a b apply (cases a b m rule: mtx_nonzero_cases) using assms apply (auto simp: mtx_pointwise_unop_def) done done lemma opr_fold_impl_refine: "(opr_fold_impl, mtx_pointwise_unop f) ∈ [λm. mtx_nonzero m ⊆ {0..<N}×{0..<M}]⇩f Id → Id" apply (rule frefI) using opr_fold_impl_eq by auto end locale mtx_pointwise_unop_gen_impl = mtx_pointwise_unop_loc + fixes assn :: "'a mtx ⇒ 'i ⇒ assn" fixes A :: "'a ⇒ 'ai ⇒ assn" fixes get_impl :: "'i ⇒ nat×nat ⇒ 'ai Heap" fixes set_impl :: "'i ⇒ nat×nat ⇒ 'ai ⇒ 'i Heap" fixes fi :: "nat×nat ⇒ 'ai ⇒ 'ai Heap" assumes assn_range: "rdomp assn m ⟹ mtx_nonzero m ⊆ {0..<N}×{0..<M}" assumes get_impl_hnr: "(uncurry get_impl,uncurry (RETURN oo op_mtx_get)) ∈ assn⇧k *⇩a (prod_assn (nbn_assn N) (nbn_assn M))⇧k →⇩a A" assumes set_impl_hnr: "(uncurry2 set_impl,uncurry2 (RETURN ooo op_mtx_set)) ∈ assn⇧d *⇩a (prod_assn (nbn_assn N) (nbn_assn M))⇧k *⇩a A⇧k →⇩a assn" assumes fi_hnr: "(uncurry fi,uncurry (RETURN oo f)) ∈ (prod_assn nat_assn nat_assn)⇧k *⇩a A⇧k →⇩a A" begin lemma this_loc: "mtx_pointwise_unop_gen_impl N M f assn A get_impl set_impl fi" by unfold_locales context notes [[sepref_register_adhoc f N M]] notes [intf_of_assn] = intf_of_assnI[where R=assn and 'a="'a i_mtx"] notes [sepref_import_param] = IdI[of N] IdI[of M] notes [sepref_fr_rules] = get_impl_hnr set_impl_hnr fi_hnr begin sepref_thm opr_fold_impl1 is "RETURN o opr_fold_impl" :: "assn⇧d →⇩a assn" unfolding opr_fold_impl_def supply [[goals_limit = 1]] by sepref end concrete_definition (in -) mtx_pointwise_unnop_fold_impl1 uses mtx_pointwise_unop_gen_impl.opr_fold_impl1.refine_raw prepare_code_thms (in -) mtx_pointwise_unnop_fold_impl1_def lemma op_hnr[sepref_fr_rules]: "(mtx_pointwise_unnop_fold_impl1 N M get_impl set_impl fi, RETURN ∘ PR_CONST (mtx_pointwise_unop f)) ∈ assn⇧d →⇩a assn" unfolding PR_CONST_def apply (rule hfref_weaken_pre'[OF _ mtx_pointwise_unnop_fold_impl1.refine[OF this_loc,FCOMP opr_fold_impl_refine]]) by (simp add: assn_range) end subsubsection ‹Binary Pointwise› definition mtx_pointwise_binop :: "('a ⇒ 'a ⇒ 'a) ⇒ 'a mtx ⇒ 'a mtx ⇒ 'a mtx" where "mtx_pointwise_binop f m n ≡ λ(i,j). f (m(i,j)) (n(i,j))" context fixes f :: "'a ⇒ 'a ⇒ 'a" begin sepref_register "PR_CONST (mtx_pointwise_binop f)" :: "'a i_mtx ⇒ 'a i_mtx ⇒ 'a i_mtx" lemma [def_pat_rules]: "mtx_pointwise_binop$f ≡ UNPROTECT (mtx_pointwise_binop f)" by simp end locale mtx_pointwise_binop_loc = fixes N :: nat and M :: nat fixes f :: "'a::{zero} ⇒ 'a ⇒ 'a" assumes pres_zero[simp]: "f 0 0 = 0" begin definition "opr_fold_impl m n ≡ fold (λi. fold (λj m. m( (i,j) := f (m(i,j)) (n(i,j)) )) [0..<M]) [0..<N] m" lemma opr_fold_impl_eq: assumes "mtx_nonzero m ⊆ {0..<N}×{0..<M}" assumes "mtx_nonzero n ⊆ {0..<N}×{0..<M}" shows "mtx_pointwise_binop f m n = opr_fold_impl m n" apply (rule ext) unfolding opr_fold_impl_def apply (simp add: fold_fold_prod_conv) apply (subst pointwise_fun_fold) apply (auto simp: mtx_pointwise_binop_def distinct_product) [3] apply clarsimp subgoal for a b apply (cases a b m rule: mtx_nonzero_cases; cases a b n rule: mtx_nonzero_cases) using assms apply (auto simp: mtx_pointwise_binop_def) done done lemma opr_fold_impl_refine: "(uncurry opr_fold_impl, uncurry (mtx_pointwise_binop f)) ∈ [λ(m,n). mtx_nonzero m ⊆ {0..<N}×{0..<M} ∧ mtx_nonzero n ⊆ {0..<N}×{0..<M}]⇩f Id×⇩rId → Id" apply (rule frefI) using opr_fold_impl_eq by auto end locale mtx_pointwise_binop_gen_impl = mtx_pointwise_binop_loc + fixes assn :: "'a mtx ⇒ 'i ⇒ assn" fixes A :: "'a ⇒ 'ai ⇒ assn" fixes get_impl :: "'i ⇒ nat×nat ⇒ 'ai Heap" fixes set_impl :: "'i ⇒ nat×nat ⇒ 'ai ⇒ 'i Heap" fixes fi :: "'ai ⇒ 'ai ⇒ 'ai Heap" assumes assn_range: "rdomp assn m ⟹ mtx_nonzero m ⊆ {0..<N}×{0..<M}" assumes get_impl_hnr: "(uncurry get_impl,uncurry (RETURN oo op_mtx_get)) ∈ assn⇧k *⇩a (prod_assn (nbn_assn N) (nbn_assn M))⇧k →⇩a A" assumes set_impl_hnr: "(uncurry2 set_impl,uncurry2 (RETURN ooo op_mtx_set)) ∈ assn⇧d *⇩a (prod_assn (nbn_assn N) (nbn_assn M))⇧k *⇩a A⇧k →⇩a assn" assumes fi_hnr: "(uncurry fi,uncurry (RETURN oo f)) ∈ A⇧k *⇩a A⇧k →⇩a A" begin lemma this_loc: "mtx_pointwise_binop_gen_impl N M f assn A get_impl set_impl fi" by unfold_locales context notes [[sepref_register_adhoc f N M]] notes [intf_of_assn] = intf_of_assnI[where R=assn and 'a="'a i_mtx"] notes [sepref_import_param] = IdI[of N] IdI[of M] notes [sepref_fr_rules] = get_impl_hnr set_impl_hnr fi_hnr begin sepref_thm opr_fold_impl1 is "uncurry (RETURN oo opr_fold_impl)" :: "assn⇧d*⇩aassn⇧k →⇩a assn" unfolding opr_fold_impl_def[abs_def] by sepref end concrete_definition (in -) mtx_pointwise_binop_fold_impl1 uses mtx_pointwise_binop_gen_impl.opr_fold_impl1.refine_raw is "(uncurry ?f,_)∈_" prepare_code_thms (in -) mtx_pointwise_binop_fold_impl1_def lemma op_hnr[sepref_fr_rules]: "(uncurry (mtx_pointwise_binop_fold_impl1 N M get_impl set_impl fi), uncurry (RETURN oo PR_CONST (mtx_pointwise_binop f))) ∈ assn⇧d *⇩a assn⇧k →⇩a assn" unfolding PR_CONST_def apply (rule hfref_weaken_pre'[OF _ mtx_pointwise_binop_fold_impl1.refine[OF this_loc,FCOMP opr_fold_impl_refine]]) apply (auto dest: assn_range) done end subsubsection ‹Compare Pointwise› definition mtx_pointwise_cmpop :: "('a ⇒ 'a ⇒ bool) ⇒ ('a ⇒ 'a ⇒ bool) ⇒ 'a mtx ⇒ 'a mtx ⇒ bool" where "mtx_pointwise_cmpop f g m n ≡ (∀i j. f (m(i,j)) (n(i,j))) ∧ (∃i j. g (m(i,j)) (n(i,j)))" context fixes f g :: "'a ⇒ 'a ⇒ bool" begin sepref_register "PR_CONST (mtx_pointwise_cmpop f g)" :: "'a i_mtx ⇒ 'a i_mtx ⇒ bool" lemma [def_pat_rules]: "mtx_pointwise_cmpop$f$g ≡ UNPROTECT (mtx_pointwise_cmpop f g)" by simp end (* TODO: Move *) lemma mtx_nonzeroD: "⟦¬i<N; mtx_nonzero m ⊆ {0..<N}×{0..<M}⟧ ⟹ m(i,j) = 0" "⟦¬j<M; mtx_nonzero m ⊆ {0..<N}×{0..<M}⟧ ⟹ m(i,j) = 0" by (auto simp: mtx_nonzero_def) locale mtx_pointwise_cmpop_loc = fixes N :: nat and M :: nat fixes f g :: "'a::{zero} ⇒ 'a ⇒ bool" assumes pres_zero[simp]: "f 0 0 = True" "g 0 0 = False" begin definition "opr_fold_impl m n ≡ do { s ← nfoldli (List.product [0..<N] [0..<M]) (λs. s≠2) (λ(i,j) s. do { if f (m(i,j)) (n(i,j)) then if s=0 then if g (m(i,j)) (n(i,j)) then RETURN 1 else RETURN s else RETURN s else RETURN 2 }) (0::nat); RETURN (s=1) }" lemma opr_fold_impl_eq: assumes "mtx_nonzero m ⊆ {0..<N}×{0..<M}" assumes "mtx_nonzero n ⊆ {0..<N}×{0..<M}" shows "opr_fold_impl m n ≤ RETURN (mtx_pointwise_cmpop f g m n)" proof - have "(∀i<N. ∀j<M. f (m (i, j)) (n (i, j))) ⟹ f (m (i, j)) (n (i, j))" for i j apply (cases "i<N"; cases "j<M") using assms by (auto simp: mtx_nonzeroD) moreover have "g (m (i, j)) (n (i, j)) ⟹ (∃i<N. ∃j<M. g (m (i, j)) (n (i, j)))" for i j apply (cases "i<N"; cases "j<M") using assms by (auto simp: mtx_nonzeroD) ultimately have EQ: "mtx_pointwise_cmpop f g m n ⟷ (∀i<N. ∀j<M. f (m(i,j)) (n(i,j))) ∧ (∃i<N. ∃j<M. g (m(i,j)) (n(i,j)))" unfolding mtx_pointwise_cmpop_def by meson have aux: "List.product [0..<N] [0..<M] = l1 @ (i, j) # l2 ⟹ i<N ∧ j<M" for l1 i j l2 proof - assume "List.product [0..<N] [0..<M] = l1 @ (i, j) # l2" hence "(i,j)∈set (List.product [0..<N] [0..<M])" by simp thus ?thesis by simp qed show ?thesis unfolding opr_fold_impl_def apply (refine_vcg nfoldli_rule[where I="λl1 _ s. if s=2 then ∃i<N. ∃j<M. ¬f (m(i,j)) (n(i,j)) else ( (s=0 ∨ s=1) ∧ (∀(i,j)∈set l1. f (m(i,j)) (n(i,j))) ∧ (s=1 ⟷ (∃(i,j)∈set l1. g (m(i,j)) (n(i,j)))) )"] ) apply (vc_solve dest: aux solve: asm_rl simp: EQ) [6] apply (fastforce simp: EQ) done qed lemma opr_fold_impl_refine: "(uncurry opr_fold_impl, uncurry (RETURN oo mtx_pointwise_cmpop f g)) ∈ [λ(m,n). mtx_nonzero m ⊆ {0..<N}×{0..<M} ∧ mtx_nonzero n ⊆ {0..<N}×{0..<M}]⇩f Id×⇩rId → ⟨bool_rel⟩nres_rel" apply (rule frefI) using opr_fold_impl_eq by (auto intro: nres_relI) end locale mtx_pointwise_cmpop_gen_impl = mtx_pointwise_cmpop_loc + fixes assn :: "'a mtx ⇒ 'i ⇒ assn" fixes A :: "'a ⇒ 'ai ⇒ assn" fixes get_impl :: "'i ⇒ nat×nat ⇒ 'ai Heap" fixes fi :: "'ai ⇒ 'ai ⇒ bool Heap" fixes gi :: "'ai ⇒ 'ai ⇒ bool Heap" assumes assn_range: "rdomp assn m ⟹ mtx_nonzero m ⊆ {0..<N}×{0..<M}" assumes get_impl_hnr: "(uncurry get_impl,uncurry (RETURN oo op_mtx_get)) ∈ assn⇧k *⇩a (prod_assn (nbn_assn N) (nbn_assn M))⇧k →⇩a A" assumes fi_hnr: "(uncurry fi,uncurry (RETURN oo f)) ∈ A⇧k *⇩a A⇧k →⇩a bool_assn" assumes gi_hnr: "(uncurry gi,uncurry (RETURN oo g)) ∈ A⇧k *⇩a A⇧k →⇩a bool_assn" begin lemma this_loc: "mtx_pointwise_cmpop_gen_impl N M f g assn A get_impl fi gi" by unfold_locales context notes [[sepref_register_adhoc f g N M]] notes [intf_of_assn] = intf_of_assnI[where R=assn and 'a="'a i_mtx"] notes [sepref_import_param] = IdI[of N] IdI[of M] notes [sepref_fr_rules] = get_impl_hnr fi_hnr gi_hnr begin sepref_thm opr_fold_impl1 is "uncurry opr_fold_impl" :: "assn⇧d*⇩aassn⇧k →⇩a bool_assn" unfolding opr_fold_impl_def[abs_def] nfoldli_nfoldli_prod_conv[symmetric] by sepref end concrete_definition (in -) mtx_pointwise_cmpop_fold_impl1 uses mtx_pointwise_cmpop_gen_impl.opr_fold_impl1.refine_raw is "(uncurry ?f,_)∈_" prepare_code_thms (in -) mtx_pointwise_cmpop_fold_impl1_def lemma op_hnr[sepref_fr_rules]: "(uncurry (mtx_pointwise_cmpop_fold_impl1 N M get_impl fi gi), uncurry (RETURN oo PR_CONST (mtx_pointwise_cmpop f g))) ∈ assn⇧d *⇩a assn⇧k →⇩a bool_assn" unfolding PR_CONST_def apply (rule hfref_weaken_pre'[OF _ mtx_pointwise_cmpop_fold_impl1.refine[OF this_loc,FCOMP opr_fold_impl_refine]]) apply (auto dest: assn_range) done end end