Theory Reader_Monad

(*
 * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
 * Copyright (c) 2022 Apple Inc. All rights reserved.
 *
 * SPDX-License-Identifier: BSD-2-Clause
 *)

(*
 * Contributions by:
 *   2012 Lars Noschinski <noschinl@in.tum.de>
 *     Option monad while loop formalisation.
 *)

chapter "Option Monad (State Reader)"

theory Reader_Monad
  imports
    "More_Lib" (* FIXME: reduce dependencies *)
    "Less_Monad_Syntax"
begin

type_synonym ('s,'a) lookup = "'s  'a option"

text ‹Similar to constmap_option but the second function returns option as well›
definition
  opt_map :: "('s,'a) lookup  ('a  'b option)  ('s,'b) lookup" (infixl "|>" 54)
where
  "f |> g  λs. case f s of None  None | Some x  g x"

abbreviation opt_map_Some :: "('s  'a)  ('a  'b)  's  'b" (infixl "||>" 54) where
  "f ||> g  f |> (Some  g)"
lemmas opt_map_Some_def = opt_map_def

lemma opt_map_cong [fundef_cong]:
  " f = f'; v s. f s = Some v  g v = g' v  f |> g = f' |> g'"
  by (rule ext) (simp add: opt_map_def split: option.splits)

lemma in_opt_map_eq:
  "((f |> g) s = Some v) = (v'. f s = Some v'  g v' = Some v)"
  by (simp add: opt_map_def split: option.splits)

lemma opt_mapE:
  " (f |> g) s = Some v; v'. f s = Some v'; g v' = Some v   P   P"
  by (auto simp: in_opt_map_eq)

lemma opt_map_upd_None:
  "f(x := None) |> g = (f |> g)(x := None)"
  by (auto simp: opt_map_def)

lemma opt_map_upd_Some:
  "f(x  v) |> g = (f |> g)(x := g v)"
  by (auto simp: opt_map_def)

lemmas opt_map_upd[simp] = opt_map_upd_None opt_map_upd_Some

declare None_upd_eq[simp]

(* None_upd_eq[simp] so that this pattern is by simp. Hopefully not too much slowdown. *)
lemma " (f |> g) x = None; g v = None   f(x  v) |> g = f |> g"
  by simp

definition
  obind :: "('s,'a) lookup  ('a  ('s,'b) lookup)  ('s,'b) lookup" (infixl "|>>" 53)
where
  "f |>> g  λs. case f s of None  None | Some x  g x s"

(* Enable "do { .. }" syntax *)
adhoc_overloading
  Monad_Syntax.bind obind

definition
  "ofail = K None"

definition
  "oreturn = K o Some"

definition
  "oassert P  if P then oreturn () else ofail"

definition oapply :: "'a  ('a  'b option)  'b option"
  where
  "oapply x  λs. s x"

text ‹
  If the result can be an exception.
  Corresponding bindE would be analogous to lifting in NonDetMonad.
›

definition
  "oreturnOk x = K (Some (Inr x))"

definition
  "othrow e = K (Some (Inl e))"

definition
  "oguard G  (λs. if G s then Some () else None)"

definition
  "ocondition c L R  (λs. if c s then L s else R s)"

definition
  "oskip  oreturn ()"

text ‹Monad laws›
lemma oreturn_bind [simp]: "(oreturn x |>> f) = f x"
  by (auto simp add: oreturn_def obind_def K_def)

lemma obind_return [simp]: "(m |>> oreturn) = m"
  by (auto simp add: oreturn_def obind_def K_def split: option.splits)

lemma obind_assoc:
  "(m |>> f) |>> g  =  m |>> (λx. f x |>> g)"
  by (auto simp add: oreturn_def obind_def K_def split: option.splits)


text ‹Binding fail›

lemma obind_fail [simp]:
  "f |>> (λ_. ofail) = ofail"
  by (auto simp add: ofail_def obind_def K_def split: option.splits)

lemma ofail_bind [simp]:
  "ofail |>> m = ofail"
  by (auto simp add: ofail_def obind_def K_def split: option.splits)



