Theory Guard_Simp

(*
 * Copyright (c) 2023 Apple Inc. All rights reserved.
 *
 * SPDX-License-Identifier: BSD-2-Clause
 *)
theory Guard_Simp imports "AutoCorres2_Main.AutoCorres_Main" 
begin

install_C_file "guard_simp.c"

init-autocorres [   
  in_out_parameters = 
    inc(y:in_out) and
    inc2(y:in_out, z:in_out),
  in_out_globals = 
    shuffle,
  ts_force nondet = 
    shuffle
] "guard_simp.c"

autocorres guard_simp.c

section @{method monad_simp}

text ‹
The idea of @{method monad_simp} is to simplify / normalise an expression of the
@{typ "('e::default, 'a, 's) spec_monad"} by using a custom setup to deal with congruence
rules which allows to propagate properties gathered by guards or conditions.
A central property of the spec monad is that partial correctness properties in the sense of
@{const runs_to_partial} can be used to obtain properties of an intermediate state that can
be utilized to derive derive equality of the remaining program. This is illustrated by
the following lemma:
›

lemma 
  assumes partial[runs_to_vcg]: "f  s ?⦃P"
  assumes g: "r t. P (Result r) t  run (g r) t = run (g' r) t"
  shows "run (f  g) s = run (f  g') s"
  apply (rule run_bind_cong)
  subgoal by simp
  subgoal by runs_to_vcg (rule g)
  done

text ‹Note that the analog lemma does not hold for the legacy ('s, 'a) nondet_monad›. 
There you would need total correctness of @{term "f  s P"}.
The problem there are the cases where @{term f} might fail, e.g. due to non-termination or
a failing guard. Those cases do not collapse the complete computation and we basically know
nothing about the rest of the computation and thus cannot derive equality. A
similar rule with partial correctness can be derived for 
refinement, but of course this is not as convenient as an
equality.
›

text ‹Here are examples of @{method monad_simp} propagating state dependent guards.›
lemma 
"do {
   guard (λs. IS_VALID(32 word) s y);
   guard (λs. IS_VALID(32 word) s y) <catch> (λ_. guard (λs. IS_VALID(32 word) s y))
 } = 
  guard (λs. IS_VALID(32 word) s y)"
  apply (monad_simp)
  done


lemma "do {
    x  guard (λs. IS_VALID(32 word) s y0);
    _  guard (λs. IS_VALID(32 word) s z0);
    (y, z)  gets (λs. (heap_w32 s y0, heap_w32 s z0));
    (y, z)  return (inc2' y z);
    (y, z)  return (y, z);
    _  guard (λs. IS_VALID(32 word) s y0);
    _  modify (heap_w32_update (λh. h(y0 := y)));
    _  guard (λs. IS_VALID(32 word) s z0);
    modify (heap_w32_update (λh. h(z0 := z)))
  } =
  do {
    _  guard (λs. IS_VALID(32 word) s y0);
    _  guard (λs. IS_VALID(32 word) s z0);
    (y, z)  gets (λs. (heap_w32 s y0, heap_w32 s z0));
    (y, z)  return (inc2' y z);
    (y, z)  return (y, z);
    _  modify (heap_w32_update (λh. h(y0 := y)));
    modify (heap_w32_update (λh. h(z0 := z)))
  }"
  apply (monad_simp)
  done

lemma fixes x::"32 word" shows "do {when (x > 42) (throw x); guard (λ_. ¬ x > 42); return x } = 
  do {when (x > 42) (throw x); return x }"
  apply monad_simp
  done

lemma fixes x::"32 word" shows "do {unless (x > 42) (throw x); guard (λ_. x > 42); return x } = 
  do {unless (x > 42) (throw x); return x }"
  apply monad_simp
  done

lemma "do {
    x  guard (λs::lifted_globals. IS_VALID(32 word) s y0);
    _  guard (λs. IS_VALID(32 word) s z0);
    condition (λ_. c) (throw ()) skip;

    _  guard (λs. IS_VALID(32 word) s y0);
    _  guard (λs. IS_VALID(32 word) s z0);
    modify (heap_w32_update (λh. h(z0 := z)))
  } =  
  do {
      _  guard (λs. IS_VALID(32 word) s y0);
      _  guard (λs. IS_VALID(32 word) s z0);
      _  when c (throw ());
      modify (heap_w32_update (λh. h(z0 := z)))
    }"
  apply monad_simp
  done

