Theory MDP_RP_Certification
section ‹Certification of Reachability Problems on MDPs›
theory MDP_RP_Certification
imports
"../MDP_Reachability_Problem"
"HOL-Library.IArray"
"HOL-Library.Code_Target_Numeral"
begin
context Reachability_Problem
begin
lemma p_ub':
fixes x
assumes 1: "s ∈ S" "⋀s D. s ∈ S1 ⟹ D ∈ K s ⟹ (∑t∈S. pmf D t * x t) ≤ x s"
assumes 2: "⋀s. s ∈ S1 ⟹ x s ≠ 0 ⟹ (∃t∈S2. (s, t) ∈ (SIGMA s:S1. ⋃D∈K s. set_pmf D)⇧*)"
assumes 3: "⋀s. s ∈ S - S1 - S2 ⟹ x s = 0"
assumes 4: "⋀s. s ∈ S2 ⟹ x s = 1"
shows "enn2real (p s) ≤ x s"
proof (rule p_ub[OF 1 _ 4])
fix s assume "s ∈ S" "p s = 0" with 2[of s] p_pos[of s] p_S2[of s] 3[of s] show "x s = 0"
by (cases "x s = 0") auto
qed
lemma n_lb':
fixes x
assumes "wf R"
assumes 1: "s ∈ S" "⋀s D. s ∈ S1 ⟹ D ∈ K s ⟹ x s ≤ (∑t∈S. pmf D t * x t)"
assumes 2: "⋀s D. s ∈ S1 ⟹ D ∈ K s ⟹ x s ≠ 0 ⟹ ∃t∈D. ((t, s) ∈ R ∧ t ∈ S1 ∧ x t ≠ 0) ∨ t ∈ S2"
assumes 3: "⋀s. s ∈ S - S1 - S2 ⟹ x s = 0"
assumes 4: "⋀s. s ∈ S2 ⟹ x s = 1"
shows "x s ≤ enn2real (n s)"
proof (rule n_lb[OF 1 _ 4])
fix s assume *: "s ∈ S" "n s = 0"
show "x s = 0"
proof (rule ccontr)
assume "x s ≠ 0"
with * n_S2[of s] n_nS12[of s] 3[of s] have "s ∈ S1"
by (metis DiffI zero_neq_one)
have "0 < n s"
by (intro n_pos[of "λs. x s ≠ 0", OF ‹x s ≠ 0› ‹s ∈ S1› ‹wf R›])
(metis zero_less_one n_S2 2)
with ‹n s = 0› show False by auto
qed
qed
end
no_notation Stream.snth (infixl "!!" 100)
subsection ‹Computable representation›
record mdp_reachability_problem =
state_count :: nat
distrs :: "(nat × rat) list list iarray"
states1 :: "bool iarray"
states2 :: "bool iarray"
record 'a RP_sub_cert =
solution :: "rat iarray"
witness :: "('a × nat) iarray"
record RP_cert =
pos_cert :: "(nat × nat) RP_sub_cert"
neg_cert :: "nat list RP_sub_cert"
definition "sparse_mult sx y = sum_list (map (λ(n, x). x * y !! n) sx)"
primrec lookup where
"lookup d [] x = d"
| "lookup d (y#ys) x = (if fst y = x then snd y else lookup d ys x)"
lemma lookup_eq_map_of: "lookup d xs x = (case map_of xs x of Some x ⇒ x | None ⇒ d)"
by (induct xs) simp_all
lemma lookup_in_set:
"distinct (map fst xs) ⟹ x ∈ set xs ⟹ lookup d xs (fst x) = snd x"
unfolding lookup_eq_map_of by (subst map_of_is_SomeI[where y="snd x"]) simp_all
lemma lookup_not_in_set:
"x ∉ fst ` set xs ⟹ lookup d xs x = d"
unfolding lookup_eq_map_of
by (subst map_of_eq_None_iff[of xs x, THEN iffD2]) auto
lemma lookup_nonneg:
"(⋀x v. (x, v) ∈ set xs ⟹ 0 ≤ v) ⟹ (0::'a::ordered_comm_monoid_add) ≤ lookup 0 xs x"
apply (induction xs)
apply simp
apply force
done
lemma sparse_mult_eq_sum_lookup:
fixes xs :: "(nat × 'a::comm_semiring_1) list"
assumes "list_all (λ(n, x). n < M) xs" "distinct (map fst xs)"
shows "sparse_mult xs y = (∑i<M. lookup 0 xs i * y !! i)"
proof -
from ‹distinct (map fst xs)› have "distinct xs" "inj_on fst (set xs)"
by (simp_all add: distinct_map)
then have "sparse_mult xs y = (∑x∈set xs. snd x * y !! fst x)"
by (auto intro!: sum.cong simp add: sparse_mult_def sum_list_distinct_conv_sum_set)
also have "… = (∑x∈set xs. lookup 0 xs (fst x) * y !! fst x)"
by (intro sum.cong refl arg_cong2[where f="(*)"]) (simp add: lookup_in_set assms)
also have "… = (∑x∈fst ` set xs. lookup 0 xs x * y !! x)"
using ‹inj_on fst (set xs)› by (simp add: sum.reindex)
also have "… = (∑x<M. lookup 0 xs x * y !! x)"
using assms(1)
by (intro sum.mono_neutral_cong_left)
(auto simp: list_all_iff lookup_eq_map_of map_of_eq_None_iff[THEN iffD2])
finally show ?thesis .
qed
lemma sum_list_eq_sum_lookup:
fixes xs :: "(nat × 'a::comm_semiring_1) list"
assumes "list_all (λ(n, x). n < M) xs" "distinct (map fst xs)"
shows "sum_list (map snd xs) = (∑i<M. lookup 0 xs i)"
proof -
from ‹distinct (map fst xs)› have "distinct xs" "inj_on fst (set xs)"
by (simp_all add: distinct_map)
then have "sum_list (map snd xs) = (∑x∈set xs. snd x)"
by (auto intro!: sum.cong simp add: sparse_mult_def sum_list_distinct_conv_sum_set)
also have "… = (∑x∈set xs. lookup 0 xs (fst x))"
by (intro sum.cong refl arg_cong2[where f="(*)"]) (simp add: lookup_in_set assms)
also have "… = (∑x∈fst ` set xs. lookup 0 xs x)"
using ‹inj_on fst (set xs)› by (simp add: sum.reindex)
also have "… = (∑x<M. lookup 0 xs x)"
using assms(1)
by (intro sum.mono_neutral_cong_left)
(auto simp: list_all_iff lookup_eq_map_of map_of_eq_None_iff[THEN iffD2])
finally show ?thesis .
qed
definition
"valid_mdp_rp mdp ⟷
0 < state_count mdp ∧
IArray.length (distrs mdp) = state_count mdp ∧
IArray.length (states1 mdp) = state_count mdp ∧
IArray.length (states2 mdp) = state_count mdp ∧
(∀i<state_count mdp. ¬ (states1 mdp !! i ∧ states2 mdp !! i) ∧
list_all (λds. distinct (map fst ds) ∧ list_all (λ(n, x). 0 ≤ x ∧ n < state_count mdp) ds ∧
sum_list (map snd ds) = 1) (distrs mdp !! i) ∧
¬ List.null (distrs mdp !! i))"
definition
"valid_sub_cert mdp c ord check ⟷
IArray.length (witness c) = state_count mdp ∧
IArray.length (solution c) = state_count mdp ∧
(∀i<state_count mdp.
if states2 mdp !! i then solution c !! i = 1
else if states1 mdp !! i then 0 ≤ solution c !! i ∧
(list_all (λds. ord (sparse_mult ds (solution c)) (solution c !! i)) (distrs mdp !! i)) ∧
(0 < solution c !! i ⟶ check (distrs mdp !! i) (witness c !! i))
else solution c !! i = 0)"
definition
"valid_pos_cert mdp c ⟷
valid_sub_cert mdp c (≤)
(λD ((j, a), n). j < state_count mdp ∧ snd (witness c !! j) < n ∧ 0 < solution c !! j ∧
a < length D ∧ lookup 0 (D ! a) j ≠ 0)"
definition
"valid_neg_cert mdp c ⟷
valid_sub_cert mdp c (≥)
(λD (J, n). list_all2 (λj d. j < state_count mdp ∧ snd (witness c !! j) < n ∧
lookup 0 d j ≠ 0 ∧ 0 < solution c !! j) J D)"
definition
"valid_cert mdp c ⟷ valid_pos_cert mdp (pos_cert c) ∧ valid_neg_cert mdp (neg_cert c)"
lemma valid_mdp_rpD_length:
assumes "valid_mdp_rp mdp"
shows "0 < state_count mdp" "IArray.length (distrs mdp) = state_count mdp"
"IArray.length (states1 mdp) = state_count mdp" "IArray.length (states2 mdp) = state_count mdp"
using assms by (auto simp: valid_mdp_rp_def)
lemma valid_mdp_rpD:
assumes "valid_mdp_rp mdp" "i < state_count mdp"
shows "¬ (states1 mdp !! i ∧ states2 mdp !! i)"
and "⋀ds n x. ds ∈ set (distrs mdp !! i) ⟹ (n, x) ∈ set ds ⟹ n < state_count mdp"
and "⋀ds n x. ds ∈ set (distrs mdp !! i) ⟹ (n, x) ∈ set ds ⟹ 0 ≤ x"
and "⋀ds. ds ∈ set (distrs mdp !! i) ⟹ sum_list (map snd ds) = 1"
and "⋀ds. ds ∈ set (distrs mdp !! i) ⟹ distinct (map fst ds)"
and "distrs mdp !! i ≠ []"
using assms by (auto simp: valid_mdp_rp_def list_all_iff List.null_def elim!: allE[of _ i])
lemma valid_mdp_rp_sparse_mult:
assumes "valid_mdp_rp mdp" "i < state_count mdp" "ds ∈ set (distrs mdp !! i)"
shows "sparse_mult ds y = (∑i<state_count mdp. lookup 0 ds i * y !! i)"
using valid_mdp_rpD(2,5)[OF assms] by (intro sparse_mult_eq_sum_lookup) (auto simp: list_all_iff)
lemma valid_sub_certD:
assumes "valid_mdp_rp mdp" "valid_sub_cert mdp c ord check" "i < state_count mdp"
shows "¬ states1 mdp !! i ⟹ ¬ states2 mdp !! i ⟹ solution c !! i = 0"
and "states2 mdp !! i ⟹ solution c !! i = 1"
and "states1 mdp !! i ⟹ 0 ≤ solution c !! i"
and "⋀ds. states1 mdp !! i ⟹ ds ∈ set (distrs mdp !! i) ⟹ ord (sparse_mult ds (solution c)) (solution c !! i)"
and "⋀ds. states1 mdp !! i ⟹ 0 < solution c !! i ⟶ check (distrs mdp !! i) (witness c !! i)"
using assms(2,3) valid_mdp_rpD(1)[OF assms(1,3)]
by (auto simp add: valid_sub_cert_def list_all_iff)
lemma valid_pos_certD:
assumes "valid_mdp_rp mdp" "valid_pos_cert mdp c" "i < state_count mdp" "states1 mdp !! i"
"0 < solution c !! i" "witness c !! i = ((j, a), n)"
shows "snd (witness c !! j) < n ∧ j < state_count mdp ∧ a < length (distrs mdp !! i) ∧
lookup 0 ((distrs mdp !! i) ! a) j ≠ 0 ∧ 0 < solution c !! j"
using valid_sub_certD(5)[OF assms(1) assms(2)[unfolded valid_pos_cert_def] assms(3,4)] assms(5-) by auto
lemma valid_neg_certD:
assumes "valid_mdp_rp mdp" "valid_neg_cert mdp c" "i < state_count mdp" "states1 mdp !! i"
"0 < solution c !! i" "witness c !! i = (js, n)"
shows "list_all2 (λj ds. j < state_count mdp ∧ snd (witness c !! j) < n ∧ lookup 0 ds j ≠ 0 ∧ 0 < solution c !! j) js (distrs mdp !! i)"
using valid_sub_certD(5)[OF assms(1) assms(2)[unfolded valid_neg_cert_def] assms(3)] assms(4-) by auto
context
fixes mdp c
assumes rp: "valid_mdp_rp mdp"
assumes cert: "valid_cert mdp c"
begin
interpretation pmf_as_function .
abbreviation "S ≡ {..< state_count mdp}"
abbreviation "S1 ≡ {i. i < state_count mdp ∧ (states1 mdp) !! i}"
abbreviation "S2 ≡ {i. i < state_count mdp ∧ (states2 mdp) !! i}"
lift_definition K :: "nat ⇒ nat pmf set" is
"λi. if i < state_count mdp then
{ (λj. of_rat (lookup 0 D j) :: real) | D. D ∈ set (distrs mdp !! i) }
else { indicator {0} }"
proof (auto split: if_split_asm simp del: IArray.sub_def)
fix n D assume n: "n < state_count mdp" and D: "D ∈ set (distrs mdp !! n)"
from valid_mdp_rpD(3)[OF rp this] show nn: "⋀i. 0 ≤ lookup 0 D i"
by (auto simp add: lookup_eq_map_of split: option.split dest: map_of_SomeD)
show "(∫⇧+ x. ennreal (real_of_rat (lookup 0 D x)) ∂count_space UNIV) = 1"
using valid_mdp_rpD(2,3,4,5)[OF rp n D]
apply (subst nn_integral_count_space'[of "{..< state_count mdp}"])
apply (auto intro: nn lookup_not_in_set simp: of_rat_sum[symmetric] lookup_nonneg)
apply (subst sum_list_eq_sum_lookup[symmetric])
apply (auto simp: list_all_iff lookup_eq_map_of split: option.split)
done
next
show "(∫⇧+ x. ennreal (indicator {0} x) ∂count_space UNIV) = 1"
by (subst nn_integral_count_space'[of "{0}"]) auto
qed
interpretation MDP: Reachability_Problem K S S1 S2
proof
show "S1 ∩ S2 = {}" "S1 ⊆ S" "S2 ⊆ S"
using valid_mdp_rpD(1)[OF rp] by auto
show "finite S" "S ≠ {}"
using ‹valid_mdp_rp mdp› by (auto simp add: valid_mdp_rp_def)
show "⋀s. K s ≠ {}"
using valid_mdp_rpD(6)[OF rp] by transfer simp
show "⋀s. finite (K s)"
by transfer simp
fix s assume "s ∈ S" then show "(⋃D∈K s. set_pmf D) ⊆ S"
using valid_mdp_rpD(2)[OF rp]
by transfer (auto simp: lookup_eq_map_of split: option.splits dest!: map_of_SomeD)
qed
definition "P_max s = enn2real (MDP.p s)"
definition "P_min s = enn2real (MDP.n s)"
lemma
assumes "i < state_count mdp"
shows P_max: "P_max i ≤ real_of_rat (solution (pos_cert c) !! i)" (is ?max)
and P_min: "P_min i ≥ real_of_rat (solution (neg_cert c) !! i)" (is ?min)
proof -
have "valid_pos_cert mdp (pos_cert c)" "valid_neg_cert mdp (neg_cert c)"
using ‹valid_cert mdp c› by (auto simp: valid_cert_def)
note pos = this(1)[unfolded valid_pos_cert_def] and neg = this(2)[unfolded valid_neg_cert_def]
let ?x = "λs. real_of_rat (solution (pos_cert c) !! s)"
have "enn2real (MDP.p i) ≤ ?x i"
proof (rule MDP.p_ub')
show "i ∈ S" using assms by simp
next
fix s D assume "s ∈ S1" "D ∈ K s"
then obtain j where j: "j < length (distrs mdp !! s)"
"⋀i. i < state_count mdp ⟹ pmf D i = real_of_rat (lookup 0 (distrs mdp !! s ! j) i)"
by transfer (auto simp: in_set_conv_nth)
with valid_sub_certD(4)[OF ‹valid_mdp_rp mdp› pos, of s "distrs mdp !! s ! j"] ‹s ∈ S1›
valid_mdp_rp_sparse_mult[OF ‹valid_mdp_rp mdp›, of s "distrs mdp !! s ! j" "solution (pos_cert c)"]
show "(∑t∈S. pmf D t * ?x t) ≤ ?x s"
by (simp add: of_rat_mult[symmetric] of_rat_sum[symmetric] of_rat_less_eq j)
next
fix s a assume "s ∈ S2" then show "?x s = 1"
using valid_sub_certD[OF ‹valid_mdp_rp mdp› pos] by simp
next
fix s define X where "X = (SIGMA s:S1. ⋃D∈K s. set_pmf D)"
assume "s ∈ S1" "?x s ≠ 0"
with valid_sub_certD(3)[OF rp pos, of s]
have "0 < ?x s"
by simp
with ‹s∈S1› show "∃t∈S2. (s, t) ∈ X⇧*"
proof (induction n≡"snd (witness (pos_cert c) !! s)" arbitrary: s rule: less_induct)
case (less s)
obtain t a n where eq: "witness (pos_cert c) !! s = ((t, a), n)"
by (metis prod.exhaust)
from valid_pos_certD[OF rp ‹valid_pos_cert mdp (pos_cert c)› _ _ _ this] less.prems
have ord: "snd (witness (pos_cert c) !! t) < snd (witness (pos_cert c) !! s)"
and t: "lookup 0 (distrs mdp !! s ! a) t ≠ 0" "0 < ?x t" "t∈S" "a < length (distrs mdp !! s)"
unfolding eq by auto
with ‹s∈S1› have X: "(s, t) ∈ X"
unfolding X_def
by (transfer fixing: s t a c)
(auto simp: X_def in_set_conv_nth
intro!: exI[of _ "λj. real_of_rat (lookup 0 (distrs mdp !! s ! a) j)"]
exI[of _ "distrs mdp !! s ! a"] exI[of _ a])
show ?case
proof cases
assume "t ∈ S1"
with less.hyps[OF ord _ ‹0 < ?x t›] X show ?thesis
by auto
next
assume "t ∉ S1"
with valid_sub_certD[OF ‹valid_mdp_rp mdp› pos, of t] ‹0 < ?x t› ‹t∈S›
have "t ∈ S2"
by auto
with X show ?thesis
by auto
qed
qed
next
fix s assume "s ∈ S - S1 - S2" then show "?x s = 0"
using valid_sub_certD(1)[OF ‹valid_mdp_rp mdp› pos, of s] by simp
qed
then show ?max
by (simp add: P_max_def)
let ?x = "λs. real_of_rat (solution (neg_cert c) !! s)"
have "?x i ≤ enn2real (MDP.n i)"
proof (rule MDP.n_lb')
show "i ∈ S" using assms by simp
next
fix s D assume "s ∈ S1" "D ∈ K s"
then obtain j where j: "j < length (distrs mdp !! s)"
"⋀i. i < state_count mdp ⟹ pmf D i = real_of_rat (lookup 0 (distrs mdp !! s ! j) i)"
by transfer (auto simp: in_set_conv_nth)
with valid_sub_certD(4)[OF ‹valid_mdp_rp mdp› neg, of s "distrs mdp !! s ! j"] ‹s ∈ S1›
valid_mdp_rp_sparse_mult[OF ‹valid_mdp_rp mdp›, of s "distrs mdp !! s ! j" "solution (neg_cert c)"]
show "?x s ≤ (∑t∈S. pmf D t * ?x t)"
by (simp add: of_rat_mult[symmetric] of_rat_sum[symmetric] of_rat_less_eq j)
next
fix s a assume "s ∈ S2" then show "?x s = 1"
using valid_sub_certD[OF ‹valid_mdp_rp mdp› neg] by simp
next
show "wf ((S × S ∩ {(s, t). snd (witness (neg_cert c) !! t) < snd (witness (neg_cert c) !! s)})¯)" (is "wf ?F")
using MDP.S_finite
by (intro finite_acyclic_wf_converse acyclicI_order[where f="λs. snd (witness (neg_cert c) !! s)"]) auto
fix s D assume 2: "s ∈ S1" "D ∈ K s" and "?x s ≠ 0"
then have "0 < ?x s"
using valid_sub_certD(3)[OF ‹valid_mdp_rp mdp› neg, of s] by auto
from 2 obtain a where a: "a < length (distrs mdp !! s)"
"⋀i. i < state_count mdp ⟹ pmf D i = real_of_rat (lookup 0 (distrs mdp !! s ! a) i)"
by transfer (auto simp: in_set_conv_nth)
obtain js n where eq: "witness (neg_cert c) !! s = (js, n)"
by (metis prod.exhaust)
from valid_neg_certD[OF ‹valid_mdp_rp mdp› ‹valid_neg_cert mdp (neg_cert c)› _ _ _ eq] a ‹s ∈ S1› ‹0 < ?x s›
have *: "length js = length (distrs mdp !! s)" "js ! a ∈ S"
"snd (witness (neg_cert c) !! (js ! a)) < snd (witness (neg_cert c) !! s)"
"lookup 0 (distrs mdp !! s ! a) (js ! a) ≠ 0"
"0 < ?x (js ! a)"
unfolding eq by (auto dest: list_all2_nthD2 list_all2_lengthD)
with a ‹s ∈ S1› have js_a: "js ! a ∈ D" "(js ! a, s) ∈ ?F"
by (auto simp: set_pmf_iff)
show "∃t∈D. (t, s) ∈ ?F ∧ t ∈ S1 ∧ ?x t ≠ 0 ∨ t ∈ S2"
proof cases
assume "js ! a ∈ S1" with js_a ‹0 < ?x (js ! a)› show ?thesis by auto
next
assume "js ! a ∉ S1"
with ‹0 < ?x (js ! a)› ‹js!a ∈ S› valid_sub_certD[OF rp neg, of "js ! a"]
have "js ! a ∈ S2"
by (auto simp: less_le)
with ‹js ! a ∈ D› show ?thesis
by auto
qed
next
fix s assume "s ∈ S - S1 - S2" then show "?x s = 0"
using valid_sub_certD(1)[OF ‹valid_mdp_rp mdp› neg, of s] by simp
qed
then show ?min
by (simp add: P_min_def)
qed
end
end