Theory State_Transformer
section ‹State monad transformer›
theory State_Transformer
imports Monad_Zero_Plus
begin
text ‹
This version has non-lifted product, and a non-lifted function space.
›
tycondef 'a⋅('f::"functor", 's) stateT =
StateT (runStateT :: "'s → ('a × 's)⋅'f")
lemma coerce_stateT_abs [simp]: "coerce⋅(stateT_abs⋅x) = stateT_abs⋅(coerce⋅x)"
apply (simp add: stateT_abs_def coerce_def)
apply (simp add: emb_prj_emb prj_emb_prj DEFL_eq_stateT)
done
lemma coerce_StateT [simp]: "coerce⋅(StateT⋅k) = StateT⋅(coerce⋅k)"
unfolding StateT_def by simp
lemma stateT_cases [case_names StateT]:
obtains k where "y = StateT⋅k"
proof
show "y = StateT⋅(runStateT⋅y)"
by (cases y, simp_all)
qed
lemma stateT_induct [case_names StateT]:
fixes P :: "'a⋅('f::functor,'s) stateT ⇒ bool"
assumes "⋀k. P (StateT⋅k)"
shows "P y"
by (cases y rule: stateT_cases, simp add: assms)
lemma stateT_eqI:
"(⋀s. runStateT⋅a⋅s = runStateT⋅b⋅s) ⟹ a = b"
apply (cases a rule: stateT_cases)
apply (cases b rule: stateT_cases)
apply (simp add: cfun_eq_iff)
done
lemma runStateT_coerce [simp]:
"runStateT⋅(coerce⋅k)⋅s = coerce⋅(runStateT⋅k⋅s)"
by (induct k rule: stateT_induct, simp)
subsection ‹Functor class instance›
lemma fmapU_StateT [simp]:
"fmapU⋅f⋅(StateT⋅k) =
StateT⋅(Λ s. fmap⋅(Λ(x, s'). (f⋅x, s'))⋅(k⋅s))"
unfolding fmapU_stateT_def stateT_map_def StateT_def
by (subst fix_eq, simp add: cfun_map_def csplit_def prod_map_def)
lemma runStateT_fmapU [simp]:
"runStateT⋅(fmapU⋅f⋅m)⋅s =
fmap⋅(Λ(x, s'). (f⋅x, s'))⋅(runStateT⋅m⋅s)"
by (cases m rule: stateT_cases, simp)
instantiation stateT :: ("functor", "domain") "functor"
begin
instance
apply standard
apply (induct_tac xs rule: stateT_induct)
apply (simp_all add: fmap_fmap ID_def csplit_def)
done
end
subsection ‹Monad class instance›
instantiation stateT :: (monad, "domain") monad
begin
definition returnU_stateT_def:
"returnU = (Λ x. StateT⋅(Λ s. return⋅(x, s)))"
definition bindU_stateT_def:
"bindU = (Λ m k. StateT⋅(Λ s. runStateT⋅m⋅s ⤜ (Λ (x, s'). runStateT⋅(k⋅x)⋅s')))"
lemma bindU_stateT_StateT [simp]:
"bindU⋅(StateT⋅f)⋅k =
StateT⋅(Λ s. f⋅s ⤜ (Λ (x, s'). runStateT⋅(k⋅x)⋅s'))"
unfolding bindU_stateT_def by simp
lemma runStateT_bindU [simp]:
"runStateT⋅(bindU⋅m⋅k)⋅s = runStateT⋅m⋅s ⤜ (Λ (x, s'). runStateT⋅(k⋅x)⋅s')"
unfolding bindU_stateT_def by simp
instance proof
fix f :: "udom → udom" and r :: "udom⋅('a,'b) stateT"
show "fmapU⋅f⋅r = bindU⋅r⋅(Λ x. returnU⋅(f⋅x))"
by (rule stateT_eqI)
(simp add: returnU_stateT_def monad_fmap prod_map_def csplit_def)
next
fix f :: "udom → udom⋅('a,'b) stateT" and x :: "udom"
show "bindU⋅(returnU⋅x)⋅f = f⋅x"
by (rule stateT_eqI)
(simp add: returnU_stateT_def eta_cfun)
next
fix r :: "udom⋅('a,'b) stateT" and f g :: "udom → udom⋅('a,'b) stateT"
show "bindU⋅(bindU⋅r⋅f)⋅g = bindU⋅r⋅(Λ x. bindU⋅(f⋅x)⋅g)"
by (rule stateT_eqI)
(simp add: bind_bind csplit_def)
qed
end
subsection ‹Monad zero instance›
instantiation stateT :: (monad_zero, "domain") monad_zero
begin
definition zeroU_stateT_def:
"zeroU = StateT⋅(Λ s. mzero)"
lemma runStateT_zeroU [simp]:
"runStateT⋅zeroU⋅s = mzero"
unfolding zeroU_stateT_def by simp
instance proof
fix k :: "udom → udom⋅('a,'b) stateT"
show "bindU⋅zeroU⋅k = zeroU"
by (rule stateT_eqI, simp add: bind_mzero)
qed
end
subsection ‹Monad plus instance›
instantiation stateT :: (monad_plus, "domain") monad_plus
begin
definition plusU_stateT_def:
"plusU = (Λ a b. StateT⋅(Λ s. mplus⋅(runStateT⋅a⋅s)⋅(runStateT⋅b⋅s)))"
lemma runStateT_plusU [simp]:
"runStateT⋅(plusU⋅a⋅b)⋅s =
mplus⋅(runStateT⋅a⋅s)⋅(runStateT⋅b⋅s)"
unfolding plusU_stateT_def by simp
instance proof
fix a b :: "udom⋅('a, 'b) stateT" and k :: "udom → udom⋅('a, 'b) stateT"
show "bindU⋅(plusU⋅a⋅b)⋅k = plusU⋅(bindU⋅a⋅k)⋅(bindU⋅b⋅k)"
by (rule stateT_eqI, simp add: bind_mplus)
next
fix a b c :: "udom⋅('a, 'b) stateT"
show "plusU⋅(plusU⋅a⋅b)⋅c = plusU⋅a⋅(plusU⋅b⋅c)"
by (rule stateT_eqI, simp add: mplus_assoc)
qed
end
subsection ‹Monad zero plus instance›
instance stateT :: (monad_zero_plus, "domain") monad_zero_plus
proof
fix m :: "udom⋅('a, 'b) stateT"
show "plusU⋅zeroU⋅m = m"
by (rule stateT_eqI, simp add: mplus_mzero_left)
next
fix m :: "udom⋅('a, 'b) stateT"
show "plusU⋅m⋅zeroU = m"
by (rule stateT_eqI, simp add: mplus_mzero_right)
qed
subsection ‹Transfer properties to polymorphic versions›
lemma coerce_csplit [coerce_simp]:
shows "coerce⋅(csplit⋅f⋅p) = csplit⋅(Λ x y. coerce⋅(f⋅x⋅y))⋅p"
unfolding csplit_def by simp
lemma csplit_coerce [coerce_simp]:
fixes p :: "'a × 'b"
shows "csplit⋅f⋅(COERCE('a × 'b, 'c × 'd)⋅p) =
csplit⋅(Λ x y. f⋅(COERCE('a, 'c)⋅x)⋅(COERCE('b, 'd)⋅y))⋅p"
unfolding coerce_prod csplit_def prod_map_def by simp
lemma fmap_stateT_simps [simp]:
"fmap⋅f⋅(StateT⋅m :: 'a⋅('f::functor,'s) stateT) =
StateT⋅(Λ s. fmap⋅(Λ (x, s'). (f⋅x, s'))⋅(m⋅s))"
unfolding fmap_def [where 'f="('f, 's) stateT"]
by (simp add: coerce_simp eta_cfun)
lemma runStateT_fmap [simp]:
"runStateT⋅(fmap⋅f⋅m)⋅s = fmap⋅(Λ (x, s'). (f⋅x, s'))⋅(runStateT⋅m⋅s)"
by (induct m rule: stateT_induct, simp)
lemma return_stateT_def:
"(return :: _ → 'a⋅('m::monad, 's) stateT) =
(Λ x. StateT⋅(Λ s. return⋅(x, s)))"
unfolding return_def [where 'm="('m, 's) stateT"] returnU_stateT_def
by (simp add: coerce_simp)
lemma bind_stateT_def:
"bind = (Λ m k. StateT⋅(Λ s. runStateT⋅m⋅s ⤜ (Λ (x, s'). runStateT⋅(k⋅x)⋅s')))"
apply (subst bind_def, subst bindU_stateT_def)
apply (simp add: coerce_simp)
apply (simp add: coerce_idem domain_defl_simps monofun_cfun)
apply (simp add: eta_cfun)
done
text "TODO: add ‹coerce_idem› to ‹coerce_simps›, along\010with monotonicity rules for DEFL."
lemma bind_stateT_simps [simp]:
"bind⋅(StateT⋅m :: 'a⋅('m::monad,'s) stateT)⋅k =
StateT⋅(Λ s. m⋅s ⤜ (Λ (x, s'). runStateT⋅(k⋅x)⋅s'))"
unfolding bind_stateT_def by simp
lemma runStateT_bind [simp]:
"runStateT⋅(m ⤜ k)⋅s = runStateT⋅m⋅s ⤜ (Λ (x, s'). runStateT⋅(k⋅x)⋅s')"
unfolding bind_stateT_def by simp
end