text ‹The simp_depth_limit› is derived from the depth of the term.›
lemma "do {
    _  guard (λs::lifted_globals. IS_VALID(32 word) s y0);
    _  guard (λs. IS_VALID(32 word) s z0);
    _  guard (λs. IS_VALID(32 word) s y0);
      condition (λ_. c) (modify (heap_w32_update (λh. h(z0 := z)))) ( modify (heap_w32_update (λh. h(z0 := z))));
    _  guard (λs. IS_VALID(32 word) s z0);
    _  guard (λs. IS_VALID(32 word) s y0);
      condition (λ_. c) (modify (heap_w32_update (λh. h(z0 := z)))) ( modify (heap_w32_update (λh. h(z0 := z))));
    _  guard (λs. IS_VALID(32 word) s z0);
    _  guard (λs. IS_VALID(32 word) s y0);
      condition (λ_. c) (modify (heap_w32_update (λh. h(z0 := z)))) ( modify (heap_w32_update (λh. h(z0 := z))));
    _  guard (λs. IS_VALID(32 word) s z0);
    _  guard (λs. IS_VALID(32 word) s y0);
      condition (λ_. c) (modify (heap_w32_update (λh. h(z0 := z)))) ( modify (heap_w32_update (λh. h(z0 := z))));
    _  guard (λs. IS_VALID(32 word) s z0);
    _  guard (λs. IS_VALID(32 word) s y0);
      condition (λ_. c) (modify (heap_w32_update (λh. h(z0 := z)))) ( modify (heap_w32_update (λh. h(z0 := z))));
    _  guard (λs. IS_VALID(32 word) s z0);
    _  guard (λs. IS_VALID(32 word) s y0);
      condition (λ_. c) (modify (heap_w32_update (λh. h(z0 := z)))) ( modify (heap_w32_update (λh. h(z0 := z))));
    _  guard (λs. IS_VALID(32 word) s z0);
    _  guard (λs. IS_VALID(32 word) s y0);
      condition (λ_. c) (modify (heap_w32_update (λh. h(z0 := z)))) ( modify (heap_w32_update (λh. h(z0 := z))));
    _  guard (λs. IS_VALID(32 word) s z0);
    _  guard (λs. IS_VALID(32 word) s y0);
      condition (λ_. c) (modify (heap_w32_update (λh. h(z0 := z)))) ( modify (heap_w32_update (λh. h(z0 := z))));
    _  guard (λs. IS_VALID(32 word) s z0);
    _  guard (λs. IS_VALID(32 word) s y0);
      condition (λ_. c) (modify (heap_w32_update (λh. h(z0 := z)))) ( modify (heap_w32_update (λh. h(z0 := z))));
    _  guard (λs. IS_VALID(32 word) s z0);
    _  guard (λs. IS_VALID(32 word) s y0);
      condition (λ_. c) (modify (heap_w32_update (λh. h(z0 := z)))) ( modify (heap_w32_update (λh. h(z0 := z))));
    _  guard (λs. IS_VALID(32 word) s z0);
    _  guard (λs. IS_VALID(32 word) s y0);
      condition (λ_. c) (modify (heap_w32_update (λh. h(z0 := z)))) ( modify (heap_w32_update (λh. h(z0 := z))));
    _  guard (λs. IS_VALID(32 word) s z0);
    _  guard (λs. IS_VALID(32 word) s y0);
      condition (λ_. c) (modify (heap_w32_update (λh. h(z0 := z)))) ( modify (heap_w32_update (λh. h(z0 := z))));
    _  guard (λs. IS_VALID(32 word) s z0);
    _  guard (λs. IS_VALID(32 word) s y0);
      condition (λ_. c) (modify (heap_w32_update (λh. h(z0 := z)))) ( modify (heap_w32_update (λh. h(z0 := z))));
    _  guard (λs. IS_VALID(32 word) s z0);
    _  guard (λs. IS_VALID(32 word) s y0);
      condition (λ_. c) (modify (heap_w32_update (λh. h(z0 := z)))) ( modify (heap_w32_update (λh. h(z0 := z))));
    _  guard (λs. IS_VALID(32 word) s z0);
    _  guard (λs. IS_VALID(32 word) s y0);
      condition (λ_. c) (modify (heap_w32_update (λh. h(z0 := z)))) ( modify (heap_w32_update (λh. h(z0 := z))));
    _  guard (λs. IS_VALID(32 word) s z0);
    _  guard (λs. IS_VALID(32 word) s y0);
      condition (λ_. c) (modify (heap_w32_update (λh. h(z0 := z)))) ( modify (heap_w32_update (λh. h(z0 := z))));
    _  guard (λs. IS_VALID(32 word) s y0);
    _  guard (λs. IS_VALID(32 word) s z0);
    modify (heap_w32_update (λh. h(z0 := z)))
  } =  
 do {
      _  guard (λs. IS_VALID(32 word) s y0);
      _  guard (λs. IS_VALID(32 word) s z0);
      _  condition (λ_. c)
             (modify (heap_w32_update (λh. h(z0 := z))))
             (modify (heap_w32_update (λh. h(z0 := z))));
      _  condition (λ_. c)
             (modify (heap_w32_update (λh. h(z0 := z))))
             (modify (heap_w32_update (λh. h(z0 := z))));
      _  condition (λ_. c)
             (modify (heap_w32_update (λh. h(z0 := z))))
             (modify (heap_w32_update (λh. h(z0 := z))));
      _  condition (λ_. c)
             (modify (heap_w32_update (λh. h(z0 := z))))
             (modify (heap_w32_update (λh. h(z0 := z))));
      _  condition (λ_. c)
             (modify (heap_w32_update (λh. h(z0 := z))))
             (modify (heap_w32_update (λh. h(z0 := z))));
      _  condition (λ_. c)
             (modify (heap_w32_update (λh. h(z0 := z))))
             (modify (heap_w32_update (λh. h(z0 := z))));
      _  condition (λ_. c)
             (modify (heap_w32_update (λh. h(z0 := z))))
             (modify (heap_w32_update (λh. h(z0 := z))));
      _  condition (λ_. c)
             (modify (heap_w32_update (λh. h(z0 := z))))
             (modify (heap_w32_update (λh. h(z0 := z))));
      _  condition (λ_. c)
             (modify (heap_w32_update (λh. h(z0 := z))))
             (modify (heap_w32_update (λh. h(z0 := z))));
      _  condition (λ_. c)
             (modify (heap_w32_update (λh. h(z0 := z))))
             (modify (heap_w32_update (λh. h(z0 := z))));
      _  condition (λ_. c)
             (modify (heap_w32_update (λh. h(z0 := z))))
             (modify (heap_w32_update (λh. h(z0 := z))));
      _  condition (λ_. c)
             (modify (heap_w32_update (λh. h(z0 := z))))
             (modify (heap_w32_update (λh. h(z0 := z))));
      _  condition (λ_. c)
             (modify (heap_w32_update (λh. h(z0 := z))))
             (modify (heap_w32_update (λh. h(z0 := z))));
      _  condition (λ_. c)
             (modify (heap_w32_update (λh. h(z0 := z))))
             (modify (heap_w32_update (λh. h(z0 := z))));
      _  condition (λ_. c)
             (modify (heap_w32_update (λh. h(z0 := z))))
             (modify (heap_w32_update (λh. h(z0 := z))));
      _  condition (λ_. c)
             (modify (heap_w32_update (λh. h(z0 := z))))
             (modify (heap_w32_update (λh. h(z0 := z))));
      modify (heap_w32_update (λh. h(z0 := z)))
    }"
  apply monad_simp
  done

