Theory Pair_Memory
subsection ‹Pair Memory›
theory Pair_Memory
imports "../state_monad/Memory"
begin
lemma map_add_mono:
"(m1 ++ m2) ⊆⇩m (m1' ++ m2')" if "m1 ⊆⇩m m1'" "m2 ⊆⇩m m2'" "dom m1 ∩ dom m2' = {}"
using that unfolding map_le_def map_add_def dom_def by (auto split: option.splits)
lemma map_add_upd2:
"f(x ↦ y) ++ g = (f ++ g)(x ↦ y)" if "dom f ∩ dom g = {}" "x ∉ dom g"
apply (subst map_add_comm)
defer
apply simp
apply (subst map_add_comm)
using that
by auto
locale pair_mem_defs =
fixes lookup1 lookup2 :: "'a ⇒ ('mem, 'v option) state"
and update1 update2 :: "'a ⇒ 'v ⇒ ('mem, unit) state"
and move12 :: "'k1 ⇒ ('mem, unit) state"
and get_k1 get_k2 :: "('mem, 'k1) state"
and P :: "'mem ⇒ bool"
fixes key1 :: "'k ⇒ 'k1" and key2 :: "'k ⇒ 'a"
begin
text ‹We assume that look-ups happen on the older row, so it is biased towards the second entry.›
definition
"lookup_pair k = do {
let k' = key1 k;
k2 ← get_k2;
if k' = k2
then lookup2 (key2 k)
else do {
k1 ← get_k1;
if k' = k1
then lookup1 (key2 k)
else State_Monad.return None
}
}
"
text ‹We assume that updates happen on the newer row, so it is biased towards the first entry.›
definition
"update_pair k v = do {
let k' = key1 k;
k1 ← get_k1;
if k' = k1
then update1 (key2 k) v
else do {
k2 ← get_k2;
if k' = k2
then update2 (key2 k) v
else (move12 k' ⪢ update1 (key2 k) v)
}
}
"
sublocale pair: state_mem_defs lookup_pair update_pair .
sublocale mem1: state_mem_defs lookup1 update1 .
sublocale mem2: state_mem_defs lookup2 update2 .
definition
"inv_pair heap ≡
let
k1 = fst (State_Monad.run_state get_k1 heap);
k2 = fst (State_Monad.run_state get_k2 heap)
in
(∀ k ∈ dom (mem1.map_of heap). ∃ k'. key1 k' = k1 ∧ key2 k' = k) ∧
(∀ k ∈ dom (mem2.map_of heap). ∃ k'. key1 k' = k2 ∧ key2 k' = k) ∧
k1 ≠ k2 ∧ P heap
"
definition
"map_of1 m k = (if key1 k = fst (State_Monad.run_state get_k1 m) then mem1.map_of m (key2 k) else None)"
definition
"map_of2 m k = (if key1 k = fst (State_Monad.run_state get_k2 m) then mem2.map_of m (key2 k) else None)"
end
locale pair_mem = pair_mem_defs +
assumes get_state:
"State_Monad.run_state get_k1 m = (k, m') ⟹ m' = m"
"State_Monad.run_state get_k2 m = (k, m') ⟹ m' = m"
assumes move12_correct:
"P m ⟹ State_Monad.run_state (move12 k1) m = (x, m') ⟹ mem1.map_of m' ⊆⇩m Map.empty"
"P m ⟹ State_Monad.run_state (move12 k1) m = (x, m') ⟹ mem2.map_of m' ⊆⇩m mem1.map_of m"
assumes move12_keys:
"State_Monad.run_state (move12 k1) m = (x, m') ⟹ fst (State_Monad.run_state get_k1 m') = k1"
"State_Monad.run_state (move12 k1) m = (x, m') ⟹ fst (State_Monad.run_state get_k2 m') = fst (State_Monad.run_state get_k1 m)"
assumes move12_inv:
"lift_p P (move12 k1)"
assumes lookup_inv:
"lift_p P (lookup1 k')" "lift_p P (lookup2 k')"
assumes update_inv:
"lift_p P (update1 k' v)" "lift_p P (update2 k' v)"
assumes lookup_keys:
"P m ⟹ State_Monad.run_state (lookup1 k') m = (v', m') ⟹
fst (State_Monad.run_state get_k1 m') = fst (State_Monad.run_state get_k1 m)"
"P m ⟹ State_Monad.run_state (lookup1 k') m = (v', m') ⟹
fst (State_Monad.run_state get_k2 m') = fst (State_Monad.run_state get_k2 m)"
"P m ⟹ State_Monad.run_state (lookup2 k') m = (v', m') ⟹
fst (State_Monad.run_state get_k1 m') = fst (State_Monad.run_state get_k1 m)"
"P m ⟹ State_Monad.run_state (lookup2 k') m = (v', m') ⟹
fst (State_Monad.run_state get_k2 m') = fst (State_Monad.run_state get_k2 m)"
assumes update_keys:
"P m ⟹ State_Monad.run_state (update1 k' v) m = (x, m') ⟹
fst (State_Monad.run_state get_k1 m') = fst (State_Monad.run_state get_k1 m)"
"P m ⟹ State_Monad.run_state (update1 k' v) m = (x, m') ⟹
fst (State_Monad.run_state get_k2 m') = fst (State_Monad.run_state get_k2 m)"
"P m ⟹ State_Monad.run_state (update2 k' v) m = (x, m') ⟹
fst (State_Monad.run_state get_k1 m') = fst (State_Monad.run_state get_k1 m)"
"P m ⟹ State_Monad.run_state (update2 k' v) m = (x, m') ⟹
fst (State_Monad.run_state get_k2 m') = fst (State_Monad.run_state get_k2 m)"
assumes
lookup_correct:
"P m ⟹ mem1.map_of (snd (State_Monad.run_state (lookup1 k') m)) ⊆⇩m (mem1.map_of m)"
"P m ⟹ mem2.map_of (snd (State_Monad.run_state (lookup1 k') m)) ⊆⇩m (mem2.map_of m)"
"P m ⟹ mem1.map_of (snd (State_Monad.run_state (lookup2 k') m)) ⊆⇩m (mem1.map_of m)"
"P m ⟹ mem2.map_of (snd (State_Monad.run_state (lookup2 k') m)) ⊆⇩m (mem2.map_of m)"
assumes
update_correct:
"P m ⟹ mem1.map_of (snd (State_Monad.run_state (update1 k' v) m)) ⊆⇩m (mem1.map_of m)(k' ↦ v)"
"P m ⟹ mem2.map_of (snd (State_Monad.run_state (update2 k' v) m)) ⊆⇩m (mem2.map_of m)(k' ↦ v)"
"P m ⟹ mem2.map_of (snd (State_Monad.run_state (update1 k' v) m)) ⊆⇩m mem2.map_of m"
"P m ⟹ mem1.map_of (snd (State_Monad.run_state (update2 k' v) m)) ⊆⇩m mem1.map_of m"
begin
lemma map_of_le_pair:
"pair.map_of m ⊆⇩m map_of1 m ++ map_of2 m"
if "inv_pair m"
using that
unfolding pair.map_of_def map_of1_def map_of2_def
unfolding lookup_pair_def inv_pair_def map_of_def map_le_def dom_def map_add_def
unfolding State_Monad.bind_def
by (auto 4 4
simp: mem2.map_of_def mem1.map_of_def Let_def
dest: get_state split: prod.split_asm if_split_asm
)
lemma pair_le_map_of:
"map_of1 m ++ map_of2 m ⊆⇩m pair.map_of m"
if "inv_pair m"
using that
unfolding pair.map_of_def map_of1_def map_of2_def
unfolding lookup_pair_def inv_pair_def map_of_def map_le_def dom_def map_add_def
unfolding State_Monad.bind_def
by (auto
simp: mem2.map_of_def mem1.map_of_def State_Monad.run_state_return Let_def
dest: get_state split: prod.splits if_split_asm option.split
)
lemma map_of_eq_pair:
"map_of1 m ++ map_of2 m = pair.map_of m"
if "inv_pair m"
using that
unfolding pair.map_of_def map_of1_def map_of2_def
unfolding lookup_pair_def inv_pair_def map_of_def map_le_def dom_def map_add_def
unfolding State_Monad.bind_def
by (auto 4 4
simp: mem2.map_of_def mem1.map_of_def State_Monad.run_state_return Let_def
dest: get_state split: prod.splits option.split
)
lemma inv_pair_neq[simp]:
False if "inv_pair m" "fst (State_Monad.run_state get_k1 m) = fst (State_Monad.run_state get_k2 m)"
using that unfolding inv_pair_def by auto
lemma inv_pair_P_D:
"P m" if "inv_pair m"
using that unfolding inv_pair_def by (auto simp: Let_def)
lemma inv_pair_domD[intro]:
"dom (map_of1 m) ∩ dom (map_of2 m) = {}" if "inv_pair m"
using that unfolding inv_pair_def map_of1_def map_of2_def by (auto split: if_split_asm)
lemma move12_correct1:
"map_of1 heap' ⊆⇩m Map.empty" if "State_Monad.run_state (move12 k1) heap = (x, heap')" "P heap"
using move12_correct[OF that(2,1)] unfolding map_of1_def by (auto simp: move12_keys map_le_def)
lemma move12_correct2:
"map_of2 heap' ⊆⇩m map_of1 heap" if "State_Monad.run_state (move12 k1) heap = (x, heap')" "P heap"
using move12_correct(2)[OF that(2,1)] that unfolding map_of1_def map_of2_def
by (auto simp: move12_keys map_le_def)
lemma dom_empty[simp]:
"dom (map_of1 heap') = {}" if "State_Monad.run_state (move12 k1) heap = (x, heap')" "P heap"
using move12_correct1[OF that] by (auto dest: map_le_implies_dom_le)
lemma inv_pair_lookup1:
"inv_pair m'" if "State_Monad.run_state (lookup1 k) m = (v, m')" "inv_pair m"
using that lookup_inv[of k] inv_pair_P_D[OF ‹inv_pair m›] unfolding inv_pair_def
by (auto 4 4
simp: Let_def lookup_keys
dest: lift_p_P lookup_correct[of _ k, THEN map_le_implies_dom_le]
)
lemma inv_pair_lookup2:
"inv_pair m'" if "State_Monad.run_state (lookup2 k) m = (v, m')" "inv_pair m"
using that lookup_inv[of k] inv_pair_P_D[OF ‹inv_pair m›] unfolding inv_pair_def
by (auto 4 4
simp: Let_def lookup_keys
dest: lift_p_P lookup_correct[of _ k, THEN map_le_implies_dom_le]
)
lemma inv_pair_update1:
"inv_pair m'"
if "State_Monad.run_state (update1 (key2 k) v) m = (v', m')" "inv_pair m" "fst (State_Monad.run_state get_k1 m) = key1 k"
using that update_inv[of "key2 k" v] inv_pair_P_D[OF ‹inv_pair m›] unfolding inv_pair_def
apply (auto
simp: Let_def update_keys
dest: lift_p_P update_correct[of _ "key2 k" v, THEN map_le_implies_dom_le]
)
apply (frule update_correct[of _ "key2 k" v, THEN map_le_implies_dom_le]; auto 13 2; fail)
apply (frule update_correct[of _ "key2 k" v, THEN map_le_implies_dom_le]; auto 13 2; fail)
done
lemma inv_pair_update2:
"inv_pair m'"
if "State_Monad.run_state (update2 (key2 k) v) m = (v', m')" "inv_pair m" "fst (State_Monad.run_state get_k2 m) = key1 k"
using that update_inv[of "key2 k" v] inv_pair_P_D[OF ‹inv_pair m›] unfolding inv_pair_def
apply (auto
simp: Let_def update_keys
dest: lift_p_P update_correct[of _ "key2 k" v, THEN map_le_implies_dom_le]
)
apply (frule update_correct[of _ "key2 k" v, THEN map_le_implies_dom_le]; auto 13 2; fail)
apply (frule update_correct[of _ "key2 k" v, THEN map_le_implies_dom_le]; auto 13 2; fail)
done
lemma inv_pair_move12:
"inv_pair m'"
if "State_Monad.run_state (move12 k) m = (v', m')" "inv_pair m" "fst (State_Monad.run_state get_k1 m) ≠ k"
using that move12_inv[of "k"] inv_pair_P_D[OF ‹inv_pair m›] unfolding inv_pair_def
apply (auto
simp: Let_def move12_keys
dest: lift_p_P move12_correct[of _ "k", THEN map_le_implies_dom_le]
)
apply (blast dest: move12_correct[of _ "k", THEN map_le_implies_dom_le])
done
lemma mem_correct_pair:
"mem_correct lookup_pair update_pair inv_pair"
if injective: "∀ k k'. key1 k = key1 k' ∧ key2 k = key2 k' ⟶ k = k'"
proof (standard, goal_cases)
case (1 k)
show ?case
unfolding lookup_pair_def Let_def
by (auto 4 4
intro!: lift_pI
dest: get_state inv_pair_lookup1 inv_pair_lookup2
simp: State_Monad.bind_def State_Monad.run_state_return
split: if_split_asm prod.split_asm
)
next
case (2 k v)
show ?case
unfolding update_pair_def Let_def
apply (auto 4 4
intro!: lift_pI intro: inv_pair_update1 inv_pair_update2
dest: get_state
simp: State_Monad.bind_def get_state State_Monad.run_state_return
split: if_split_asm prod.split_asm
)+
apply (elim inv_pair_update1 inv_pair_move12)
apply (((subst get_state, assumption)+)?, auto intro: move12_keys dest: get_state; fail)+
done
next
case (3 m k)
{
let ?m = "snd (State_Monad.run_state (lookup2 (key2 k)) m)"
have "map_of1 ?m ⊆⇩m map_of1 m"
by (smt 3 domIff inv_pair_P_D local.lookup_keys lookup_correct map_le_def map_of1_def surjective_pairing)
moreover have "map_of2 ?m ⊆⇩m map_of2 m"
by (smt 3 domIff inv_pair_P_D local.lookup_keys lookup_correct map_le_def map_of2_def surjective_pairing)
moreover have "dom (map_of1 ?m) ∩ dom (map_of2 m) = {}"
using 3 ‹map_of1 ?m ⊆⇩m map_of1 m› inv_pair_domD map_le_implies_dom_le by fastforce
moreover have "inv_pair ?m"
using 3 inv_pair_lookup2 surjective_pairing by metis
ultimately have "pair.map_of ?m ⊆⇩m pair.map_of m"
apply (subst map_of_eq_pair[symmetric])
defer
apply (subst map_of_eq_pair[symmetric])
by (auto intro: 3 map_add_mono)
}
moreover
{
let ?m = "snd (State_Monad.run_state (lookup1 (key2 k)) m)"
have "map_of1 ?m ⊆⇩m map_of1 m"
by (smt 3 domIff inv_pair_P_D local.lookup_keys lookup_correct map_le_def map_of1_def surjective_pairing)
moreover have "map_of2 ?m ⊆⇩m map_of2 m"
by (smt 3 domIff inv_pair_P_D local.lookup_keys lookup_correct map_le_def map_of2_def surjective_pairing)
moreover have "dom (map_of1 ?m) ∩ dom (map_of2 m) = {}"
using 3 ‹map_of1 ?m ⊆⇩m map_of1 m› inv_pair_domD map_le_implies_dom_le by fastforce
moreover have "inv_pair ?m"
using 3 inv_pair_lookup1 surjective_pairing by metis
ultimately have "pair.map_of ?m ⊆⇩m pair.map_of m"
apply (subst map_of_eq_pair[symmetric])
defer
apply (subst map_of_eq_pair[symmetric])
by (auto intro: 3 map_add_mono)
}
ultimately show ?case
by (auto
split:if_split prod.split
simp: Let_def lookup_pair_def State_Monad.bind_def State_Monad.run_state_return dest: get_state intro: map_le_refl
)
next
case prems: (4 m k v)
let ?m1 = "snd (State_Monad.run_state (update1 (key2 k) v) m)"
let ?m2 = "snd (State_Monad.run_state (update2 (key2 k) v) m)"
from prems have disjoint: "dom (map_of1 m) ∩ dom (map_of2 m) = {}"
by (simp add: inv_pair_domD)
show ?case
apply (auto
intro: map_le_refl dest: get_state
split: prod.split
simp: Let_def update_pair_def State_Monad.bind_def State_Monad.run_state_return
)
proof goal_cases
case (1 m')
then have "m' = m"
by (rule get_state)
from 1 prems have "map_of1 ?m1 ⊆⇩m (map_of1 m)(k ↦ v)"
by (smt inv_pair_P_D map_le_def map_of1_def surjective_pairing domIff
fst_conv fun_upd_apply injective update_correct update_keys
)
moreover from prems have "map_of2 ?m1 ⊆⇩m map_of2 m"
by (smt domIff inv_pair_P_D update_correct update_keys map_le_def map_of2_def surjective_pairing)
moreover from prems have "dom (map_of1 ?m1) ∩ dom (map_of2 m) = {}"
by (smt inv_pair_P_D[OF ‹inv_pair m›] domIff Int_emptyI eq_snd_iff inv_pair_neq
map_of1_def map_of2_def update_keys(1)
)
moreover from 1 prems have "k ∉ dom (map_of2 m)"
using inv_pair_neq map_of2_def by fastforce
moreover from 1 prems have "inv_pair ?m1"
using inv_pair_update1 fst_conv surjective_pairing by metis
ultimately show "pair.map_of (snd (State_Monad.run_state (update1 (key2 k) v) m')) ⊆⇩m (pair.map_of m)(k ↦ v)"
unfolding ‹m' = m› using disjoint
apply (subst map_of_eq_pair[symmetric])
defer
apply (subst map_of_eq_pair[symmetric], rule prems)
apply (subst map_add_upd2[symmetric])
by (auto intro: map_add_mono)
next
case (2 k1 m' m'')
then have "m' = m" "m'' = m"
by (auto dest: get_state)
from 2 prems have "map_of2 ?m2 ⊆⇩m (map_of2 m)(k ↦ v)"
unfolding ‹m' = m› ‹m'' = m›
by (smt inv_pair_P_D map_le_def map_of2_def surjective_pairing domIff
fst_conv fun_upd_apply injective update_correct update_keys
)
moreover from prems have "map_of1 ?m2 ⊆⇩m map_of1 m"
by (smt domIff inv_pair_P_D update_correct update_keys map_le_def map_of1_def surjective_pairing)
moreover from 2 have "dom (map_of1 ?m2) ∩ dom ((map_of2 m)(k ↦ v)) = {}"
unfolding ‹m' = m›
by (smt domIff ‹map_of1 ?m2 ⊆⇩m map_of1 m› disjoint_iff_not_equal fst_conv fun_upd_apply
map_le_def map_of1_def map_of2_def
)
moreover from 2 prems have "inv_pair ?m2"
unfolding ‹m' = m›
using inv_pair_update2 fst_conv surjective_pairing by metis
ultimately show "pair.map_of (snd (State_Monad.run_state (update2 (key2 k) v) m'')) ⊆⇩m (pair.map_of m)(k ↦ v)"
unfolding ‹m' = m› ‹m'' = m›
apply (subst map_of_eq_pair[symmetric])
defer
apply (subst map_of_eq_pair[symmetric], rule prems)
apply (subst map_add_upd[symmetric])
by (rule map_add_mono)
next
case (3 k1 m1 k2 m2 m3)
then have "m1 = m" "m2 = m"
by (auto dest: get_state)
let ?m3 = "snd (State_Monad.run_state (update1 (key2 k) v) m3)"
from 3 prems have "map_of1 ?m3 ⊆⇩m (map_of2 m)(k ↦ v)"
unfolding ‹m2 = m›
by (smt inv_pair_P_D map_le_def map_of1_def surjective_pairing domIff
fst_conv fun_upd_apply injective
inv_pair_move12 move12_correct move12_keys update_correct update_keys
)
moreover have "map_of2 ?m3 ⊆⇩m map_of1 m"
proof -
from prems 3 have "P m" "P m3"
unfolding ‹m1 = m› ‹m2 = m›
using inv_pair_P_D[OF prems] by (auto elim: lift_p_P[OF move12_inv])
from 3(3)[unfolded ‹m2 = m›] have "mem2.map_of ?m3 ⊆⇩m mem1.map_of m"
by - (erule map_le_trans[OF update_correct(3)[OF ‹P m3›] move12_correct(2)[OF ‹P m›]])
with 3 prems show ?thesis
unfolding ‹m1 = m› ‹m2 = m› map_le_def map_of2_def
apply auto
apply (frule move12_keys(2), simp)
by (metis
domI inv_pair_def map_of1_def surjective_pairing
inv_pair_move12 move12_keys(2) update_keys(2)
)
qed
moreover from prems 3 have "dom (map_of1 ?m3) ∩ dom (map_of1 m) = {}"
unfolding ‹m1 = m› ‹m2 = m›
by (smt inv_pair_P_D disjoint_iff_not_equal map_of1_def surjective_pairing domIff
fst_conv inv_pair_move12 move12_keys update_keys
)
moreover from 3 have "k ∉ dom (map_of1 m)"
by (simp add: domIff map_of1_def)
moreover from 3 prems have "inv_pair ?m3"
unfolding ‹m2 = m›
by (metis inv_pair_move12 inv_pair_update1 move12_keys(1) fst_conv surjective_pairing)
ultimately show ?case
unfolding ‹m1 = m› ‹m2 = m› using disjoint
apply (subst map_of_eq_pair[symmetric])
defer
apply (subst map_of_eq_pair[symmetric])
apply (rule prems)
apply (subst (2) map_add_comm)
defer
apply (subst map_add_upd2[symmetric])
apply (auto intro: map_add_mono)
done
qed
qed
lemma emptyI:
assumes "inv_pair m" "mem1.map_of m ⊆⇩m Map.empty" "mem2.map_of m ⊆⇩m Map.empty"
shows "pair.map_of m ⊆⇩m Map.empty"
using assms by (auto simp: map_of1_def map_of2_def map_le_def map_of_eq_pair[symmetric])
end