Theory DetMonad

(*
 * Copyright 2014, NICTA
 *
 * This software may be distributed and modified according to the terms of
 * the BSD 2-Clause license. Note that NO WARRANTY is provided.
 * See "LICENSE_BSD2.txt" for details.
 *
 * @TAG(NICTA_BSD)
 *)

(*
* Zhe Hou: I modified this file to model a deterministic monad instead of
* a non-deterministic monad. Features that are irrelevant to deterministic 
* moands are removed. I also removed the section for loops, which are not
* used in the modelling of the SPARC architecture.
*)

(* 
   Deterministic state and error monads with failure in Isabelle.
*)

(*chapter "Deterministic State Monad with Failure"*)

theory DetMonad
imports "../Lib"
begin

text ‹
  \label{c:monads}

  State monads are used extensively in the seL4 specification. They are
  defined below.
›

section "The Monad"

text ‹
  The basic type of the deterministic state monad with failure is
  very similar to the normal state monad. Instead of a pair consisting
  of result and new state, we return a pair coupled with
  a failure flag. The flag is @{const True} if the computation have failed. 
  Conversely, if the flag is @{const False}, the computation resulting in 
  the returned result have succeeded.› 
type_synonym ('s,'a) det_monad = "'s  ('a × 's) × bool"


text ‹
  The definition of fundamental monad functions return› and
  bind›. The monad function return x› does not change 
  the  state, does not fail, and returns x›.
› 
definition
  return :: "'a  ('s,'a) det_monad" where
  "return a  λs. ((a,s),False)"

text ‹
  The monad function bind f g›, also written f >>= g›,
  is the execution of @{term f} followed by the execution of g›.
  The function g› takes the result value \emph{and} the result
  state of f› as parameter. The definition says that the result of
  the combined operation is the result which is created
  by g› applied to the result of f›. The combined
  operation may have failed, if f› may have failed or g› may
  have failed on the result of f›.
›

text ‹
 David Sanan and Zhe Hou: The original definition of bind is very inefficient 
 when converted to executable code. Here we change it to a more efficient
 version for execution. The idea remains the same.
›

definition "h1 f s = f s"
definition "h2 g fs = (let (a,b) = fst (fs) in g a b)"
definition bind:: "('s, 'a) det_monad  ('a  ('s, 'b) det_monad)  
           ('s, 'b) det_monad" (infixl ">>=" 60)
where
"bind f g  λs. (
  let fs = h1 f s;
      v = h2 g fs
  in
  (fst v, (snd v  snd fs)))"

text ‹
  Sometimes it is convenient to write bind› in reverse order.
›
abbreviation(input)
  bind_rev :: "('c  ('a, 'b) det_monad)  ('a, 'c) det_monad  
               ('a, 'b) det_monad" (infixl "=<<" 60) where 
  "g =<< f  f >>= g"

text ‹
  The basic accessor functions of the state monad. get› returns
  the current state as result, does not fail, and does not change the state.
  put s› returns nothing (@{typ unit}), changes the current state
  to s› and does not fail.
›
definition
  get :: "('s,'s) det_monad" where
  "get  λs. ((s,s), False)"

definition
  put :: "'s  ('s, unit) det_monad" where
  "put s  λ_. (((),s), False)"

subsection "Failure"

text ‹The monad function that always fails. Returns the current 
  state and sets the failure flag.›
definition
  fail :: "'a  ('s, 'a) det_monad" where
 "fail a  λs. ((a,s), True)"

text ‹Assertions: fail if the property P› is not true›
definition
  assert :: "bool  ('a, unit) det_monad" where
 "assert P  if P then return () else fail ()"

text ‹An assertion that also can introspect the current state.›

definition
  state_assert :: "('s  bool)  ('s, unit) det_monad"
where
  "state_assert P  get >>= (λs. assert (P s))"

subsection "Generic functions on top of the state monad"

text ‹Apply a function to the current state and return the result
without changing the state.›
definition
  gets :: "('s  'a)  ('s, 'a) det_monad" where
 "gets f  get >>= (λs. return (f s))"

