Theory Step_Spec

section "Misprediction and Speculative Semantics"

text ‹This theory formalizes an optimized speculative semantics, which allows for a characterization of the Spectre vulnerability, this work is inspired and based off the speculative semantics introduced by Cheang et al. \cite{Cheang}›

theory Step_Spec
imports Step_Basic
begin

subsection "Misprediction Oracle"
text ‹The speculative semantics is parameterised by a misprediction oracle. 
This consists of a predictor state:›
typedecl predState

text ‹Along with predicates "mispred" (which decides when a misprediction occurs), "resolve" (which decides for when a speculation is resolved)›
text ‹Both depend on the predictor state (which evolves via the update function) and the program counters of nested speculation›
locale Prog_Mispred = 
Prog prog 
for prog :: "com list"
+ 
fixes mispred :: "predState  pcounter list  bool"
and resolve :: "predState  pcounter list  bool"
and update :: "predState  pcounter list  predState"
begin 

subsection "Mispredicting Step"
text "stepM simply goes the other way than stepB at branches"
inductive
stepM :: "config × val llist × val llist  config × val llist × val llist  bool" (infix "→M" 55)
where
IfTrue[intro]: 
"pc < endPC  prog!pc = IfJump b pc1 pc2  
 bval b s  
 (Config pc s, ibT, ibUT) →M (Config pc2 s, ibT, ibUT)" 
|
IfFalse[intro]: 
"pc < endPC  prog!pc = IfJump b pc1 pc2  
 ¬ bval b s  
 (Config pc s, ibT, ibUT) →M (Config pc1 s, ibT, ibUT)"

subsubsection "State Transitions"
definition "finalM = final stepM"

lemma finalM_iff_aux:  
"pc < endPC  is_IfJump (prog!pc) 
  
 (cfg'. (Config pc s, ibT, ibUT) →M cfg')"
apply (cases s) subgoal for vst avst h p apply clarsimp  (* apply clarsimp subgoal for vs hh  *)
(* apply safe subgoal *)
apply (cases "prog!pc")
  subgoal by (auto elim: stepM.cases)
  subgoal by (auto elim: stepM.cases)
  subgoal by (auto elim: stepM.cases)  
  subgoal by (auto elim: stepM.cases)
  subgoal by (auto elim: stepM.cases)
  subgoal by (auto elim: stepM.cases) 
  subgoal by (auto elim: stepM.cases) 
  subgoal by (auto elim: stepM.cases) 
  subgoal by (auto elim: stepM.cases) 
  subgoal by (auto elim: stepM.cases,meson IfFalse IfTrue) . .

lemma finalM_iff: 
"finalM (Config pc (State vst avst h p), ibT, ibUT) 
 
 (pc  endPC  ¬ is_IfJump (prog!pc))"
using finalM_iff_aux unfolding finalM_def final_def  
by (metis linorder_not_less)

lemma finalB_imp_finalM: 
"finalB (cfg, ibT, ibUT)  finalM (cfg, ibT, ibUT)"
apply(cases cfg) subgoal for pc s apply(cases s)
subgoal for vst avst h p apply clarsimp unfolding finalB_iff finalM_iff by auto . .

lemma not_finalM_imp_not_finalB: 
"¬ finalM (cfg, ibT, ibUT)  ¬ finalB (cfg, ibT, ibUT)"
using finalB_imp_finalM by blast

(* *)

lemma stepM_determ:
"cfg_ib →M cfg_ib'  cfg_ib →M cfg_ib''  cfg_ib'' = cfg_ib'"
apply(induction arbitrary: cfg_ib'' rule: stepM.induct)
by (auto elim: stepM.cases)
 
definition nextM :: "config × val llist × val llist  config × val llist × val llist" where 
"nextM cfg_ib  SOME cfg'_ib'. cfg_ib →M cfg'_ib'"
 
lemma nextM_stepM: "¬ finalM cfg_ib  cfg_ib →M (nextM cfg_ib)"
unfolding nextM_def apply(rule someI_ex) 
unfolding finalM_def final_def by auto

