Theory SepInv

(*
 * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
 * Copyright (c) 2022 Apple Inc. All rights reserved.
 *
 * SPDX-License-Identifier: BSD-2-Clause
 *)

(* License: BSD, terms see file ./LICENSE *)

theory SepInv
imports SepCode
begin

(* fixme: temporary hack for compatability - should generalise earlier proofs
   to avoid all the duplication in here *)

definition inv_footprint :: "'a::c_type ptr  heap_assert" where
  "inv_footprint p 
     λs. dom s = {(x,y). x  {ptr_val p..+size_of TYPE('a)}} - s_footprint p"

text ‹
  Like in Separation.thy, these arrows are defined using bsub and esub but
  have an \emph{input} syntax abbreviation with just sub.
  See original comment there for justification.
›

definition
  sep_map_inv :: "'a::c_type ptr  'a ptr_guard  'a  heap_assert" ("_ i_ _" [56,0,51] 56)
where
  "pigv  p ↦⇩g v * inv_footprint p"

notation (input)
  sep_map_inv ("_ i_ _" [56,1000,51] 56)

definition
  sep_map_any_inv :: "'a ::c_type ptr  'a ptr_guard  heap_assert" ("_ i_ -" [56,0] 56)
where
  "pig-  p ↦⇩g - * inv_footprint p"

notation (input)
  sep_map_any_inv ("_ i_ -" [56,0] 56)

definition
  sep_map'_inv :: "'a::c_type ptr  'a ptr_guard  'a  heap_assert" ("_ i_ _" [56,0,51] 56)
where
  "pigv  p ↪⇩g v * inv_footprint p"

notation (input)
  sep_map'_inv ("_ i_ _" [56,1000,51] 56)

definition
  sep_map'_any_inv :: "'a::c_type ptr  'a ptr_guard  heap_assert" ("_ i_ -" [56,0] 56)
where
  "pig-  p ↪⇩g - * inv_footprint p"

notation (input)
  sep_map'_any_inv ("_ i_ -" [56,0] 56)

definition
  tagd_inv :: "'a ptr_guard  'a::c_type ptr  heap_assert" (infix "si" 100)
where
  "g si p  g s p * inv_footprint p"

text ‹----›

lemma sep_map'_g:
  "(p ig v) s  g p"
  unfolding sep_map'_inv_def by (fastforce dest: sep_conjD sep_map'_g_exc)

lemma sep_map'_unfold:
  "(p ig v) = ((p ig v) * sep_true)"
  by (simp add: sep_map'_inv_def sep_map'_def sep_conj_ac)

lemma sep_map'_any_unfold:
  "(i ig -) = ((i ig -) * sep_true)"
  apply(rule ext, simp add: sep_map'_any_inv_def sep_map'_any_def sep_conj_ac)
  apply(rule iffI)
   apply(subst sep_conj_com)
   apply(subst sep_conj_assoc)+
   apply(erule (1) sep_conj_impl)
   apply(clarsimp simp: sep_conj_ac)
   apply(subst (asm) sep_map'_unfold_exc, subst sep_conj_com)
   apply(subst sep_conj_exists, fast)
  apply(subst (asm) sep_conj_com)
  apply(subst (asm) sep_conj_assoc)+
  apply(erule (1) sep_conj_impl)
  apply(subst sep_map'_unfold_exc)
  apply(subst (asm) sep_conj_exists, fast)
  done

lemma sep_map'_conjE1:
  " (P * Q) s; s. P s  (i ig v) s   (i ig v) s"
  by (subst sep_map'_unfold, erule sep_conj_impl, simp+)

lemma sep_map'_conjE2:
  " (P * Q) s; s. Q s  (i ig v) s   (i ig v) s"
  by (subst (asm) sep_conj_com, erule sep_map'_conjE1, simp)

lemma sep_map'_any_conjE1:
  " (P * Q) s; s. P s  (i ig -) s   (i ig -) s"
  by (subst sep_map'_any_unfold, erule sep_conj_impl, simp+)

lemma sep_map'_any_conjE2:
  " (P * Q) s; s. Q s  (i ig -) s   (i ig -) s"
  by (subst (asm) sep_conj_com, erule sep_map'_any_conjE1, simp)

lemma sep_map_any_old:
  "(p ig -) = (λs. v. (p ig v) s)"
  by (simp add: sep_map_inv_def sep_map_any_inv_def sep_map_any_def sep_conj_ac sep_conj_exists)

lemma sep_map'_old:
  "(p ig v) = ((p ig v) * sep_true)"
  by (simp add: sep_map'_inv_def sep_map_inv_def sep_map'_def sep_conj_ac)

lemma sep_map'_any_old:
  "(p ig -) = (λs. v. (p ig v) s)"
  by (simp add: sep_map'_inv_def sep_map'_any_inv_def sep_map'_any_def sep_conj_exists)

lemma sep_map_sep_map' [simp]:
  "(p ig v) s  (p ig v) s"
  unfolding sep_map_inv_def sep_map'_inv_def sep_map'_def
  apply(simp add: sep_conj_ac)
  apply(subst sep_conj_com)
  apply(simp add: sep_conj_assoc sep_conj_impl sep_conj_sep_true)
  done

lemmas guardI = sep_map'_g[OF sep_map_sep_map']

lemma sep_map_anyI [simp]:
  "(p ig v) s  (p ig -) s"
  by (fastforce simp: sep_map_any_inv_def sep_map_inv_def sep_map_any_def sep_conj_ac
                elim: sep_conj_impl)

lemma sep_map_anyD:
  "(p ig -) s  v. (p ig v) s"
  apply(simp add: sep_map_any_def sep_map_any_inv_def sep_map_inv_def sep_conj_ac)
  apply(subst (asm) sep_conj_com)
  apply(clarsimp simp: sep_conj_exists sep_conj_ac)
  done

lemma sep_conj_mapD:
  "((i ig v) * P) s  (i ig v) s  ((i ig -) * P) s"
  by (simp add: sep_conj_impl sep_map'_conjE2 sep_conj_ac)

lemma sep_map'_ptr_safe:
  "(p ig (v::'a::mem_type)) (lift_state (h,d))  ptr_safe p d"
  unfolding sep_map'_inv_def
  apply(rule sep_map'_ptr_safe_exc)
  apply(subst sep_map'_unfold_exc)
  apply(fastforce elim: sep_conj_impl)
  done

lemmas sep_map_ptr_safe = sep_map'_ptr_safe[OF sep_map_sep_map']

lemma sep_map_any_ptr_safe:
  fixes p::"'a::mem_type ptr"
  shows "(p ig -) (lift_state (h, d))  ptr_safe p d"
  by (blast dest: sep_map_anyD intro: sep_map_ptr_safe)

lemma sep_heap_update':
  "(g si p * (p ig v * P)) (lift_state (h,d)) 
      P (lift_state (heap_update p (v::'a::mem_type) h,d))"
  apply(rule sep_heap_update'_exc [where g=g])
  apply(unfold tagd_inv_def)
  apply(subst (asm) sep_conj_assoc)+
  apply(erule (1) sep_conj_impl)
  apply(subst (asm) sep_map_inv_def)
  apply(simp add: sep_conj_ac)
  apply(drule sep_conjD, clarsimp)
  apply(rule sep_implI, clarsimp)
  apply(drule sep_implD)
  apply(drule_tac x="s0 ++ s'" in spec)
  apply(simp add: map_disj_com map_add_disj)
  apply(clarsimp simp: map_disj_com)
  apply(erule notE)
  apply(erule (1) sep_conjI)
   apply(simp add: map_disj_com)
  apply(subst map_add_com; simp)
  done

lemma tagd_g:
  "(g si p * P) s  g p"
  by (auto simp: tagd_inv_def tagd_def dest!: sep_conjD elim: s_valid_g)

lemma tagd_ptr_safe:
  "(g si p * sep_true) (lift_state (h,d))  ptr_safe p d"
  apply(rule tagd_ptr_safe_exc)
  apply(unfold tagd_inv_def)
  apply(subst (asm) sep_conj_assoc)
  apply(erule (1) sep_conj_impl)
  apply simp
  done

lemma sep_map_tagd:
  "(p ig (v::'a::mem_type)) s  (g si p) s"
  apply(unfold sep_map_inv_def tagd_inv_def)
  apply(erule sep_conj_impl)
   apply(erule sep_map_tagd_exc)
  apply assumption
  done

lemma sep_map_any_tagd:
  "(p ig -) s  (g si (p::'a::mem_type ptr)) s"
  by (clarsimp dest!: sep_map_anyD, erule sep_map_tagd)

lemma sep_heap_update:
  " (p ig - * (p ig v * P)) (lift_state (h,d))  
      P (lift_state (heap_update p (v::'a::mem_type) h,d))"
  by (force intro: sep_heap_update' dest: sep_map_anyD sep_map_tagd
            elim: sep_conj_impl)

lemma sep_heap_update_global':
  "(g si p * R) (lift_state (h,d)) 
      ((p ig v) * R) (lift_state (heap_update p (v::'a::mem_type) h,d))"
  by (rule sep_heap_update', erule sep_conj_sep_conj_sep_impl_sep_conj)

lemma sep_heap_update_global:
  "(p ig - * R) (lift_state (h,d)) 
      ((p ig v) * R) (lift_state (heap_update p (v::'a::mem_type) h,d))"
  by (fast intro: sep_heap_update_global' sep_conj_impl sep_map_any_tagd)

lemma sep_heap_update_global_super_fl_inv:
  " (p ig u * R) (lift_state (h,d));
      field_lookup (typ_info_t TYPE('b::mem_type)) f 0 = Some (t,n);
      export_uinfo t = (typ_uinfo_t TYPE('a))  
      ((p ig update_ti_t t (to_bytes_p v) u) * R)
      (lift_state (heap_update (Ptr &(pf)) (v::'a::mem_type) h,d))"
  apply(unfold sep_map_inv_def)
  apply(simp only: sep_conj_assoc)
  apply(erule (2) sep_heap_update_global_super_fl)
  done

lemma sep_map'_inv:
  "(p ig v) s  (p ↪⇩g v) s"
  apply(unfold sep_map'_inv_def)
  apply(subst sep_map'_unfold_exc)
  apply(erule (1) sep_conj_impl, simp)
  done

lemma sep_map'_lift:
  "(p ig (v::'a::mem_type)) (lift_state (h,d))  lift h p = v"
  apply(drule sep_map'_inv)
  apply(erule sep_map'_lift_exc)
  done

lemma sep_map_lift:
  "((p::'a::mem_type ptr) ig -) (lift_state (h,d)) 
        (p ig lift h p) (lift_state (h,d))"
  apply(frule sep_map_anyD)
  apply clarsimp
  apply(frule sep_map_sep_map')
  apply(drule sep_map'_lift)
  apply simp
  done

lemma sep_map_lift_wp:
  " v. (p ig v * (p ig v * P v)) (lift_state (h,d)) 
       P (lift h (p::'a::mem_type ptr)) (lift_state (h,d))"
  apply clarsimp
  subgoal for v
    apply(subst sep_map'_lift [where g=g and d=d])
     apply(subst sep_map'_inv_def)
     apply(subst sep_map'_def)
     apply(subst sep_conj_assoc)+
     apply(subst sep_conj_com[where P=sep_true])
     apply(subst sep_conj_assoc [symmetric])
     apply(erule sep_conj_impl)
      apply(simp add: sep_map_inv_def)
     apply simp
    apply(rule sep_conj_impl_same [where P="p ig v" and Q="P v"])
    apply(unfold sep_map_inv_def)
    apply(erule (2) sep_conj_impl)
    done
  done

lemma sep_map'_anyI [simp]:
  "(p ig v) s  (p ig -) s"
  apply(unfold sep_map'_inv_def sep_map'_any_inv_def)
  apply(erule sep_conj_impl)
   apply(erule sep_map'_anyI_exc)
  apply assumption
  done

lemma sep_map'_anyD:
  "(p ig -) s  v. (p ig v) s"
  unfolding sep_map'_inv_def sep_map'_any_inv_def sep_map'_any_def
  by (clarsimp simp: sep_conj_exists)

lemma sep_map'_lift_rev:
  " lift h p = (v::'a::mem_type); (p ig -) (lift_state (h,d))  
      (p ig v) (lift_state (h,d))"
  by (fastforce dest: sep_map'_anyD simp: sep_map'_lift)

lemma sep_map'_any_g:
  "(p ig -) s  g p"
  by (blast dest: sep_map'_anyD intro: sep_map'_g)

lemma any_guardI:
  "(p ig -) s  g p"
  by (drule sep_map_anyD) (blast intro: guardI)

lemma sep_map_sep_map_any:
  "(p ig v) s  (p ig -) s"
  by (rule sep_map_anyI)


lemma sep_lift_exists:
  fixes p :: "'a::mem_type ptr"
  assumes ex: "((λs. v. (p ig v) s  P v s) * Q) (lift_state (h,d))"
  shows "(P (lift h p) * Q) (lift_state (h,d))"
proof -
  from ex obtain v where "((λs. (p ig v) s  P v s) * Q)
      (lift_state (h,d))"
    by (subst (asm) sep_conj_exists, clarsimp)
  thus ?thesis
    by (force simp: sep_map'_lift sep_conj_ac
              dest: sep_map'_conjE2 dest!: sep_conj_conj)
qed

lemma sep_map_dom:
  "(p ig (v::'a::c_type)) s  dom s = {(a,b). a  {ptr_val p..+size_of TYPE('a)}}"
  unfolding sep_map_inv_def
  by (drule sep_conjD, clarsimp)
     (auto dest!: sep_map_dom_exc elim: s_footprintD simp: inv_footprint_def)

lemma sep_map'_dom:
  "(p ig (v::'a::mem_type)) s  (ptr_val p,SIndexVal)  dom s"
  unfolding sep_map'_inv_def
  by (drule sep_conjD, clarsimp) (drule sep_map'_dom_exc, clarsimp)

lemma sep_map'_inj:
  " (p ig (v::'a::c_type)) s; (p ih v') s   v=v'"
  by (drule sep_map'_inv)+ (drule (2) sep_map'_inj_exc)

lemma ptr_retyp_sep_cut':
  fixes p::"'a::mem_type ptr"
  assumes sc: "(sep_cut' (ptr_val p) (size_of TYPE('a)) * P)
      (lift_state (h,d))" and "g p"
  shows "(g si p * P) (lift_state (h,(ptr_retyp p d)))"
proof -
  from sc
  obtain s0 and s1
    where "s0  s1" and "lift_state (h,d) = s1 ++ s0"
      and "P s1" and d: "dom s0 = {(a,b). a  {ptr_val p..+size_of TYPE('a)}}"
      and k: "dom s0  dom_s d"
    by (auto dest!: sep_conjD sep_cut'_dom simp: dom_lift_state_dom_s [where h=h,symmetric])
  moreover from this
  have "lift_state (h, ptr_retyp p d) = s1 ++ lift_state (h, ptr_retyp p d) |` (dom s0)"
    apply -
    apply(rule ext)
    subgoal for x
      apply(cases "x  dom s0")
       apply(cases "x  dom s1")
        apply(fastforce simp: map_disj_def)
       apply(subst map_add_com)
        apply(fastforce simp: map_disj_def)
       apply(clarsimp simp: map_add_def split: option.splits)
      apply(cases x, clarsimp)
      apply(clarsimp simp: lift_state_ptr_retyp_d merge_dom2)
      done
    done
  moreover have "g p" by fact
  with d k have "(g si p) (lift_state (h, ptr_retyp p d) |` dom s0)"
    apply(clarsimp simp: lift_state_ptr_retyp_restrict sep_conj_ac tagd_inv_def)
    apply(rule sep_conjI [where s0="lift_state (h,d) |` ({(a, b). a  {ptr_val p..+size_of TYPE('a)}} - s_footprint p)"])
       apply(fastforce simp: inv_footprint_def)
      apply(erule ptr_retyp_tagd_exc[where h=h])
     apply(fastforce  simp: map_disj_def)
    apply(subst map_add_comm[of "lift_state (h, ptr_retyp p empty_htd)"])
     apply force
    apply(rule ext)
    apply(clarsimp simp: map_add_def split: option.splits)
    by (metis (mono_tags) Diff_iff dom_ptr_retyp_empty_htd non_dom_eval_eq restrict_in_dom restrict_out)
  ultimately show ?thesis
    by (metis restrict_map_on_disj sep_conjI)
qed

lemma ptr_retyp_sep_cut'_wp:
  " (sep_cut' (ptr_val p) (size_of TYPE('a)) * (g si p * P))
      (lift_state (h,d)); g (p::'a::mem_type ptr) 
       P (lift_state (h,(ptr_retyp p d)))"
  by (rule sep_conj_impl_same [where P="g si p" and Q=P]) (simp add: ptr_retyp_sep_cut')

lemma tagd_dom:
  "(g si p) s  dom s = {(a,b). a  {ptr_val (p::'a::c_type ptr)..+size_of TYPE('a)}}"
  unfolding tagd_inv_def
  by (drule sep_conjD, clarsimp)
     (auto simp: inv_footprint_def dest!: tagd_dom_exc elim: s_footprintD)

lemma tagd_dom_p:
  "(g si p) s  (ptr_val (p::'a::mem_type ptr),SIndexVal)  dom s"
  by (drule tagd_dom) clarsimp


end