Theory State_Transformer

section ‹State monad transformer›

theory State_Transformer
imports Monad_Zero_Plus
begin

text ‹
  This version has non-lifted product, and a non-lifted function space.
›

tycondef 'a('f::"functor", 's) stateT =
  StateT (runStateT :: "'s  ('a × 's)'f")

lemma coerce_stateT_abs [simp]: "coerce(stateT_absx) = stateT_abs(coercex)"
apply (simp add: stateT_abs_def coerce_def)
apply (simp add: emb_prj_emb prj_emb_prj DEFL_eq_stateT)
done

lemma coerce_StateT [simp]: "coerce(StateTk) = StateT(coercek)"
unfolding StateT_def by simp

lemma stateT_cases [case_names StateT]:
  obtains k where "y = StateTk"
proof
  show "y = StateT(runStateTy)"
    by (cases y, simp_all)
qed

lemma stateT_induct [case_names StateT]:
  fixes P :: "'a('f::functor,'s) stateT  bool"
  assumes "k. P (StateTk)"
  shows "P y"
by (cases y rule: stateT_cases, simp add: assms)

lemma stateT_eqI:
  "(s. runStateTas = runStateTbs)  a = b"
apply (cases a rule: stateT_cases)
apply (cases b rule: stateT_cases)
apply (simp add: cfun_eq_iff)
done

lemma runStateT_coerce [simp]:
  "runStateT(coercek)s = coerce(runStateTks)"
by (induct k rule: stateT_induct, simp)

subsection ‹Functor class instance›

lemma fmapU_StateT [simp]:
  "fmapUf(StateTk) =
    StateT(Λ s. fmap(Λ(x, s'). (fx, s'))(ks))"
unfolding fmapU_stateT_def stateT_map_def StateT_def
by (subst fix_eq, simp add: cfun_map_def csplit_def prod_map_def)

lemma runStateT_fmapU [simp]:
  "runStateT(fmapUfm)s =
    fmap(Λ(x, s'). (fx, s'))(runStateTms)"
by (cases m rule: stateT_cases, simp)

instantiation stateT :: ("functor", "domain") "functor"
begin

instance
apply standard
apply (induct_tac xs rule: stateT_induct)
apply (simp_all add: fmap_fmap ID_def csplit_def)
done

end

subsection ‹Monad class instance›

instantiation stateT :: (monad, "domain") monad
begin

definition returnU_stateT_def:
  "returnU = (Λ x. StateT(Λ s. return(x, s)))"

definition bindU_stateT_def:
  "bindU = (Λ m k. StateT(Λ s. runStateTms  (Λ (x, s'). runStateT(kx)s')))"

