Theory Certification_Monads.Strict_Sum

(* Title:     Xml
   Author:    Christian Sternagel
   Author:    René Thiemann
*)

section ‹A Sum Type with Bottom Element›

theory Strict_Sum
imports
  "HOL-Library.Monad_Syntax"
  Error_Syntax
  Partial_Function_MR.Partial_Function_MR
begin

datatype (dead 'e, 'a) sum_bot (infixr "+" 10) = Bottom | Left 'e | Right 'a for map: sum_bot_map


subsection ‹Setup for Partial Functions›

abbreviation sum_bot_ord :: "'e + 'a  'e + 'a  bool"
where
  "sum_bot_ord  flat_ord Bottom"

interpretation sum_bot:
  partial_function_definitions sum_bot_ord "flat_lub Bottom"
  by (rule flat_interpretation)

declaration Partial_Function.init
  "sum_bot"
  @{term sum_bot.fixp_fun}
  @{term sum_bot.mono_body}
  @{thm sum_bot.fixp_rule_uc}
  @{thm sum_bot.fixp_induct_uc}
  NONE


subsection ‹Monad Setup›

fun bind :: "'e + 'a  ('a  ('e + 'b))  'e + 'b"
where
  "bind Bottom f = Bottom" |
  "bind (Left e) f = Left e" |
  "bind (Right x) f = f x"

lemma bind_cong [fundef_cong]:
  assumes "xs = ys" and "x. ys = Right x  f x = g x"
  shows "bind xs f = bind ys g"
  using assms by (cases ys) simp_all

abbreviation mono_sum_bot :: "(('a  ('e + 'b))  'f + 'c)  bool"
where
  "mono_sum_bot  monotone (fun_ord sum_bot_ord) sum_bot_ord"

(* TODO: perhaps use Partial_Function.bind_mono to proof this result immediately *)
lemma bind_mono [partial_function_mono]:
  assumes mf: "mono_sum_bot B" and mg: "y. mono_sum_bot (λf. C y f)"
  shows "mono_sum_bot (λf. bind (B f) (λy. C y f))"
proof (rule monotoneI)
  fix f g :: "'a  'b + 'c"
  assume fg: "fun_ord sum_bot_ord f g"
  with mf have "sum_bot_ord (B f) (B g)" by (rule monotoneD [of _ _ _ f g])
  then have "sum_bot_ord (bind (B f) (λy. C y f)) (bind (B g) (λy. C y f))"
    unfolding flat_ord_def by auto
  also from mg have "y'. sum_bot_ord (C y' f) (C y' g)"
    by (rule monotoneD) (rule fg)
  then have "sum_bot_ord (bind (B g) (λy'. C y' f)) (bind (B g) (λy'. C y' g))"
    unfolding flat_ord_def by (cases "B g") auto
  finally (sum_bot.leq_trans)
  show "sum_bot_ord (bind (B f) (λy. C y f)) (bind (B g) (λy'. C y' g))" .
qed

adhoc_overloading
  Monad_Syntax.bind bind

hide_const (open) bind

fun catch_error :: "'e + 'a  ('e  ('f + 'a))  'f + 'a"
where
  "catch_error Bottom f = Bottom " |
  "catch_error (Left a) f = f a" |
  "catch_error (Right a) f = Right a"

adhoc_overloading
  Error_Syntax.catch catch_error

lemma catch_mono [partial_function_mono]:
  assumes mf: "mono_sum_bot B" and mg: "y. mono_sum_bot (λf. C y f)"
  shows "mono_sum_bot (λf. try (B f) catch (λy. C y f))"