text ‹Function package setup›
lemma opt_bind_cong [fundef_cong]:
  " f = f'; v s. f' s = Some v  g v s = g' v s   f |>> g = f' |>> g'"
  by (rule ext) (simp add: obind_def split: option.splits)

lemma opt_bind_cong_apply [fundef_cong]:
  " f s = f' s; v. f' s = Some v  g v s = g' v s   (f |>> g) s = (f' |>> g') s"
  by (simp add: obind_def split: option.splits)

lemma oassert_bind_cong [fundef_cong]:
  " P = P'; P'  m = m'   oassert P |>> m = oassert P' |>> m'"
  by (auto simp: oassert_def)

lemma oassert_bind_cong_apply [fundef_cong]:
  " P = P'; P'  m () s = m' () s   (oassert P |>> m) s = (oassert P' |>> m') s"
  by (auto simp: oassert_def)

lemma oreturn_bind_cong [fundef_cong]:
  " x = x'; m x' = m' x'   oreturn x |>> m = oreturn x' |>> m'"
  by simp

lemma oreturn_bind_cong_apply [fundef_cong]:
  " x = x'; m x' s = m' x' s   (oreturn x |>> m) s = (oreturn x' |>> m') s"
  by simp

lemma oreturn_bind_cong2 [fundef_cong]:
  " x = x'; m x' = m' x'   (oreturn $ x) |>> m = (oreturn $ x') |>> m'"
  by simp

lemma oreturn_bind_cong2_apply [fundef_cong]:
  " x = x'; m x' s = m' x' s   ((oreturn $ x) |>> m) s = ((oreturn $ x') |>> m') s"
  by simp

lemma ocondition_cong [fundef_cong]:
"c = c'; s. c' s  l s = l' s; s. ¬c' s  r s = r' s
   ocondition c l r = ocondition c' l' r'"
  by (auto simp: ocondition_def)


text ‹Decomposition›

lemma ocondition_K_true [simp]:
  "ocondition (λ_. True) T F = T"
  by (simp add: ocondition_def)

lemma ocondition_K_false [simp]:
  "ocondition (λ_. False) T F = F"
  by (simp add: ocondition_def)

lemma ocondition_False:
    " s. ¬ P s   ocondition P L R = R"
  by (rule ext, clarsimp simp: ocondition_def)

lemma ocondition_True:
    " s. P s   ocondition P L R = L"
  by (rule ext, clarsimp simp: ocondition_def)

lemma in_oreturn [simp]:
  "(oreturn x s = Some v) = (v = x)"
  by (auto simp: oreturn_def K_def)

lemma oreturnE:
  "oreturn x s = Some v; v = x  P  P"
  by simp

lemma in_ofail [simp]:
  "ofail s  Some v"
  by (auto simp: ofail_def K_def)

lemma ofailE:
  "ofail s = Some v  P"
  by simp

lemma in_oassert_eq [simp]:
  "(oassert P s = Some v) = P"
  by (simp add: oassert_def)

lemma oassert_True [simp]:
  "oassert True = oreturn ()"
  by (simp add: oassert_def)

lemma oassert_False [simp]:
  "oassert False = ofail"
  by (simp add: oassert_def)

lemma oassertE:
  " oassert P s = Some v; P  Q   Q"
  by simp

lemma in_obind_eq:
  "((f |>> g) s = Some v) = (v'. f s = Some v'  g v' s = Some v)"
  by (simp add: obind_def split: option.splits)

lemma obind_eqI:
  " f s = f s' ; x. f s = Some x  g x s = g' x s'   obind f g s = obind f g' s'"
  by (simp add: obind_def split: option.splits)

(* full form of obind_eqI; the second equality makes more sense flipped here, as we end up
   with "f s = Some x ; f s' = f s" preventing "Some x = ..." *)
lemma obind_eqI_full:
  " f s = f s' ; x.  f s = Some x; f s' = f s   g x s = g' x s' 
    obind f g s = obind f g' s'"
  by (drule sym[where s="f s"]) (* prevent looping *)
     (clarsimp simp: obind_def split: option.splits)

