Theory UnboxedNats
theory UnboxedNats
imports
HOLCF
Nats
WorkerWrapperNew
begin
section‹Unboxing types.›
text‹The original application of the worker/wrapper transformation
was the unboxing of flat types by \<^citet>‹"SPJ-JL:1991"›. We can model
the boxed and unboxed types as (respectively) pointed and unpointed
domains in HOLCF. Concretely @{typ "UNat"} denotes the discrete domain
of naturals, @{typ "UNat⇩⊥"} the lifted (flat and pointed) variant, and
@{typ "Nat"} the standard boxed domain, isomorphic to @{typ
"UNat⇩⊥"}. This latter distinction helps us keep the boxed naturals and
lifted function codomains separated; applications of @{term "unbox"}
should be thought of in the same way as Haskell's @{term "newtype"}
constructors, i.e. operationally equivalent to @{term "ID"}.
The divergence monad is used to handle the unboxing, see below.›
subsection‹Factorial example.›
text‹Standard definition of factorial.›
fixrec fac :: "Nat → Nat"
where
"fac⋅n = If n =⇩B 0 then 1 else n * fac⋅(n - 1)"
declare fac.simps[simp del]
lemma fac_strict[simp]: "fac⋅⊥ = ⊥"
by fixrec_simp
definition
fac_body :: "(Nat → Nat) → Nat → Nat" where
"fac_body ≡ Λ r n. If n =⇩B 0 then 1 else n * r⋅(n - 1)"
lemma fac_body_strict[simp]: "fac_body⋅r⋅⊥ = ⊥"
unfolding fac_body_def by simp
lemma fac_fac_body_eq: "fac = fix⋅fac_body"
unfolding fac_body_def by (rule cfun_eqI, subst fac_def, simp)
text‹Wrap / unwrap functions. Note the explicit lifting of the
co-domain. For some reason the published version of
\<^citet>‹"GillHutton:2009"› does not discuss this point: if we're going to
handle recursive functions, we need a bottom.
@{term "unbox"} simply removes the tag, yielding a possibly-divergent
unboxed value, the result of the function.›
definition
unwrapB :: "(Nat → Nat) → UNat → UNat⇩⊥" where
"unwrapB ≡ Λ f. unbox oo f oo box"
text‹Note that the monadic bind operator @{term "(>>=)"} here stands
in for the \textsf{case} construct in the paper.›
definition
wrapB :: "(UNat → UNat⇩⊥) → Nat → Nat" where
"wrapB ≡ Λ f x . unbox⋅x >>= f >>= box"
lemma wrapB_unwrapB_body:
assumes strictF: "f⋅⊥ = ⊥"
shows "(wrapB oo unwrapB)⋅f = f" (is "?lhs = ?rhs")
proof(rule cfun_eqI)
fix x :: Nat
have "?lhs⋅x = unbox⋅x >>= (Λ x'. unwrapB⋅f⋅x' >>= box)"
unfolding wrapB_def by simp
also have "… = unbox⋅x >>= (Λ x'. unbox⋅(f⋅(box⋅x')) >>= box)"
unfolding unwrapB_def by simp
also from strictF have "… = f⋅x" by (cases x, simp_all)
finally show "?lhs⋅x = ?rhs⋅x" .
qed
text‹Apply worker/wrapper.›
definition
fac_work :: "UNat → UNat⇩⊥" where
"fac_work ≡ fix⋅(unwrapB oo fac_body oo wrapB)"
definition
fac_wrap :: "Nat → Nat" where
"fac_wrap ≡ wrapB⋅fac_work"
lemma fac_fac_ww_eq: "fac = fac_wrap" (is "?lhs = ?rhs")
proof -
have "wrapB oo unwrapB oo fac_body = fac_body"
using wrapB_unwrapB_body[OF fac_body_strict]
by - (rule cfun_eqI, simp)
thus ?thesis
using worker_wrapper_body[where computation=fac and body=fac_body and wrap=wrapB and unwrap=unwrapB]
unfolding fac_work_def fac_wrap_def by (simp add: fac_fac_body_eq)
qed
text‹This is not entirely faithful to the paper, as they don't
explicitly handle the lifting of the codomain.›
definition
fac_body' :: "(UNat → UNat⇩⊥) → UNat → UNat⇩⊥" where
"fac_body' ≡ Λ r n.
unbox⋅(If box⋅n =⇩B 0
then 1
else unbox⋅(box⋅n - 1) >>= r >>= (Λ b. box⋅n * box⋅b))"
lemma fac_body'_fac_body: "fac_body' = unwrapB oo fac_body oo wrapB" (is "?lhs = ?rhs")
proof(rule cfun_eqI)+
fix r x
show "?lhs⋅r⋅x = ?rhs⋅r⋅x"
using bbind_case_distr_strict[where f="Λ y. box⋅x * y" and g="unbox⋅(box⋅x - 1)"]
bbind_case_distr_strict[where f="Λ y. box⋅x * y" and h="box"]
unfolding fac_body'_def fac_body_def unwrapB_def wrapB_def by simp
qed
text‹The @{term "up"} constructors here again mediate the
isomorphism, operationally doing nothing. Note the switch to the
machine-oriented \emph{if} construct: the test @{term "n = 0"} cannot
diverge.›
definition
fac_body_final :: "(UNat → UNat⇩⊥) → UNat → UNat⇩⊥" where
"fac_body_final ≡ Λ r n.
if n = 0 then up⋅1 else r⋅(n -⇩# 1) >>= (Λ b. up⋅(n *⇩# b))"
lemma fac_body_final_fac_body': "fac_body_final = fac_body'" (is "?lhs = ?rhs")
proof(rule cfun_eqI)+
fix r x
show "?lhs⋅r⋅x = ?rhs⋅r⋅x"
using bbind_case_distr_strict[where f="unbox" and g="r⋅(x -⇩# 1)" and h="(Λ b. box⋅(x *⇩# b))"]
unfolding fac_body_final_def fac_body'_def uMinus_def uMult_def zero_Nat_def one_Nat_def
by simp
qed
definition
fac_work_final :: "UNat → UNat⇩⊥" where
"fac_work_final ≡ fix⋅fac_body_final"
definition
fac_final :: "Nat → Nat" where
"fac_final ≡ Λ n. unbox⋅n >>= fac_work_final >>= box"
lemma fac_fac_final: "fac = fac_final" (is "?lhs=?rhs")
proof -
have "?lhs = fac_wrap" by (rule fac_fac_ww_eq)
also have "… = wrapB⋅fac_work" by (simp only: fac_wrap_def)
also have "… = wrapB⋅(fix⋅(unwrapB oo fac_body oo wrapB))" by (simp only: fac_work_def)
also have "… = wrapB⋅(fix⋅fac_body')" by (simp only: fac_body'_fac_body)
also have "… = wrapB⋅fac_work_final" by (simp only: fac_body_final_fac_body' fac_work_final_def)
also have "… = fac_final" by (simp add: fac_final_def wrapB_def)
finally show ?thesis .
qed
subsection‹Introducing an accumulator.›
text‹
The final version of factorial uses unboxed naturals but is not
tail-recursive. We can apply worker/wrapper once more to introduce an
accumulator, similar to \S\ref{sec:accum}.
The monadic machinery complicates things slightly here. We use
\emph{Kleisli composition}, denoted @{term "(>=>)"}, in the
homomorphism.
Firstly we introduce an ``accumulator'' monoid and show the
homomorphism.
›
type_synonym UNatAcc = "UNat → UNat⇩⊥"
definition
n2a :: "UNat → UNatAcc" where
"n2a ≡ Λ m n. up⋅(m *⇩# n)"
definition
a2n :: "UNatAcc → UNat⇩⊥" where
"a2n ≡ Λ a. a⋅1"
lemma a2n_strict[simp]: "a2n⋅⊥ = ⊥"
unfolding a2n_def by simp
lemma a2n_n2a: "a2n⋅(n2a⋅u) = up⋅u"
unfolding a2n_def n2a_def by (simp add: uMult_arithmetic)
lemma A_hom_mult: "n2a⋅(x *⇩# y) = (n2a⋅x >=> n2a⋅y)"
unfolding n2a_def bKleisli_def by (simp add: uMult_arithmetic)
definition
unwrapA :: "(UNat → UNat⇩⊥) → UNat → UNatAcc" where
"unwrapA ≡ Λ f n. f⋅n >>= n2a"
lemma unwrapA_strict[simp]: "unwrapA⋅⊥ = ⊥"
unfolding unwrapA_def by (rule cfun_eqI) simp
definition
wrapA :: "(UNat → UNatAcc) → UNat → UNat⇩⊥" where
"wrapA ≡ Λ f. a2n oo f"
lemma wrapA_unwrapA_id: "wrapA oo unwrapA = ID"
unfolding wrapA_def unwrapA_def
apply (rule cfun_eqI)+
apply (case_tac "x⋅xa")
apply (simp_all add: a2n_n2a)
done
text‹Some steps along the way.›
definition
fac_acc_body1 :: "(UNat → UNatAcc) → UNat → UNatAcc" where
"fac_acc_body1 ≡ Λ r n.
if n = 0 then n2a⋅1 else wrapA⋅r⋅(n -⇩# 1) >>= (Λ res. n2a⋅(n *⇩# res))"
lemma fac_acc_body1_fac_body_final_eq: "fac_acc_body1 = unwrapA oo fac_body_final oo wrapA"
unfolding fac_acc_body1_def fac_body_final_def wrapA_def unwrapA_def
by (rule cfun_eqI)+ simp
text‹Use the homomorphism.›
definition
fac_acc_body2 :: "(UNat → UNatAcc) → UNat → UNatAcc" where
"fac_acc_body2 ≡ Λ r n.
if n = 0 then n2a⋅1 else wrapA⋅r⋅(n -⇩# 1) >>= (Λ res. n2a⋅n >=> n2a⋅res)"
lemma fac_acc_body2_body1_eq: "fac_acc_body2 = fac_acc_body1"
unfolding fac_acc_body1_def fac_acc_body2_def
by (rule cfun_eqI)+ (simp add: A_hom_mult)
text‹Apply worker/wrapper.›
definition
fac_acc_body3 :: "(UNat → UNatAcc) → UNat → UNatAcc" where
"fac_acc_body3 ≡ Λ r n.
if n = 0 then n2a⋅1 else n2a⋅n >=> r⋅(n -⇩# 1)"
lemma fac_acc_body3_body2: "fac_acc_body3 oo (unwrapA oo wrapA) = fac_acc_body2" (is "?lhs=?rhs")
proof(rule cfun_eqI)+
fix r n acc
show "((fac_acc_body3 oo (unwrapA oo wrapA))⋅r⋅n⋅acc) = fac_acc_body2⋅r⋅n⋅acc"
unfolding fac_acc_body2_def fac_acc_body3_def unwrapA_def
using bbind_case_distr_strict[where f="Λ y. n2a⋅n >=> y" and h="n2a", symmetric]
by simp
qed
lemma fac_work_final_body3_eq: "fac_work_final = wrapA⋅(fix⋅fac_acc_body3)"
unfolding fac_work_final_def
by (rule worker_wrapper_fusion_new[OF wrapA_unwrapA_id unwrapA_strict])
(simp add: fac_acc_body3_body2 fac_acc_body2_body1_eq fac_acc_body1_fac_body_final_eq)
definition
fac_acc_body_final :: "(UNat → UNatAcc) → UNat → UNatAcc" where
"fac_acc_body_final ≡ Λ r n acc.
if n = 0 then up⋅acc else r⋅(n -⇩# 1)⋅(n *⇩# acc)"
definition
fac_acc_work_final :: "UNat → UNat⇩⊥" where
"fac_acc_work_final ≡ Λ x. fix⋅fac_acc_body_final⋅x⋅1"
lemma fac_acc_work_final_fac_acc_work3_eq: "fac_acc_body_final = fac_acc_body3" (is "?lhs=?rhs")
unfolding fac_acc_body3_def fac_acc_body_final_def n2a_def bKleisli_def
by (rule cfun_eqI)+
(simp add: uMult_arithmetic)
lemma fac_acc_work_final_fac_work: "fac_acc_work_final = fac_work_final" (is "?lhs=?rhs")
proof -
have "?rhs = wrapA⋅(fix⋅fac_acc_body3)" by (rule fac_work_final_body3_eq)
also have "… = wrapA⋅(fix⋅fac_acc_body_final)"
using fac_acc_work_final_fac_acc_work3_eq by simp
also have "… = ?lhs"
unfolding fac_acc_work_final_def wrapA_def a2n_def
by (simp add: cfcomp1)
finally show ?thesis by simp
qed
end