proof (rule monotoneI)
  fix f g :: "'a  'b + 'c"
  assume fg: "fun_ord sum_bot_ord f g"
  with mf have "sum_bot_ord (B f) (B g)" by (rule monotoneD [of _ _ _ f g])
  then have "sum_bot_ord (try (B f) catch (λy. C y f)) (try (B g) catch (λy. C y f))"
    unfolding flat_ord_def by auto
  also from mg
  have "y'. sum_bot_ord (C y' f) (C y' g)"
    by (rule monotoneD) (rule fg)
  then have "sum_bot_ord (try (B g) catch (λy'. C y' f)) (try (B g) catch (λy'. C y' g))"
    unfolding flat_ord_def by (cases "B g") auto
  finally (sum_bot.leq_trans)
    show "sum_bot_ord (try (B f) catch (λy. C y f)) (try (B g) catch (λy'. C y' g))" .
qed

definition error :: "'e  'e + 'a"
where
  [simp]: "error x = Left x"

definition return :: "'a  'e + 'a"
where
  [simp]: "return x = Right x"

fun map_sum_bot :: "('a  ('e + 'b))  'a list  'e + 'b list"
where
  "map_sum_bot f [] = return []" |
  "map_sum_bot f (x#xs) = do {
    y  f x;
    ys  map_sum_bot f xs;
    return (y # ys)
  }"

lemma map_sum_bot_cong [fundef_cong]:
  assumes "xs = ys" and "x. x  set ys  f x = g x"
  shows "map_sum_bot f xs = map_sum_bot g ys"
  unfolding assms(1) using assms(2) by (induct ys) auto

lemmas sum_bot_const_mono =
  sum_bot.const_mono [of "fun_ord sum_bot_ord"]

lemma map_sum_bot_mono [partial_function_mono]:
  fixes C :: "'a  ('b  ('e + 'c))  'e + 'd"
  assumes "y. y  set B  mono_sum_bot (C y)"
  shows "mono_sum_bot (λf. map_sum_bot (λy. C y f) B)"
  using assms by (induct B) (auto intro!: partial_function_mono)

abbreviation update_error :: "'e + 'a  ('e  'f)  'f + 'a"
where
  "update_error r f  try r catch (λ e. error (f e))"

adhoc_overloading
  Error_Syntax.update_error update_error

fun sumbot :: "'e + 'a  'e + 'a"
where
  "sumbot (Inl x) = Left x" |
  "sumbot (Inr x) = Right x"

code_datatype sumbot

lemma [code]:
  "bind (sumbot a) f = (case a of Inl b  sumbot (Inl b) | Inr a  f a)"
  by (cases a) auto

lemma [code]:
  "(try (sumbot a) catch f) = (case a of Inl b  f b | Inr a  sumbot (Inr a))"
  by (cases a) auto

lemma [code]: "Right x = sumbot (Inr x)" by simp

lemma [code]: "Left x = sumbot (Inl x)" by simp

lemma [code]: "return x = sumbot (Inr x)" by simp

lemma [code]: "error x = sumbot (Inl x)" by simp

lemma [code]:
  "case_sum_bot f g h (sumbot p) = case_sum g h p"
  by (cases p) auto


subsection ‹Connection to @{theory Partial_Function_MR.Partial_Function_MR}

lemma sum_bot_map_mono [partial_function_mono]:
  assumes mf: "mono_sum_bot B"
  shows "mono_sum_bot (λf. sum_bot_map h (B f))"
proof (rule monotoneI)
  fix f g :: "'a  'b + 'c"
  assume fg: "fun_ord sum_bot_ord f g"
  with mf have "sum_bot_ord (B f) (B g)" by (rule monotoneD [of _ _ _ f g])
  then show "sum_bot_ord (sum_bot_map h (B f)) (sum_bot_map h (B g))"
    unfolding flat_ord_def by auto    
qed

declaration Partial_Function_MR.init 
  "sum_bot" 
  (fn (mt, t_to_ss, mtT, msT, t_to_sTs) =>
      list_comb (Const (@{const_name sum_bot_map}, t_to_sTs ---> mtT --> msT), t_to_ss) $ mt)
  (fn (commonTs, argTs) => Type (@{type_name sum_bot}, commonTs @ argTs))
  (fn mT => Term.dest_Type mT |> #2 |> (fn [err, res] => ([err], [res]))) 
  @{thms sum_bot.map_comp} 
  @{thms sum_bot.map_ident}

end