lemma obindE:
  " (f |>> g) s = Some v;
     v'. f s = Some v'; g v' s = Some v  P  P"
  by (auto simp: in_obind_eq)

lemma in_othrow_eq [simp]:
  "(othrow e s = Some v) = (v = Inl e)"
  by (auto simp: othrow_def K_def)

lemma othrowE:
  "othrow e s = Some v; v = Inl e  P  P"
  by simp

lemma in_oreturnOk_eq [simp]:
  "(oreturnOk x s = Some v) = (v = Inr x)"
  by (auto simp: oreturnOk_def K_def)

lemma oreturnOkE:
  "oreturnOk x s = Some v; v = Inr x  P  P"
  by simp

lemmas omonadE [elim!] =
  opt_mapE obindE oreturnE ofailE othrowE oreturnOkE oassertE

lemma in_opt_map_Some_eq:
  "((f ||> g) x = Some y) = (v. f x = Some v  g v = y)"
  by (simp add: in_opt_map_eq)

lemma in_opt_map_None_eq[simp]:
  "((f ||> g) x = None) = (f x = None)"
  by (simp add: opt_map_def split: option.splits)

lemma oreturn_comp[simp]:
  "oreturn x  f = oreturn x"
  by (simp add: oreturn_def K_def o_def)

lemma ofail_comp[simp]:
  "ofail  f = ofail"
  by (auto simp: ofail_def K_def)

lemma oassert_comp[simp]:
  "oassert P  f = oassert P"
  by (simp add: oassert_def)

lemma fail_apply[simp]:
  "ofail s = None"
  by (simp add: ofail_def K_def)

lemma oassert_apply[simp]:
  "oassert P s = (if P then Some () else None)"
  by (simp add: oassert_def)

lemma oreturn_apply[simp]:
  "oreturn x s = Some x"
  by simp

lemma oapply_apply[simp]:
  "oapply x s = s x"
  by (simp add: oapply_def)

lemma obind_comp_dist:
  "obind f g o h = obind (f o h) (λx. g x o h)"
  by (auto simp: obind_def split: option.splits)

lemma if_comp_dist:
  "(if P then f else g) o h = (if P then f o h else g o h)"
  by auto


section ‹"While" loops over option monad.›

text ‹
  This is an inductive definition of a while loop over the plain option monad
  (without passing through a state)
›

inductive_set
  option_while' :: "('a  bool)  ('a  'a option)  'a option rel"
  for C B
where
    final: "¬ C r  (Some r, Some r)  option_while' C B"
  | fail: " C r; B r = None   (Some r, None)  option_while' C B"
  | step: " C r;  B r = Some r'; (Some r', sr'')  option_while' C B 
            (Some r, sr'')  option_while' C B"