text ‹
To achieve the goal to propagate even state dependent guards @{method monad_simp} is 
configured by supplying special congruence rules that contain syntactic markers to influence
how @{method monad_simp} descends into an expression. The basic idea is to gather
properties of guards or conditions as we descend into the expression and attempt to propagate
these properties by proving that they stay invariant.

The syntactic markers are:
 @{term "ADD_FACT P s"}: add the fact that property termP holds for the 
  current state terms.
 @{term "PRESERVED_FACTS f s r t"}: maintain those facts about state terms that 
  still hold in state termt when function termf transitions from state terms 
  to state termt while producing result termr. Moreover the theorems in @{attribute monad_simp_derive_rule} are used
  to derive facts that hold after the transition. This are added in the same mannor as
  with @{const ADD_FACT}.
 @{term "PRESERVED_FACTS_WHILE C B i s r t"}: maintain those facts about state terms that 
  still hold in state termt that can be reached by unrolling the while loop
   termwhileLoop C B i and transitioning from state terms 
  to state termt while producing result termr.

Here is the set of rules that are configured:
›
ML Monad_Cong_Simp.print_rules (Context.Proof @{context})
ML Monad_Cong_Simp.print_derive_rules (Context.Proof @{context})

ML Monad_Cong_Simp.print_stop_congs (Context.Proof @{context})
text ‹The stop_congs› are used to block the simplifier to descend into the subterms. They are
generated from the supplied congruence rules. In case of local definitions global congruence
rules can be supplied via @{attribute monad_simp_global_stop_cong}