lemma stepM_nextM: "cfg_ib →M cfg'_ib'  cfg'_ib' = nextM cfg_ib"
unfolding nextM_def apply(rule sym) apply(rule some_equality)
using stepM_determ by auto

lemma nextM_iff_stepM: "¬ finalM cfg_ib  nextM cfg_ib  = cfg'_ib'  cfg_ib →M cfg'_ib'"
using nextM_stepM stepM_nextM by blast

lemma stepM_iff_nextM: "cfg_ib →M cfg'_ib'  ¬ finalM cfg_ib  nextM cfg_ib  = cfg'_ib'"
by (metis finalM_def final_def stepM_nextM)

(* *)

lemma nextM_IfTrue[simp]:  
"pc < endPC  prog!pc = IfJump b pc1 pc2  
 ¬ bval b s  
 nextM (Config pc s, ibT, ibUT) = (Config pc1 s, ibT, ibUT)" 
by(intro stepM_nextM[THEN sym] stepM.intros)

lemma nextM_IfFalse[simp]:  
"pc < endPC  prog!pc = IfJump b pc1 pc2  
 bval b s  
 nextM (Config pc s, ibT, ibUT) = (Config pc2 s, ibT, ibUT)" 
by(intro stepM_nextM[THEN sym] stepM.intros)

end (* context Prog_Mispred *)

subsection "Speculative Semantics"
text ‹A "speculative" configuration is a quadruple consisting of:
\begin{itemize}
\item The predictor's state
\item The nonspeculative configuration (at level 0 so to speak)
\item The list of speculative configurations (modelling nested speculation, levels 1 to n, from left to right: so the last in this list is at the current speculaton level, n)
\item The list of inputs in the input buffer
\end{itemize}
›

text ‹We think of cfgs as a stack of configurations, one for each speculation level in a nested speculative execution. 
At level 0 (empty list) we have the configuration for normal, non-speculative execution. 
At each moment, only the top of the configuration stack, "hd cfgs" is active.›

type_synonym configS = "predState × config × config list × val llist × val llist × loc set"

context Prog_Mispred
begin

text ‹The speculative semantics is more involved than both the normal and basic semantics, so a short description of each rule is provided:
\begin{itemize}
\item Non\_spec\_ normal: when we are either not mispredicting or not at a branch and there is no current speculation, i.e. normal execution

\item Nonspec\_mispred: when we are mispredicting and at a branch, speculation occurs down the wrong branch, i.e. branch misprediction

\item Spec\_normal: when we are either not mispredicting or not at a branch BUT there is speculation, i.e. standard speculative execution

\item Spec\_mispred: when we are mispredicting and at a branch, AND also speculating... speculation occurs down the wrong branch, and we go to another speculation level i.e. nested speculative execution

\item Spec\_Fence: when there is current speculation and a Fence is hit, all speculation resolves

\item Spec Resolve: If the resolve predicate is true, resolution occurs for one speculation level. In contrast to Fences, resolve does not necessarily kill all speculation levels, but allows resolution one level at a time
\end{itemize}›
inductive
stepS :: "configS  configS  bool" (infix "→S" 55)
where 
nonspec_normal: 
"cfgs = []   
 ¬ is_IfJump (prog!(pcOf cfg))  ¬ mispred pstate [pcOf cfg]  
 pstate' = pstate  
 ¬ finalB (cfg, ibT, ibUT)  (cfg', ibT', ibUT') = nextB (cfg, ibT, ibUT)  
 cfgs' = []  
 ls' = ls  readLocs cfg
  
 (pstate, cfg, cfgs, ibT, ibUT, ls) →S (pstate', cfg', cfgs', ibT', ibUT', ls')"
|
nonspec_mispred: 
"cfgs = []  
 is_IfJump (prog!(pcOf cfg))  mispred pstate [pcOf cfg]  
 pstate' = update pstate [pcOf cfg]  
 ¬ finalM (cfg, ibT, ibUT)  (cfg', ibT', ibUT') = nextB (cfg, ibT, ibUT)  (cfg1', ibT1', ibUT1') = nextM (cfg, ibT, ibUT)  
 cfgs' = [cfg1']  
 ls' = ls  readLocs cfg
  
 (pstate, cfg, cfgs, ibT, ibUT, ls) →S (pstate', cfg', cfgs', ibT', ibUT', ls')" 