text ‹Modify the current state using the function passed in.›
definition
  modify :: "('s  's)  ('s, unit) det_monad" where
 "modify f  get >>= (λs. put (f s))"

lemma simpler_gets_def: "gets f = (λs. ((f s, s), False))"
  apply (simp add: gets_def return_def bind_def h1_def h2_def get_def)
  done

lemma simpler_modify_def:
  "modify f = (λs. (((), f s), False))"
  by (simp add: modify_def bind_def h1_def h2_def get_def put_def)

text ‹Execute the given monad when the condition is true, 
  return ()› otherwise.›
definition
  when1 :: "bool  ('s, unit) det_monad  
           ('s, unit) det_monad" where 
  "when1 P m  if P then m else return ()"

text ‹Execute the given monad unless the condition is true, 
  return ()› otherwise.›
definition 
  unless :: "bool  ('s, unit) det_monad  
            ('s, unit) det_monad" where
  "unless P m  when1 (¬P) m"

text ‹
  Perform a test on the current state, performing the left monad if
  the result is true or the right monad if the result is false.
›
definition
  condition :: "('s  bool)  ('s, 'r) det_monad  ('s, 'r) det_monad  ('s, 'r) det_monad"
where
  "condition P L R  λs. if (P s) then (L s) else (R s)"

notation (output)
  condition  ("(condition (_)//  (_)//  (_))" [1000,1000,1000] 1000)

subsection ‹The Monad Laws›

text ‹Each monad satisfies at least the following three laws.›

text @{term return} is absorbed at the left of a @{term bind}, 
  applying the return value directly:› 
lemma return_bind [simp]: "(return x >>= f) = f x"
  by (simp add: return_def bind_def h1_def h2_def)

text @{term return} is absorbed on the right of a @{term bind} 
lemma bind_return [simp]: "(m >>= return) = m"
  apply (rule ext)
  apply (simp add: bind_def h1_def h2_def return_def split_def)
  done
 
text @{term bind} is associative›
lemma bind_assoc: 
  fixes m :: "('a,'b) det_monad"
  fixes f :: "'b  ('a,'c) det_monad"
  fixes g :: "'c  ('a,'d) det_monad"
  shows "(m >>= f) >>= g  =  m >>= (λx. f x >>= g)"
  apply (unfold bind_def h1_def h2_def Let_def split_def)
  apply (rule ext)
  apply clarsimp
  done


section ‹Adding Exceptions›

text ‹
  The type @{typ "('s,'a) det_monad"} gives us determinism and
  failure. We now extend this monad with exceptional return values
  that abort normal execution, but can be handled explicitly.
  We use the sum type to indicate exceptions. 

  In @{typ "('s, 'e + 'a) det_monad"}, @{typ "'s"} is the state,
  @{typ 'e} is an exception, and @{typ 'a} is a normal return value.

  This new type itself forms a monad again. Since type classes in 
  Isabelle are not powerful enough to express the class of monads,
  we provide new names for the @{term return} and @{term bind} functions
  in this monad. We call them returnOk› (for normal return values)
  and bindE› (for composition). We also define throwError›
  to return an exceptional value.
›
definition
  returnOk :: "'a  ('s, 'e + 'a) det_monad" where
  "returnOk  return o Inr"

definition
  throwError :: "'e  ('s, 'e + 'a) det_monad" where
  "throwError  return o Inl"

text ‹
  Lifting a function over the exception type: if the input is an
  exception, return that exception; otherwise continue execution.
›
definition
  lift :: "('a  ('s, 'e + 'b) det_monad)  
           'e +'a  ('s, 'e + 'b) det_monad"
where
  "lift f v  case v of Inl e  throwError e
                      | Inr v'  f v'"

text ‹
  The definition of @{term bind} in the exception monad (new
  name bindE›): the same as normal @{term bind}, but 
  the right-hand side is skipped if the left-hand side
  produced an exception.
›
definition
  bindE :: "('s, 'e + 'a) det_monad  
            ('a  ('s, 'e + 'b) det_monad)  
            ('s, 'e + 'b) det_monad"  (infixl ">>=E" 60)
where
  "bindE f g  bind f (lift g)"


text ‹
  Lifting a normal deterministic monad into the 
  exception monad is achieved by always returning its
  result as normal result and never throwing an exception.
›
definition
  liftE :: "('s,'a) det_monad  ('s, 'e+'a) det_monad"
where
  "liftE f  f >>= (λr. return (Inr r))"


text ‹
  Since the underlying type and return› function changed, 
  we need new definitions for when and unless:
›
definition
  whenE :: "bool  ('s, 'e + unit) det_monad  
            ('s, 'e + unit) det_monad" 
  where
  "whenE P f  if P then f else returnOk ()"

definition
  unlessE :: "bool  ('s, 'e + unit) det_monad  
            ('s, 'e + unit) det_monad" 
  where
  "unlessE P f  if P then returnOk () else f"


text ‹
  Throwing an exception when the parameter is @{term None}, otherwise
  returning @{term "v"} for @{term "Some v"}.
›
definition
  throw_opt :: "'e  'a option  ('s, 'e + 'a) det_monad" where
  "throw_opt ex x  
  case x of None  throwError ex | Some v  returnOk v"

subsection "Monad Laws for the Exception Monad"

text ‹More direct definition of @{const liftE}:›
lemma liftE_def2:
  "liftE f = (λs. ((λ(v,s'). (Inr v, s'))  (fst (f s)), snd (f s)))"
  by (auto simp: Let_def liftE_def return_def split_def bind_def h1_def h2_def)
  
text ‹Left @{const returnOk} absorbtion over @{term bindE}:›
lemma returnOk_bindE [simp]: "(returnOk x >>=E f) = f x"
  apply (unfold bindE_def returnOk_def)
  apply (clarsimp simp: lift_def)
  done

lemma lift_return [simp]:
  "lift (return  Inr) = return"
  by (rule ext)
     (simp add: lift_def throwError_def split: sum.splits)

text ‹Right @{const returnOk} absorbtion over @{term bindE}:›
lemma bindE_returnOk [simp]: "(m >>=E returnOk) = m"
  by (simp add: bindE_def returnOk_def)

text ‹Associativity of @{const bindE}:›
lemma bindE_assoc:
  "(m >>=E f) >>=E g = m >>=E (λx. f x >>=E g)"
  apply (simp add: bindE_def bind_assoc)
  apply (rule arg_cong [where f="λx. m >>= x"])
  apply (rule ext)
  apply (case_tac x, simp_all add: lift_def throwError_def)
  done

text @{const returnOk} could also be defined via @{const liftE}:›
lemma returnOk_liftE:
  "returnOk x = liftE (return x)"
  by (simp add: liftE_def returnOk_def)

text ‹Execution after throwing an exception is skipped:›
lemma throwError_bindE [simp]:
  "(throwError E >>=E f) = throwError E"
  by (simp add: bindE_def bind_def h1_def h2_def throwError_def lift_def return_def)


section "Syntax"

text ‹This section defines traditional Haskell-like do-syntax 
  for the state monad in Isabelle.›

subsection "Syntax for the Nondeterministic State Monad"

text ‹We use K_bind› to syntactically indicate the 
  case where the return argument of the left side of a @{term bind}
  is ignored›
definition
  K_bind_def [iff]: "K_bind  λx y. x"

nonterminal
  dobinds and dobind and nobind

syntax
  "_dobind"    :: "[pttrn, 'a] => dobind"             ("(_ / _)" 10)
  ""           :: "dobind => dobinds"                 ("_")
  "_nobind"    :: "'a => dobind"                      ("_")
  "_dobinds"   :: "[dobind, dobinds] => dobinds"      ("(_);//(_)")

  "_do"        :: "[dobinds, 'a] => 'a"               ("(do ((_);//(_))//od)" 100)
translations
  "_do (_dobinds b bs) e"  == "_do b (_do bs e)"
  "_do (_nobind b) e"      == "b >>= (CONST K_bind e)"
  "do x  a; e od"        == "a >>= (λx. e)"  

text ‹Syntax examples:›
lemma "do x  return 1; 
          return (2::nat); 
          return x 
       od = 
       return 1 >>= 
       (λx. return (2::nat) >>= 
            K_bind (return x))" 
  by (rule refl)

lemma "do x  return 1; 
          return 2; 
          return x 
       od = return 1" 
  by simp

subsection "Syntax for the Exception Monad"

text ‹
  Since the exception monad is a different type, we
  need to syntactically distinguish it in the syntax.
  We use doE›/odE› for this, but can re-use
  most of the productions from do›/od›
  above.
›

syntax
  "_doE" :: "[dobinds, 'a] => 'a"  ("(doE ((_);//(_))//odE)" 100)

translations
  "_doE (_dobinds b bs) e"  == "_doE b (_doE bs e)"
  "_doE (_nobind b) e"      == "b >>=E (CONST K_bind e)"
  "doE x  a; e odE"       == "a >>=E (λx. e)"

text ‹Syntax examples:›
lemma "doE x  returnOk 1; 
           returnOk (2::nat); 
           returnOk x 
       odE =
       returnOk 1 >>=E 
       (λx. returnOk (2::nat) >>=E 
            K_bind (returnOk x))"
  by (rule refl)

lemma "doE x  returnOk 1; 
           returnOk 2; 
           returnOk x 
       odE = returnOk 1" 
  by simp



section "Library of Monadic Functions and Combinators"


text ‹Lifting a normal function into the monad type:›
definition
  liftM :: "('a  'b)  ('s,'a) det_monad  ('s, 'b) det_monad"
where
  "liftM f m  do x  m; return (f x) od"

text ‹The same for the exception monad:›
definition
  liftME :: "('a  'b)  ('s,'e+'a) det_monad  ('s,'e+'b) det_monad"
where
  "liftME f m  doE x  m; returnOk (f x) odE"

text ‹
  Run a sequence of monads from left to right, ignoring return values.›
definition
  sequence_x :: "('s, 'a) det_monad list  ('s, unit) det_monad" 
where
  "sequence_x xs  foldr (λx y. x >>= (λ_. y)) xs (return ())"

text ‹
  Map a monadic function over a list by applying it to each element
  of the list from left to right, ignoring return values.
›
definition
  mapM_x :: "('a  ('s,'b) det_monad)  'a list  ('s, unit) det_monad"
where
  "mapM_x f xs  sequence_x (map f xs)"

text ‹
  Map a monadic function with two parameters over two lists,
  going through both lists simultaneously, left to right, ignoring
  return values.
›
definition
  zipWithM_x :: "('a  'b  ('s,'c) det_monad)  
                 'a list  'b list  ('s, unit) det_monad"
where
  "zipWithM_x f xs ys  sequence_x (zipWith f xs ys)"


text ‹The same three functions as above, but returning a list of
return values instead of unit›
definition
  sequence :: "('s, 'a) det_monad list  ('s, 'a list) det_monad" 
where
  "sequence xs  let mcons = (λp q. p >>= (λx. q >>= (λy. return (x#y))))
                 in foldr mcons xs (return [])"

definition
  mapM :: "('a  ('s,'b) det_monad)  'a list  ('s, 'b list) det_monad"
where
  "mapM f xs  sequence (map f xs)"

definition
  zipWithM :: "('a  'b  ('s,'c) det_monad)  
                 'a list  'b list  ('s, 'c list) det_monad"
where
  "zipWithM f xs ys  sequence (zipWith f xs ys)"

definition
  foldM :: "('b  'a  ('s, 'a) det_monad)  'b list  'a  ('s, 'a) det_monad" 
where
  "foldM m xs a  foldr (λp q. q >>= m p) xs (return a) "

text ‹The sequence and map functions above for the exception monad,
with and without lists of return value›
definition
  sequenceE_x :: "('s, 'e+'a) det_monad list  ('s, 'e+unit) det_monad" 
where
  "sequenceE_x xs  foldr (λx y. doE _  x; y odE) xs (returnOk ())"

definition
  mapME_x :: "('a  ('s,'e+'b) det_monad)  'a list  
              ('s,'e+unit) det_monad"
where
  "mapME_x f xs  sequenceE_x (map f xs)"

definition
  sequenceE :: "('s, 'e+'a) det_monad list  ('s, 'e+'a list) det_monad" 
where
  "sequenceE xs  let mcons = (λp q. p >>=E (λx. q >>=E (λy. returnOk (x#y))))
                 in foldr mcons xs (returnOk [])"

definition
  mapME :: "('a  ('s,'e+'b) det_monad)  'a list  
              ('s,'e+'b list) det_monad"
where
  "mapME f xs  sequenceE (map f xs)"


text ‹Filtering a list using a monadic function as predicate:›
primrec
  filterM :: "('a  ('s, bool) det_monad)  'a list  ('s, 'a list) det_monad"
where
  "filterM P []       = return []"
| "filterM P (x # xs) = do
     b   P x;
     ys  filterM P xs; 
     return (if b then (x # ys) else ys)
   od"


section "Catching and Handling Exceptions"

text ‹
  Turning an exception monad into a normal state monad
  by catching and handling any potential exceptions:
›
definition
  catch :: "('s, 'e + 'a) det_monad 
            ('e  ('s, 'a) det_monad) 
            ('s, 'a) det_monad" (infix "<catch>" 10)
where
  "f <catch> handler 
     do x  f;
        case x of
          Inr b  return b
        | Inl e  handler e
     od"

text ‹
  Handling exceptions, but staying in the exception monad.
  The handler may throw a type of exceptions different from
  the left side.
›
definition
  handleE' :: "('s, 'e1 + 'a) det_monad 
               ('e1  ('s, 'e2 + 'a) det_monad) 
               ('s, 'e2 + 'a) det_monad" (infix "<handle2>" 10)
where
  "f <handle2> handler 
   do
      v  f;
      case v of
        Inl e  handler e
      | Inr v'  return (Inr v')
   od"

text ‹
  A type restriction of the above that is used more commonly in
  practice: the exception handle (potentially) throws exception
  of the same type as the left-hand side.
›
definition
  handleE :: "('s, 'x + 'a) det_monad  
              ('x  ('s, 'x + 'a) det_monad)  
              ('s, 'x + 'a) det_monad" (infix "<handle>" 10)
where
  "handleE  handleE'"


text ‹
  Handling exceptions, and additionally providing a continuation
  if the left-hand side throws no exception:
›
definition
  handle_elseE :: "('s, 'e + 'a) det_monad 
                   ('e  ('s, 'ee + 'b) det_monad) 
                   ('a  ('s, 'ee + 'b) det_monad) 
                   ('s, 'ee + 'b) det_monad"
  ("_ <handle> _ <else> _" 10)
where
  "f <handle> handler <else> continue 
   do v  f;
   case v of Inl e   handler e
           | Inr v'  continue v'
   od"

section "Hoare Logic"

subsection "Validity"

text ‹This section defines a Hoare logic for partial correctness for
  the deterministic state monad as well as the exception monad.
  The logic talks only about the behaviour part of the monad and ignores
  the failure flag.

  The logic is defined semantically. Rules work directly on the
  validity predicate.

  In the deterministic state monad, validity is a triple of precondition,
  monad, and postcondition. The precondition is a function from state to 
  bool (a state predicate), the postcondition is a function from return value
  to state to bool. A triple is valid if for all states that satisfy the
  precondition, all result values and result states that are returned by
  the monad satisfy the postcondition. Note that if the computation returns
  the empty set, the triple is trivially valid. This means @{term "assert P"} 
  does not require us to prove that @{term P} holds, but rather allows us
  to assume @{term P}! Proving non-failure is done via separate predicate and
  calculus (see below).
›
definition
  valid :: "('s  bool)  ('s,'a) det_monad  ('a  's  bool)  bool" 
  ("_/ _ /_")
where
  "P f Q  s. P s  (r s'. ((r,s') = fst (f s)  Q r s'))"

text ‹
  Validity for the exception monad is similar and build on the standard 
  validity above. Instead of one postcondition, we have two: one for
  normal and one for exceptional results.
›
definition
  validE :: "('s  bool)  ('s, 'a + 'b) det_monad  
             ('b  's  bool)  
             ('a  's  bool)  bool" 
("_/ _ /(_⦄,/ _)")
where
  "P f Q⦄,E  P f  λv s. case v of Inr r  Q r s | Inl e  E e s "


text ‹
  The following two instantiations are convenient to separate reasoning
  for exceptional and normal case.
›
definition
  validE_R :: "('s  bool)  ('s, 'e + 'a) det_monad  
               ('a  's  bool)  bool"
   ("_/ _ /_⦄, -")
where
 "P f Q⦄,-  validE P f Q (λx y. True)"

definition
  validE_E :: "('s  bool)   ('s, 'e + 'a) det_monad  
               ('e  's  bool)  bool"
   ("_/ _ /-, _")
where
 "P f -,Q  validE P f (λx y. True) Q"


text ‹Abbreviations for trivial preconditions:›
abbreviation(input)
  top :: "'a  bool" ("")
where
  "  λ_. True"

abbreviation(input)
  bottom :: "'a  bool" ("")
where
  "  λ_. False"

text ‹Abbreviations for trivial postconditions (taking two arguments):›
abbreviation(input)
  toptop :: "'a  'b  bool" ("⊤⊤")
where
 "⊤⊤  λ_ _. True"

abbreviation(input)
  botbot :: "'a  'b  bool" ("⊥⊥")
where
 "⊥⊥  λ_ _. False"

text ‹
  Lifting ∧› and ∨› over two arguments. 
  Lifting ∧› and ∨› over one argument is already
  defined (written and› and or›).
›
definition
  bipred_conj :: "('a  'b  bool)  ('a  'b  bool)  ('a  'b  bool)" 
  (infixl "And" 96)
where
  "bipred_conj P Q  λx y. P x y  Q x y"

definition
  bipred_disj :: "('a  'b  bool)  ('a  'b  bool)  ('a  'b  bool)" 
  (infixl "Or" 91)
where
  "bipred_disj P Q  λx y. P x y  Q x y"


subsection "Determinism"

text ‹A monad of type det_monad› is deterministic iff it
returns exactly one state and result and does not fail› 
definition
  det :: "('a,'s) det_monad  bool"
where
  "det f  s. r. f s = (r,False)" 

text ‹A deterministic det_monad› can be turned
  into a normal state monad:›
definition
  the_run_state :: "('s,'a) det_monad  's  'a × 's"
where
  "the_run_state M  λs. THE s'. fst (M s) = s'"


subsection "Non-Failure"

text ‹
  With the failure flag, we can formulate non-failure separately
  from validity. A monad m› does not fail under precondition
  P›, if for no start state in that precondition it sets
  the failure flag.
›
definition
  no_fail :: "('s  bool)  ('s,'a) det_monad  bool"
where
  "no_fail P m  s. P s  ¬ (snd (m s))"


text ‹
  It is often desired to prove non-failure and a Hoare triple
  simultaneously, as the reasoning is often similar. The following
  definitions allow such reasoning to take place.
›

definition
  validNF ::"('s  bool)  ('s,'a) det_monad  ('a  's  bool)  bool"
      ("_/ _ /_⦄!")
where
  "validNF P f Q  valid P f Q  no_fail P f"

definition
  validE_NF :: "('s  bool)  ('s, 'a + 'b) det_monad 
             ('b  's  bool) 
             ('a  's  bool)  bool"
  ("_/ _ /(_⦄,/ _⦄!)")
where
  "validE_NF P f Q E  validE P f Q E  no_fail P f"

lemma validE_NF_alt_def:
  " P  B  Q ⦄, E ⦄! =  P  B  λv s. case v of Inl e  E e s | Inr r  Q r s ⦄!"
  by (clarsimp simp: validE_NF_def validE_def validNF_def)

section "Basic exception reasoning"

text ‹
  The following predicates no_throw› and no_return› allow
  reasoning that functions in the exception monad either do
  no throw an exception or never return normally.
›

definition "no_throw P A   P  A  λ_ _. True ⦄, λ_ _. False "

definition "no_return P A   P  A λ_ _. False⦄,λ_ _. True "

end