definition
  "option_while C B r 
    (if (s. (Some r, s)  option_while' C B) then
      (THE s. (Some r, s)  option_while' C B) else None)"

lemma option_while'_inj:
  assumes "(s,s')  option_while' C B" "(s, s'')  option_while' C B"
  shows "s' = s''"
  using assms by (induct rule: option_while'.induct) (auto elim: option_while'.cases)

lemma option_while'_inj_step:
  " C s; B s = Some s'; (Some s, t)  option_while' C B ; (Some s', t')  option_while' C B   t = t'"
  by (metis option_while'.step option_while'_inj)

lemma option_while'_THE:
  assumes "(Some r, sr')  option_while' C B"
  shows "(THE s. (Some r, s)  option_while' C B) = sr'"
  using assms by (blast dest: option_while'_inj)

lemma option_while_simps:
  "¬ C s  option_while C B s = Some s"
  "C s  B s = None  option_while C B s = None"
  "C s  B s = Some s'  option_while C B s = option_while C B s'"
  "(Some s, ss')  option_while' C B  option_while C B s = ss'"
  using option_while'_inj_step[of C s B s']
  by (auto simp: option_while_def option_while'_THE
      intro: option_while'.intros
      dest: option_while'_inj
      elim: option_while'.cases)

lemma option_while_rule:
  assumes "option_while C B s = Some s'"
  assumes "I s"
  assumes istep: "s s'. C s  I s  B s = Some s'  I s'"
  shows "I s'  ¬ C s'"
proof -
  { fix ss ss' assume "(ss, ss')  option_while' C B" "ss = Some s" "ss' = Some s'"
    then have ?thesis using I s
      by (induct arbitrary: s) (auto intro: istep) }
  then show ?thesis using assms(1)
    by (auto simp: option_while_def option_while'_THE split: if_split_asm)
qed

lemma option_while'_term:
  assumes "I r"
  assumes "wf M"
  assumes step_less: "r r'. I r; C r; B r = Some r'  (r',r)  M"
  assumes step_I: "r r'. I r; C r; B r = Some r'  I r'"
  obtains sr' where "(Some r, sr')  option_while' C B"
  apply atomize_elim
  using assms(2,1)
proof induct
  case (less r)
  show ?case
  proof (cases "C r" "B r" rule: bool.exhaust[case_product option.exhaust])
    case (True_Some r')
    then have "(r',r)  M" "I r'"
      by (auto intro: less step_less step_I)
    then obtain sr' where "(Some r', sr')  option_while' C B"
      by atomize_elim (rule less)
    then have "(Some r, sr')  option_while' C B"
      using True_Some by (auto intro: option_while'.intros)
    then show ?thesis ..
  qed (auto intro: option_while'.intros)
qed

lemma option_while_rule':
  assumes "option_while C B s = ss'"
  assumes "wf M"
  assumes "I (Some s)"
  assumes less: "s s'. C s  I (Some s)  B s = Some s'  (s', s)  M"
  assumes step: "s s'. C s  I (Some s)  B s = Some s'  I (Some s')"
  assumes final: "s. C s  I (Some s)  B s = None  I None"
  shows "I ss'  (case ss' of Some s'  ¬ C s' | _  True)"
proof -
  define ss where "ss  Some s"
  obtain ss1' where "(Some s, ss1')  option_while' C B"
    using assms(3,2,4,5) by (rule option_while'_term)
  then have *: "(ss, ss')  option_while' C B" using option_while C B s = ss'
    by (auto simp: option_while_simps ss_def)
  show ?thesis
  proof (cases ss')
    case (Some s') with * ss_def show ?thesis using I _
      by (induct arbitrary:s) (auto intro: step)
  next
    case None with * ss_def show ?thesis using I _
      by (induct arbitrary:s) (auto intro: step final)
  qed
qed

section ‹Lift @{term option_while} to the @{typ "('a,'s) lookup"} monad›

definition
  owhile :: "('a  's  bool)  ('a  ('s,'a) lookup)  'a  ('s,'a) lookup"
where
 "owhile c b a  λs. option_while (λa. c a s) (λa. b a s) a"

lemma owhile_unroll:
  "owhile C B r = ocondition (C r) (B r |>> owhile C B) (oreturn r)"
  by (auto simp: ocondition_def obind_def oreturn_def owhile_def
           option_while_simps K_def split: option.split)

text ‹rule for terminating loops›

lemma owhile_rule:
  assumes "I r s"
  assumes "wf M"
  assumes less: "r r'. I r s; C r s; B r s = Some r'  (r',r)  M"
  assumes step: "r r'. I r s; C r s; B r s = Some r'  I r' s"
  assumes fail: "r. I r s; C r s; B r s = None  Q None"
  assumes final: "r. I r s; ¬C r s  Q (Some r)"
  shows "Q (owhile C B r s)"
proof -
  let ?rs' = "owhile C B r s"
  have "(case ?rs' of Some r  I r s | _  Q None)
       (case ?rs' of Some r'  ¬ C r' s | _  True)"
    by (rule option_while_rule'[where B="λr. B r s" and s=r, OF _ wf _])
       (auto simp: owhile_def intro: assms)
  then show ?thesis by (auto intro: final split: option.split_asm)
qed

end