text ‹To facilitate the preservation proofs for those accumulated facts we use the 
collection of theorems @{thm monad_simp_simp}.
›
thm monad_simp_simp

text ‹Currently the setup is tuned towards preservation of constptr_valid predicates.
To preserve such predicates we use the property that most programs do not change the 
typing information at all. Which can be expressed by a construns_to_partial statement:

termf  s ?⦃ λ_. unchanged_typing_on UNIV s 

Note that autocorres attempts to prove those properties for all its outputs and collects
successful lemmas in @{thm unchanged_typing}.
›
thm unchanged_typing
thm unchanged_typing_on_simps


text ‹Autocorres applies @{method monad_simp} to all the final results. 
This is done in a staged approach: 
 First the an initial (raw) definition is derived
 Then an constunchanged_typing_on theorem is derived for that definition (if possible)
 Then this is used by @{method monad_simp} to simplify the raw definition and arrive
  at the final definition.
›


context guard_simp_all_corres
begin
thm ts_def
thm raw.inc2_loop'_def inc2_loop'_def
thm raw.inc2_while'_def inc2_while'_def
thm raw.odd'.simps odd'.simps
thm raw.even'.simps even'.simps
thm raw.heap_inc2'_def heap_inc2'_def 
thm raw.cond'_def cond'_def

thm raw.fac_exit'.simps fac_exit'.simps
thm raw.dec'_def dec'_def

ML_val Monad_Cong_Simp.print_rules (Context.Proof @{context})

declare [[ML_print_depth=1000]]
ML_val Monad_Cong_Simp.Data.get (Context.Proof @{context})
lemma " heap_w32.assume_with_fresh_stack_ptr 1 (λa. {[n]})
   (λnp. heap_w32.assume_with_fresh_stack_ptr 1 (λs. {[m]})
            (λmp. do {
                  ret  swap' np mp;
                  x  guard (λs. IS_VALID(32 word) s np);
                  _  guard (λs. IS_VALID(32 word) s mp);
                  gets (λs. heap_w32 s np + heap_w32 s mp)
                })) = 
  heap_w32.assume_with_fresh_stack_ptr 1 (λa. {[n]})
     (λt. heap_w32.assume_with_fresh_stack_ptr 1 (λs. {[m]}) (λta. do {
        ret  swap' t ta;
        gets (λs. heap_w32 s t + heap_w32 s ta)
      }))"
  supply [[verbose=6]]
  apply monad_simp
  done