lemma bindU_stateT_StateT [simp]:
  "bindU(StateTf)k =
    StateT(Λ s. fs  (Λ (x, s'). runStateT(kx)s'))"
unfolding bindU_stateT_def by simp

lemma runStateT_bindU [simp]:
  "runStateT(bindUmk)s = runStateTms  (Λ (x, s'). runStateT(kx)s')"
unfolding bindU_stateT_def by simp

instance proof
  fix f :: "udom  udom" and r :: "udom('a,'b) stateT"
  show "fmapUfr = bindUr(Λ x. returnU(fx))"
    by (rule stateT_eqI)
       (simp add: returnU_stateT_def monad_fmap prod_map_def csplit_def)
next
  fix f :: "udom  udom('a,'b) stateT" and x :: "udom"
  show "bindU(returnUx)f = fx"
    by (rule stateT_eqI)
       (simp add: returnU_stateT_def eta_cfun)
next
  fix r :: "udom('a,'b) stateT" and f g :: "udom  udom('a,'b) stateT"
  show "bindU(bindUrf)g = bindUr(Λ x. bindU(fx)g)"
    by (rule stateT_eqI)
       (simp add: bind_bind csplit_def)
qed

end

subsection ‹Monad zero instance›

instantiation stateT :: (monad_zero, "domain") monad_zero
begin

definition zeroU_stateT_def:
  "zeroU = StateT(Λ s. mzero)"

lemma runStateT_zeroU [simp]:
  "runStateTzeroUs = mzero"
unfolding zeroU_stateT_def by simp

instance proof
  fix k :: "udom  udom('a,'b) stateT"
  show "bindUzeroUk = zeroU"
    by (rule stateT_eqI, simp add: bind_mzero)
qed

end

subsection ‹Monad plus instance›

instantiation stateT :: (monad_plus, "domain") monad_plus
begin

definition plusU_stateT_def:
  "plusU = (Λ a b. StateT(Λ s. mplus(runStateTas)(runStateTbs)))"

lemma runStateT_plusU [simp]:
  "runStateT(plusUab)s =
    mplus(runStateTas)(runStateTbs)"
unfolding plusU_stateT_def by simp

instance proof
  fix a b :: "udom('a, 'b) stateT" and k :: "udom  udom('a, 'b) stateT"
  show "bindU(plusUab)k = plusU(bindUak)(bindUbk)"
    by (rule stateT_eqI, simp add: bind_mplus)
next
  fix a b c :: "udom('a, 'b) stateT"
  show "plusU(plusUab)c = plusUa(plusUbc)"
    by (rule stateT_eqI, simp add: mplus_assoc)
qed

end

subsection ‹Monad zero plus instance›

instance stateT :: (monad_zero_plus, "domain") monad_zero_plus
proof
  fix m :: "udom('a, 'b) stateT"
  show "plusUzeroUm = m"
    by (rule stateT_eqI, simp add: mplus_mzero_left)
next
  fix m :: "udom('a, 'b) stateT"
  show "plusUmzeroU = m"
    by (rule stateT_eqI, simp add: mplus_mzero_right)
qed

subsection ‹Transfer properties to polymorphic versions›

lemma coerce_csplit [coerce_simp]:
  shows "coerce(csplitfp) = csplit(Λ x y. coerce(fxy))p"
unfolding csplit_def by simp

lemma csplit_coerce [coerce_simp]:
  fixes p :: "'a × 'b"
  shows "csplitf(COERCE('a × 'b, 'c × 'd)p) =
    csplit(Λ x y. f(COERCE('a, 'c)x)(COERCE('b, 'd)y))p"
unfolding coerce_prod csplit_def prod_map_def by simp

lemma fmap_stateT_simps [simp]:
  "fmapf(StateTm :: 'a('f::functor,'s) stateT) =
    StateT(Λ s. fmap(Λ (x, s'). (fx, s'))(ms))"
unfolding fmap_def [where 'f="('f, 's) stateT"]
by (simp add: coerce_simp eta_cfun)

lemma runStateT_fmap [simp]:
  "runStateT(fmapfm)s = fmap(Λ (x, s'). (fx, s'))(runStateTms)"
by (induct m rule: stateT_induct, simp)

lemma return_stateT_def:
  "(return :: _  'a('m::monad, 's) stateT) =
    (Λ x. StateT(Λ s. return(x, s)))"
unfolding return_def [where 'm="('m, 's) stateT"] returnU_stateT_def
by (simp add: coerce_simp)

lemma bind_stateT_def:
  "bind = (Λ m k. StateT(Λ s. runStateTms  (Λ (x, s'). runStateT(kx)s')))"
apply (subst bind_def, subst bindU_stateT_def)
apply (simp add: coerce_simp)
apply (simp add: coerce_idem domain_defl_simps monofun_cfun)
apply (simp add: eta_cfun)
done

text "TODO: add coerce_idem› to coerce_simps›, along\010with monotonicity rules for DEFL."

lemma bind_stateT_simps [simp]:
  "bind(StateTm :: 'a('m::monad,'s) stateT)k =
    StateT(Λ s. ms  (Λ (x, s'). runStateT(kx)s'))"
unfolding bind_stateT_def by simp

lemma runStateT_bind [simp]:
  "runStateT(m  k)s = runStateTms  (Λ (x, s'). runStateT(kx)s')"
unfolding bind_stateT_def by simp

end