|
spec_normal: 
"cfgs  []  
 ¬ resolve pstate (pcOf cfg # map pcOf cfgs)   
 ¬ is_IfJump (prog!(pcOf (last cfgs)))  ¬ mispred pstate (pcOf cfg # map pcOf cfgs)  
 prog!(pcOf (last cfgs))  Fence 
 pstate' = pstate  
 ¬ is_getInput (prog!(pcOf (last cfgs))) 
 ¬ is_Output (prog!(pcOf (last cfgs)))  
 ¬ finalB (last cfgs, ibT, ibUT)  (cfg1',ibT', ibUT') = nextB (last cfgs, ibT, ibUT)  
 cfg' = cfg  cfgs' = butlast cfgs @ [cfg1']  
 ls' = ls  readLocs (last cfgs)
  
 (pstate, cfg, cfgs, ibT, ibUT, ls) →S (pstate', cfg', cfgs', ibT', ibUT', ls')"
|
spec_mispred: 
"cfgs  []  
 ¬ resolve pstate (pcOf cfg # map pcOf cfgs)  
 is_IfJump (prog!(pcOf (last cfgs)))  mispred pstate (pcOf cfg # map pcOf cfgs)  
 pstate' = update pstate (pcOf cfg # map pcOf cfgs)  
 ¬ finalM (last cfgs, ibT, ibUT)  
 (lcfg', ibT', ibUT') = nextB (last cfgs, ibT, ibUT)  (cfg1', ibT1', ibUT1') = nextM (last cfgs, ibT, ibUT)  
 cfg' = cfg  cfgs' = butlast cfgs @ [lcfg'] @ [cfg1']  
 ls' = ls  readLocs (last cfgs)
  
 (pstate, cfg, cfgs, ibT, ibUT, ls) →S (pstate', cfg', cfgs', ibT', ibUT', ls')"
|
spec_Fence: 
"cfgs  []  
 ¬ resolve pstate (pcOf cfg # map pcOf cfgs)  
 prog!(pcOf (last cfgs)) = Fence 
 pstate' = pstate  cfg' = cfg  cfgs' = []  
 ibT = ibT'  ibUT = ibUT'  ls' = ls 
  
 (pstate, cfg, cfgs, ibT, ibUT, ls) →S (pstate', cfg', cfgs', ibT', ibUT', ls')"
|
spec_resolve: 
"cfgs  []  
 resolve pstate (pcOf cfg # map pcOf cfgs)   
 pstate' = update pstate (pcOf cfg # map pcOf cfgs) 
 cfg' = cfg  cfgs' = butlast cfgs  
 ibT = ibT'  ibUT = ibUT'  ls' = ls 
  
 (pstate, cfg, cfgs, ibT, ibUT, ls) →S (pstate', cfg', cfgs', ibT', ibUT', ls')"

lemmas stepS_induct = stepS.induct[split_format(complete)]

(* *)
subsubsection "State Transitions"
lemma stepS_nonspec_normal_iff[simp]: 
"cfgs = []  ¬ is_IfJump (prog!(pcOf cfg))  ¬ mispred pstate [pcOf cfg]  
  
 (pstate, cfg, cfgs, ibT, ibUT, ls) →S (pstate', cfg', cfgs', ibT', ibUT', ls')
  
 (pstate' = pstate  ¬ finalB (cfg, ibT, ibUT)  
  (cfg', ibT', ibUT') = nextB (cfg, ibT, ibUT)  
  cfgs' = []  ls' = ls  readLocs cfg)"
apply(subst stepS.simps) by auto

lemma stepS_nonspec_normal_iff1[simp]: 
"cfgs = []  ¬ is_IfJump (prog!pc)  ¬ mispred pstate [pc]  
  
 (pstate, (Config pc (State (Vstore vs) avst h p)), cfgs, ibT, ibUT, ls) →S (pstate', (Config pc' (State (Vstore vs') avst' h' p')), cfgs', ibT', ibUT', ls')
  
 (pstate' = pstate  ¬ finalB ((Config pc (State (Vstore vs) avst h p)), ibT, ibUT)  
  ((Config pc' (State (Vstore vs') avst' h' p')), ibT', ibUT') = nextB ((Config pc (State (Vstore vs) avst h p)), ibT, ibUT)  
  cfgs' = []  ls' = ls  readLocs (Config pc (State (Vstore vs) avst h p)))"
  using stepS_nonspec_normal_iff config.sel(1) by presburger


lemma stepS_nonspec_mispred_iff[simp]: 
"cfgs = []  is_IfJump (prog!(pcOf cfg))  mispred pstate [pcOf cfg]
  
 (pstate, cfg, cfgs, ibT, ibUT, ls) →S (pstate', cfg', cfgs', ibT', ibUT', ls')
  
 (cfg1' ibT1' ibUT1'. pstate' = update pstate [pcOf cfg]  
  ¬ finalM (cfg, ibT, ibUT)  (cfg', ibT', ibUT') = nextB (cfg, ibT, ibUT) 
  (cfg1', ibT1', ibUT1') = nextM (cfg, ibT, ibUT)  
  cfgs' = [cfg1']  ls' = ls  readLocs cfg)" 
apply(subst stepS.simps) by auto

lemma stepS_spec_normal_iff[simp]: 
"cfgs  []  
 ¬ resolve pstate (pcOf cfg # map pcOf cfgs)   
 ¬ is_IfJump (prog!(pcOf (last cfgs)))  ¬ mispred pstate (pcOf cfg # map pcOf cfgs)  
 prog!(pcOf (last cfgs))  Fence 
  
 (pstate, cfg, cfgs, ibT, ibUT, ls) →S (pstate', cfg', cfgs', ibT', ibUT', ls')
 
 (cfg1'. pstate' = pstate  
    ¬ is_getInput (prog!(pcOf (last cfgs))) 
    ¬ is_getInput (prog!(pcOf (last cfgs)))  ¬ is_Output (prog!(pcOf (last cfgs)))  
    ¬ finalB (last cfgs, ibT, ibUT)  (cfg1',ibT',ibUT') = nextB (last cfgs, ibT, ibUT)   
    cfg' = cfg  cfgs' = butlast cfgs @ [cfg1']  ls' = ls  readLocs (last cfgs))"
apply(subst stepS.simps) by auto

lemma stepS_spec_mispred_iff[simp]: 
"cfgs  []  
 ¬ resolve pstate (pcOf cfg # map pcOf cfgs)  
 is_IfJump (prog!(pcOf (last cfgs)))  mispred pstate (pcOf cfg # map pcOf cfgs)
  
 (pstate, cfg, cfgs, ibT, ibUT, ls) →S (pstate', cfg', cfgs', ibT', ibUT', ls')
  
 (cfg1' ibT1' ibUT1' lcfg'. pstate' = update pstate (pcOf cfg # map pcOf cfgs)  
  ¬ finalM (last cfgs, ibT, ibUT)  
  (lcfg', ibT', ibUT') = nextB (last cfgs, ibT, ibUT)  
  (cfg1', ibT1', ibUT1') = nextM (last cfgs, ibT, ibUT)  
  cfg' = cfg  cfgs' = butlast cfgs @ [lcfg'] @ [cfg1']  ls' = ls  readLocs (last cfgs))"
apply(subst stepS.simps) by auto

lemma stepS_spec_Fence_iff[simp]: 
"cfgs  []  
 ¬ resolve pstate (pcOf cfg # map pcOf cfgs)  
 prog!(pcOf (last cfgs)) = Fence 
  
 (pstate, cfg, cfgs, ibT, ibUT, ls) →S (pstate', cfg', cfgs', ibT', ibUT', ls')
 
 (pstate' = pstate  cfg = cfg'  cfgs' = []  ibT' = ibT  ibUT' = ibUT  ls' = ls)"
apply(subst stepS.simps) by auto

lemma stepS_spec_resolve_iff[simp]: 
"cfgs  []  
 resolve pstate (pcOf cfg # map pcOf cfgs)
  
 (pstate, cfg, cfgs, ibT, ibUT, ls) →S (pstate', cfg', cfgs', ibT', ibUT', ls')
 
 (pstate' = update pstate (pcOf cfg # map pcOf cfgs) 
  cfg' = cfg  cfgs' = butlast cfgs  ibT' = ibT  ibUT' = ibUT  ls' = ls)"
apply(subst stepS.simps) by auto

(* *)

lemma stepS_cases[cases pred: stepS, 
 consumes 1, 
 case_names nonspec_normal nonspec_mispred 
            spec_normal spec_mispred spec_Fence spec_resolve]:
assumes "(pstate, cfg, cfgs, ibT, ibUT, ls) →S (pstate', cfg', cfgs', ibT', ibUT', ls')"
obtains 
(* nonspec_normal: *)
"cfgs = []"  
   "¬ is_IfJump (prog!(pcOf cfg))  ¬ mispred pstate [pcOf cfg]"
   "pstate' = pstate"
   "¬ finalB (cfg, ibT, ibUT)"
   "(cfg', ibT', ibUT') = nextB (cfg, ibT, ibUT)"
   "cfgs' = []"
   "ls' = ls  readLocs cfg"
| 
(* nonspec_mispred *)
"cfgs = []" 
   "is_IfJump (prog!(pcOf cfg))" "mispred pstate [pcOf cfg]"
   "pstate' = update pstate [pcOf cfg]"
   "¬ finalM (cfg, ibT, ibUT)"
   "(cfg', ibT', ibUT') = nextB (cfg, ibT, ibUT)"
   "cfg1' ibT1' ibUT1'. (cfg1', ibT1', ibUT1') = nextM (cfg, ibT, ibUT) 
                 cfgs' = [cfg1']"
   "ls' = ls  readLocs cfg"
|
(* spec_normal *)
"cfgs  []" 
   "¬ resolve pstate (pcOf cfg # map pcOf cfgs)"
      "¬ is_IfJump (prog!(pcOf (last cfgs)))  ¬ mispred pstate (pcOf cfg # map pcOf cfgs)"
      "prog!(pcOf (last cfgs))  Fence"
      "pstate' = pstate"
      "¬ is_getInput (prog!(pcOf (last cfgs)))"
      "¬ is_Output (prog!(pcOf (last cfgs)))"
      "cfg' = cfg"
      "ls' = ls  readLocs (last cfgs)"
      "cfg1'. nextB (last cfgs, ibT, ibUT) = (cfg1',ibT',ibUT')
            cfgs' = butlast cfgs @ [cfg1']"
|
(* spec_mispred:  *) 
"cfgs  []"
   "¬ resolve pstate (pcOf cfg # map pcOf cfgs)" 
      "is_IfJump (prog!(pcOf (last cfgs)))" "mispred pstate (pcOf cfg # map pcOf cfgs)"
      "pstate' = update pstate (pcOf cfg # map pcOf cfgs)"
      "¬ finalM (last cfgs, ibT, ibUT)"
      "cfg' = cfg"
      "lcfg' cfg1' ibT1' ibUT1'. 
       nextB (last cfgs, ibT, ibUT) = (lcfg',ibT',ibUT') 
       (cfg1', ibT1', ibUT1') = nextM (last cfgs, ibT, ibUT) 
       cfgs' = butlast cfgs @ [lcfg'] @ [cfg1']" 
      "ls' = ls  readLocs (last cfgs)"
|
(* spec_Fence: *)
"cfgs  []"
   "¬ resolve pstate (pcOf cfg # map pcOf cfgs)"
      "prog!(pcOf (last cfgs)) = Fence"
      "pstate' = pstate"
      "cfg' = cfg"
      "cfgs' = []"
      "ibT' = ibT" 
      "ibUT' = ibUT" 
      "ls' = ls"
|
(* spec_resolve: *)
"cfgs  []"   
   "resolve pstate (pcOf cfg # map pcOf cfgs)"
   "pstate' = update pstate (pcOf cfg # map pcOf cfgs)"
   "cfg' = cfg"
   "cfgs' = butlast cfgs"
   "ls' = ls"
   "ibT' = ibT" 
   "ibUT' = ibUT" 
  using assms by (cases rule: stepS.cases, metis+) 
(* *)
lemma stepS_endPC: "pcOf cfg = endPC  ¬ (pstate, cfg, [], ibT, ibUT, ls) →S ss'"
apply(cases ss') 
apply safe apply(cases rule: stepS_cases, auto) 
  using finalB_endPC apply blast  
  using finalB_endPC apply blast
  using finalB_endPC finalB_imp_finalM by blast

abbreviation
  stepsS :: "configS  configS  bool" (infix "→S*" 55)
  where "x →S* y  star stepS x y"

definition "finalS = final stepS"
lemmas finalS_defs  = final_def finalS_def

lemma stepS_0: "(pstate, Config 0 s, [], ibT, ibUT, ls) →S (pstate, Config 1 s, [], ibT, ibUT, ls)"
using prog_0 apply-apply(rule nonspec_normal) 
using One_nat_def stebB_0 stepB_nextB  
by (auto simp: readLocs_def finalB_def final_def, meson)

lemma stepS_imp_stepB:"(pstate, cfg, [], ibT,ibUT, ls) →S (pstate', cfg', cfgs', ibT',ibUT', ls')  (cfg, ibT,ibUT) →B (cfg', ibT',ibUT')"
  subgoal premises s
    using s apply (cases rule: stepS_cases)
    by (metis finalB_imp_finalM stepB_iff_nextB)+ .

subsubsection "Elimination Rules"

(*step2 elims*)
lemma stepS_Assign2E:
  assumes (ps3, cfg3, cfgs3, ibT3,ibUT3, ls3) →S (ps3', cfg3', cfgs3', ibT3',ibUT3', ls3') 
      and (ps4, cfg4, cfgs4, ibT4,ibUT4, ls4) →S (ps4', cfg4', cfgs4', ibT4',ibUT4', ls4')
      and cfg3 = (Config pc3 (State (Vstore vs3) avst3 h3 p3)) and cfg3' = (Config pc3' (State (Vstore vs3') avst3' h3' p3'))
      and cfg4 = (Config pc4 (State (Vstore vs4) avst4 h4 p4)) and cfg4' = (Config pc4' (State (Vstore vs4') avst4' h4' p4'))
      and cfgs3 = [] and cfgs4 = []
      and prog!pc3 = (x ::= a) and pcOf cfg3 = pcOf cfg4 
    shows cfgs3' = []  cfgs4' = []  
           vs3' = (vs3(x := aval a (stateOf cfg3))) 
           vs4' = (vs4(x := aval a (stateOf cfg4))) 
           pc3' = Suc pc3  pc4' = Suc pc4  ls4' = ls4  readLocs cfg4 
           avst3' = avst3  avst4' = avst4  ls3' = ls3  readLocs cfg3 
           p3 = p3'  p4 = p4'
  using assms apply clarify 
  apply-apply(frule stepS_imp_stepB[of ps3])
  apply(frule stepS_imp_stepB[of ps4])
  apply (drule stepB_AssignE[of _ _ _ _ _ _ pc3 vs3 avst3 h3 p3
                                pc3' vs3' avst3' h3' p3' x a], clarify+)
  apply (drule stepB_AssignE[of _ _ _ _ _ _ pc4 vs4 avst4 h4 p4
                                pc4' vs4' avst4' h4' p4'], clarify+)
  by fastforce+


end (* context Prog_Mispred *)

end