lemma " heap_s32.with_fresh_stack_ptr 1 (λa. {[n]})
   (λnp. heap_s32.with_fresh_stack_ptr 1 (λs. {[m]})
            (λmp. do {
  
                  x  guard (λs. IS_VALID(32 signed word) s np);
                  _  guard (λs. IS_VALID(32 signed word) s mp);
                  gets (λs. heap_s32 s np + heap_s32 s mp)
                })) = 
    heap_s32.with_fresh_stack_ptr 1 (λa. {[n]})
     (λnp. heap_s32.with_fresh_stack_ptr 1 (λs. {[m]}) (λmp. gets (λs. heap_s32 s np + heap_s32 s mp)))"
  supply [[verbose=0]]
  apply monad_simp
  done


end


context ts_definition_shuffle
begin
thm shuffle'_def

text ‹Here we use @{method monad_simp} to remove the nested guard about
termptr_span (buf +p - i)  𝒢. So this kind of technique might be useful 
to propagate assumptions about the inputs of a function to simplify the
program, (e.g. by solving some guards) before running @{method runs_to_vcg} on
the program.
›

lemma 
  assumes G: "i. i  4  ptr_span (buf +p - i)  𝒢"
  shows "shuffle' buf len = 
    condition (λs. 4 < uint len)
      (Spec_Monad.return 0x2A)
      (do {
         (i, y) 
           whileLoop (λ(i, tmp) s. i < uint len)
             (λ(i, tmp). do {
                   _  Spec_Monad.guard (λs. 0  i);
                   _  Spec_Monad.guard (λs. IS_VALID(8 word) s (buf +p - i));
                   tmp 
                     Spec_Monad.gets
                      (λs. tmp ||
                            (UCAST(8  32) (heap_w8 s (buf +p - i)) <<
                             unat ((word_of_int i * 8)::32 signed word)));
                   Spec_Monad.return (i + 1, tmp)
                 })
            (0, 0);
         Spec_Monad.return y
       })"
  unfolding shuffle'_def
  apply (monad_simp simp add: G simp del: size_simps)
  done


lemma "do {x  return n; 
         unless (x = 42) (throw e);
         (i, y) 
           whileLoop (λ(i, tmp) s. i < uint len)
             (λ(i, tmp). do {
                   _  Spec_Monad.guard (λs. 0  i);
                   _  Spec_Monad.guard (λs. IS_VALID(8 word) s (buf +p - i));
                   tmp 
                     Spec_Monad.gets
                      (λs. tmp ||
                            (UCAST(8  32) (heap_w8 s (buf +p - i)) <<
                             unat ((word_of_int i * 8)::32 signed word)));
                   Spec_Monad.return (i + 1, tmp)
                 })
            (0, x);
         return x} = 
   do {x  return n; 
         unless (x = 42) (throw e);
         (i, y) 
           whileLoop (λ(i, tmp) s. i < uint len)
             (λ(i, tmp). do {
                   _  Spec_Monad.guard (λs. 0  i);
                   _  Spec_Monad.guard (λs. IS_VALID(8 word) s (buf +p - i));
                   tmp 
                     Spec_Monad.gets
                      (λs. tmp ||
                            (UCAST(8  32) (heap_w8 s (buf +p - i)) <<
                             unat ((word_of_int i * 8)::32 signed word)));
                   Spec_Monad.return (i + 1, tmp)
                 })
            (0, 42);
         return 42}"
  apply monad_simp
  done


