Theory Error_Transformer
section ‹Error monad transformer›
theory Error_Transformer
imports Error_Monad
begin
subsection ‹Type definition›
text ‹The error monad transformer is defined in Haskell by composing
the given monad with a standard error monad:›
text_raw ‹
\begin{verbatim}
data Error e a = Err e | Ok a
newtype ErrorT e m a = ErrorT { runErrorT :: m (Error e a) }
\end{verbatim}
›
text ‹We can formalize this definition directly using ‹tycondef›. \medskip›
tycondef 'a⋅('f::"functor",'e::"domain") errorT =
ErrorT (runErrorT :: "('a⋅'e error)⋅'f")
lemma coerce_errorT_abs [simp]: "coerce⋅(errorT_abs⋅x) = errorT_abs⋅(coerce⋅x)"
apply (simp add: errorT_abs_def coerce_def)
apply (simp add: emb_prj_emb prj_emb_prj DEFL_eq_errorT)
done
lemma coerce_ErrorT [simp]: "coerce⋅(ErrorT⋅k) = ErrorT⋅(coerce⋅k)"
unfolding ErrorT_def by simp
lemma errorT_cases [case_names ErrorT]:
obtains k where "y = ErrorT⋅k"
proof
show "y = ErrorT⋅(runErrorT⋅y)"
by (cases y, simp_all)
qed
lemma ErrorT_runErrorT [simp]: "ErrorT⋅(runErrorT⋅m) = m"
by (cases m rule: errorT_cases, simp)
lemma errorT_induct [case_names ErrorT]:
fixes P :: "'a⋅('f::functor,'e) errorT ⇒ bool"
assumes "⋀k. P (ErrorT⋅k)"
shows "P y"
by (cases y rule: errorT_cases, simp add: assms)
lemma errorT_eq_iff:
"a = b ⟷ runErrorT⋅a = runErrorT⋅b"
apply (cases a rule: errorT_cases)
apply (cases b rule: errorT_cases)
apply simp
done
lemma errorT_eqI:
"runErrorT⋅a = runErrorT⋅b ⟹ a = b"
by (simp add: errorT_eq_iff)
lemma runErrorT_coerce [simp]:
"runErrorT⋅(coerce⋅k) = coerce⋅(runErrorT⋅k)"
by (induct k rule: errorT_induct, simp)
subsection ‹Functor class instance›
lemma fmap_error_def: "fmap = error_map⋅ID"
apply (rule cfun_eqI, rename_tac f)
apply (rule cfun_eqI, rename_tac x)
apply (case_tac x rule: error.exhaust, simp_all)
apply (simp add: error_map_def fix_const)
apply (simp add: error_map_def fix_const Err_def)
apply (simp add: error_map_def fix_const Ok_def)
done
lemma fmapU_ErrorT [simp]:
"fmapU⋅f⋅(ErrorT⋅m) = ErrorT⋅(fmap⋅(fmap⋅f)⋅m)"
unfolding fmapU_errorT_def errorT_map_def fmap_error_def fix_const ErrorT_def
by simp
lemma runErrorT_fmapU [simp]:
"runErrorT⋅(fmapU⋅f⋅m) = fmap⋅(fmap⋅f)⋅(runErrorT⋅m)"
by (induct m rule: errorT_induct) simp
instance errorT :: ("functor", "domain") "functor"
proof
fix f g and xs :: "udom⋅('a, 'b) errorT"
show "fmapU⋅f⋅(fmapU⋅g⋅xs) = fmapU⋅(Λ x. f⋅(g⋅x))⋅xs"
apply (induct xs rule: errorT_induct)
apply (simp add: fmap_fmap eta_cfun)
done
qed
subsection ‹Transfer properties to polymorphic versions›
lemma fmap_ErrorT [simp]:
fixes f :: "'a → 'b" and m :: "'a⋅'e error⋅('m::functor)"
shows "fmap⋅f⋅(ErrorT⋅m) = ErrorT⋅(fmap⋅(fmap⋅f)⋅m)"
unfolding fmap_def [where 'f="('m,'e) errorT"]
by (simp_all add: coerce_simp eta_cfun)
lemma runErrorT_fmap [simp]:
fixes f :: "'a → 'b" and m :: "'a⋅('m::functor,'e) errorT"
shows "runErrorT⋅(fmap⋅f⋅m) = fmap⋅(fmap⋅f)⋅(runErrorT⋅m)"
using fmap_ErrorT [of f "runErrorT⋅m"]
by simp
lemma errorT_fmap_strict [simp]:
shows "fmap⋅f⋅(⊥::'a⋅('m::monad,'e) errorT) = ⊥"
by (simp add: errorT_eq_iff fmap_strict)
subsection ‹Monad operations›
text ‹The error monad transformer does not yield a monad in the
usual sense: We cannot prove a ‹monad› class instance, because
type ‹'a⋅('m,'e) errorT› contains values that break the monad
laws. However, it turns out that such values are inaccessible: The
monad laws are satisfied by all values constructible from the abstract
operations.›
text ‹To explore the properties of the error monad transformer
operations, we define them all as non-overloaded functions. \medskip
›
definition unitET :: "'a → 'a⋅('m::monad,'e) errorT"
where "unitET = (Λ x. ErrorT⋅(return⋅(Ok⋅x)))"
definition bindET :: "'a⋅('m::monad,'e) errorT →
('a → 'b⋅('m,'e) errorT) → 'b⋅('m,'e) errorT"
where "bindET = (Λ m k. ErrorT⋅(bind⋅(runErrorT⋅m)⋅
(Λ n. case n of Err⋅e ⇒ return⋅(Err⋅e) | Ok⋅x ⇒ runErrorT⋅(k⋅x))))"
definition liftET :: "'a⋅'m::monad → 'a⋅('m,'e) errorT"
where "liftET = (Λ m. ErrorT⋅(fmap⋅Ok⋅m))"
definition throwET :: "'e → 'a⋅('m::monad,'e) errorT"
where "throwET = (Λ e. ErrorT⋅(return⋅(Err⋅e)))"
definition catchET :: "'a⋅('m::monad,'e) errorT →
('e → 'a⋅('m,'e) errorT) → 'a⋅('m,'e) errorT"
where "catchET = (Λ m h. ErrorT⋅(bind⋅(runErrorT⋅m)⋅(Λ n. case n of
Err⋅e ⇒ runErrorT⋅(h⋅e) | Ok⋅x ⇒ return⋅(Ok⋅x))))"
definition fmapET :: "('a → 'b) →
'a⋅('m::monad,'e) errorT → 'b⋅('m,'e) errorT"
where "fmapET = (Λ f m. bindET⋅m⋅(Λ x. unitET⋅(f⋅x)))"
lemma runErrorT_unitET [simp]:
"runErrorT⋅(unitET⋅x) = return⋅(Ok⋅x)"
unfolding unitET_def by simp
lemma runErrorT_bindET [simp]:
"runErrorT⋅(bindET⋅m⋅k) = bind⋅(runErrorT⋅m)⋅
(Λ n. case n of Err⋅e ⇒ return⋅(Err⋅e) | Ok⋅x ⇒ runErrorT⋅(k⋅x))"
unfolding bindET_def by simp
lemma runErrorT_liftET [simp]:
"runErrorT⋅(liftET⋅m) = fmap⋅Ok⋅m"
unfolding liftET_def by simp
lemma runErrorT_throwET [simp]:
"runErrorT⋅(throwET⋅e) = return⋅(Err⋅e)"
unfolding throwET_def by simp
lemma runErrorT_catchET [simp]:
"runErrorT⋅(catchET⋅m⋅h) =
bind⋅(runErrorT⋅m)⋅(Λ n. case n of
Err⋅e ⇒ runErrorT⋅(h⋅e) | Ok⋅x ⇒ return⋅(Ok⋅x))"
unfolding catchET_def by simp
lemma runErrorT_fmapET [simp]:
"runErrorT⋅(fmapET⋅f⋅m) =
bind⋅(runErrorT⋅m)⋅(Λ n. case n of
Err⋅e ⇒ return⋅(Err⋅e) | Ok⋅x ⇒ return⋅(Ok⋅(f⋅x)))"
unfolding fmapET_def by simp
subsection ‹Laws›
lemma bindET_unitET [simp]:
"bindET⋅(unitET⋅x)⋅k = k⋅x"
by (rule errorT_eqI, simp)
lemma catchET_unitET [simp]:
"catchET⋅(unitET⋅x)⋅h = unitET⋅x"
by (rule errorT_eqI, simp)
lemma catchET_throwET [simp]:
"catchET⋅(throwET⋅e)⋅h = h⋅e"
by (rule errorT_eqI, simp)
lemma liftET_return:
"liftET⋅(return⋅x) = unitET⋅x"
by (rule errorT_eqI, simp add: fmap_return)
lemma liftET_bind:
"liftET⋅(bind⋅m⋅k) = bindET⋅(liftET⋅m)⋅(liftET oo k)"
by (rule errorT_eqI, simp add: fmap_bind bind_fmap)
lemma bindET_throwET:
"bindET⋅(throwET⋅e)⋅k = throwET⋅e"
by (rule errorT_eqI, simp)
lemma bindET_bindET:
"bindET⋅(bindET⋅m⋅h)⋅k = bindET⋅m⋅(Λ x. bindET⋅(h⋅x)⋅k)"
apply (rule errorT_eqI)
apply simp
apply (simp add: bind_bind)
apply (rule cfun_arg_cong)
apply (rule cfun_eqI, simp)
apply (case_tac x)
apply (simp add: bind_strict)
apply simp
apply simp
done
lemma fmapET_fmapET:
"fmapET⋅f⋅(fmapET⋅g⋅m) = fmapET⋅(Λ x. f⋅(g⋅x))⋅m"
by (simp add: fmapET_def bindET_bindET)
text ‹Right unit monad law is not satisfied in general.›
lemma bindET_unitET_right_counterexample:
fixes m :: "'a⋅('m::monad,'e) errorT"
assumes "m = ErrorT⋅(return⋅⊥)"
assumes "return⋅⊥ ≠ (⊥ :: ('a⋅'e error)⋅'m)"
shows "bindET⋅m⋅unitET ≠ m"
by (simp add: errorT_eq_iff assms)
text ‹Right unit is satisfied for inner monads with strict return.›
lemma bindET_unitET_right_restricted:
fixes m :: "'a⋅('m::monad,'e) errorT"
assumes "return⋅⊥ = (⊥ :: ('a⋅'e error)⋅'m)"
shows "bindET⋅m⋅unitET = m"
unfolding errorT_eq_iff
apply simp
apply (rule trans [OF _ monad_right_unit])
apply (rule cfun_arg_cong)
apply (rule cfun_eqI)
apply (case_tac x, simp_all add: assms)
done
subsection ‹Error monad transformer invariant›
text ‹This inductively-defined invariant is supposed to represent
the set of all values constructible using the standard ‹errorT›
operations.›
inductive invar :: "'a⋅('m::monad, 'e) errorT ⇒ bool"
where invar_bottom: "invar ⊥"
| invar_lub: "⋀Y. ⟦chain Y; ⋀i. invar (Y i)⟧ ⟹ invar (⨆i. Y i)"
| invar_unitET: "⋀x. invar (unitET⋅x)"
| invar_bindET: "⋀m k. ⟦invar m; ⋀x. invar (k⋅x)⟧ ⟹ invar (bindET⋅m⋅k)"
| invar_throwET: "⋀e. invar (throwET⋅e)"
| invar_catchET: "⋀m h. ⟦invar m; ⋀e. invar (h⋅e)⟧ ⟹ invar (catchET⋅m⋅h)"
| invar_liftET: "⋀m. invar (liftET⋅m)"
text ‹Right unit is satisfied for arguments built from standard functions.›
lemma bindET_unitET_right_invar:
assumes "invar m"
shows "bindET⋅m⋅unitET = m"
using assms
apply (induct set: invar)
apply (rule errorT_eqI, simp add: bind_strict)
apply (rule admD, simp, assumption, assumption)
apply (rule errorT_eqI, simp)
apply (simp add: errorT_eq_iff bind_bind)
apply (rule cfun_arg_cong, rule cfun_eqI, simp)
apply (case_tac x, simp add: bind_strict, simp, simp)
apply (rule errorT_eqI, simp)
apply (simp add: errorT_eq_iff bind_bind)
apply (rule cfun_arg_cong, rule cfun_eqI, simp)
apply (case_tac x, simp add: bind_strict, simp, simp)
apply (rule errorT_eqI, simp add: monad_fmap bind_bind)
done
text ‹Monad-fmap is satisfied for arguments built from standard functions.›
lemma errorT_monad_fmap_invar:
fixes f :: "'a → 'b" and m :: "'a⋅('m::monad,'e) errorT"
assumes "invar m"
shows "fmap⋅f⋅m = bindET⋅m⋅(Λ x. unitET⋅(f⋅x))"
using assms
apply (induct set: invar)
apply (rule errorT_eqI, simp add: bind_strict fmap_strict)
apply (rule admD, simp, assumption, assumption)
apply (rule errorT_eqI, simp add: fmap_return)
apply (simp add: errorT_eq_iff bind_bind fmap_bind)
apply (rule cfun_arg_cong, rule cfun_eqI, simp)
apply (case_tac x)
apply (simp add: bind_strict fmap_strict)
apply (simp add: fmap_return)
apply simp
apply (rule errorT_eqI, simp add: fmap_return)
apply (simp add: errorT_eq_iff bind_bind fmap_bind)
apply (rule cfun_arg_cong, rule cfun_eqI, simp)
apply (case_tac x)
apply (simp add: bind_strict fmap_strict)
apply simp
apply (simp add: fmap_return)
apply (rule errorT_eqI, simp add: monad_fmap bind_bind return_error_def)
done
subsection ‹Invariant expressed as a deflation›
text ‹We can also define an invariant in a more semantic way, as the
set of fixed-points of a deflation.›
definition invar' :: "'a⋅('m::monad, 'e) errorT ⇒ bool"
where "invar' m ⟷ fmapET⋅ID⋅m = m"
text ‹All standard operations preserve the invariant.›
lemma invar'_unitET: "invar' (unitET⋅x)"
unfolding invar'_def by (simp add: fmapET_def)
lemma invar'_fmapET: "invar' m ⟹ invar' (fmapET⋅f⋅m)"
unfolding invar'_def
by (erule subst, simp add: fmapET_def bindET_bindET eta_cfun)
lemma invar'_bindET: "⟦invar' m; ⋀x. invar' (k⋅x)⟧ ⟹ invar' (bindET⋅m⋅k)"
unfolding invar'_def
by (simp add: fmapET_def bindET_bindET eta_cfun)
lemma invar'_throwET: "invar' (throwET⋅e)"
unfolding invar'_def by (simp add: fmapET_def bindET_throwET eta_cfun)
lemma invar'_catchET: "⟦invar' m; ⋀e. invar' (h⋅e)⟧ ⟹ invar' (catchET⋅m⋅h)"
unfolding invar'_def
apply (simp add: fmapET_def eta_cfun)
apply (rule errorT_eqI)
apply (simp add: bind_bind eta_cfun)
apply (rule cfun_arg_cong)
apply (rule cfun_eqI)
apply (case_tac x)
apply (simp add: bind_strict)
apply simp
apply (drule_tac x=e in meta_spec)
apply (erule_tac t="h⋅e" in subst) back
apply (simp add: eta_cfun)
apply simp
done
lemma invar'_liftET: "invar' (liftET⋅m)"
unfolding invar'_def
apply (simp add: fmapET_def errorT_eq_iff)
apply (simp add: monad_fmap bind_bind)
done
lemma invar'_bottom: "invar' ⊥"
unfolding invar'_def fmapET_def
by (simp add: errorT_eq_iff bind_strict)
lemma adm_invar': "adm invar'"
unfolding invar'_def [abs_def] by simp
text ‹All monad laws are preserved by values satisfying the invariant.›
lemma bindET_fmapET_unitET:
shows "bindET⋅(fmapET⋅f⋅m)⋅unitET = fmapET⋅f⋅m"
by (simp add: fmapET_def bindET_bindET)
lemma invar'_right_unit: "invar' m ⟹ bindET⋅m⋅unitET = m"
unfolding invar'_def by (erule subst, rule bindET_fmapET_unitET)
lemma invar'_monad_fmap:
"invar' m ⟹ fmapET⋅f⋅m = bindET⋅m⋅(Λ x. unitET⋅(f⋅x))"
unfolding invar'_def by (erule subst, simp add: errorT_eq_iff)
lemma invar'_bind_assoc:
"⟦invar' m; ⋀x. invar' (f⋅x); ⋀y. invar' (g⋅y)⟧
⟹ bindET⋅(bindET⋅m⋅f)⋅g = bindET⋅m⋅(Λ x. bindET⋅(f⋅x)⋅g)"
by (rule bindET_bindET)
end