lemma "do {x  return n; 
         unless (x = 42) (throw e);
         (i, y, p) 
           whileLoop (λ(i, tmp, p::(nat × nat)) s. i < uint len)
             (λ(i, tmp, p). do {
                   _  Spec_Monad.guard (λs. 0  i);
                   _  Spec_Monad.guard (λs. IS_VALID(8 word) s (buf +p - i));
                   tmp 
                     Spec_Monad.gets
                      (λs. tmp ||
                            (UCAST(8  32) (heap_w8 s (buf +p - i)) <<
                             unat ((word_of_int i * 8)::32 signed word)));
                   Spec_Monad.return (i + 1, tmp, p)
                 })
            (0, x, p);
         return x} = 
   do {
      x  return n;
      _  unless (x = 42) (throw e);
      (i, y, x3, x4) 
        whileLoop (λ(i, tmp, x3, x4) s. i < uint len)
          (λ(i, tmp, x3, x4). do {
                _  guard (λs. 0  i);
                _  guard (λs. IS_VALID(8 word) s (buf +p - i));
                tmp 
                  gets (λs. tmp || (UCAST(8  32) (heap_w8 s (buf +p - i)) << unat ((word_of_int i * 8)::32 signed word)));
                return (i + 1, tmp, x3, x4)
              })
         (0, 42, p);
      return 42
    } "
  apply monad_simp
  done

lemma 
 "condition (λs. n = g_'' s) 
   (return n) (return (n + 1)) = 
 condition (λs. n = g_'' s) 
   (return n) (return (n + 1))"
  apply monad_simp
  done

text ‹This lemma demonstrates the tupled-eta expansion of @{method monad_simp}. Note that on the
left hand side the initialiser of the constwhileLoop is contracted. So there is no bound variables
in constbind.›

lemma "(bind (return (0,0)) (
           whileLoop (λ(i, tmp) s. i < uint len)
             (λ(i, tmp). do {
                   _  Spec_Monad.guard (λs. 0  i);
                   _  Spec_Monad.guard (λs. IS_VALID(8 word) s (buf +p - i));
                   _  Spec_Monad.guard (λs. 0  i);
                   tmp 
                     Spec_Monad.gets
                      (λs. tmp ||
                            (UCAST(8  32) (heap_w8 s (buf +p - i)) <<
                             unat ((word_of_int i * 8)::32 signed word)));
                   Spec_Monad.return (i + 1, tmp)
                 })
  )) = do {
      (x1, x2)  return (0, 0);
      whileLoop (λ(i, tmp) s. i < uint len)
        (λ(i, tmp). do {
              _  guard (λs. 0  i);
              _  guard (λs. IS_VALID(8 word) s (buf +p - i));
              tmp 
                gets
                 (λs. tmp ||
                       (UCAST(8  32) (heap_w8 s (buf +p - i)) << unat ((word_of_int i * 8)::32 signed word)));
              return (i + 1, tmp)
            })
       (x1, x2)
    }"
  apply monad_simp
  done

end

lemma "condition (λs. x < (5::nat)) (case v of (x, y)  (g x y::('a, 's) res_monad)) (h x) = 
       condition (λs. x < (5::nat)) (case v of (x, y)  (g x y::('a, 's) res_monad)) (h x) "
  by (monad_simp)

lemma "condition (λs. x < (5::nat)) 
         (case v of (x1, y1, z1)  (do {guard (λ_. x < 5); g x1 z1 y1})) 
         (h x) = 
       condition (λs. x < 5)
         (case v of (x1, y1, z1)  g x1 z1 y1)
          (h x) "
  apply (monad_simp)
  done


lemma 
  assumes "X = Y"
  shows
"condition (λs. x < (5::nat)) 
         (case v of (x1, y1, z1)  (do {guard (λ_. x < 5); g x1 z1 y1})) 
         (h x) = 
       condition (λs. x < 5)
         (case v of (x1, y1, z1)  g x1 z1 y1)
          (h x)   X = Y"
  apply (monad_simp) ― ‹no constSTOP should be in premises›
  by (rule assms)

lemma "(when ((x::nat) > 2) (do {guard (λ_. x > 2); return ()})) = when (2 < x) skip"
  by (monad_simp)

end