Theory Quicksort_Ex

(*
 * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
 * Copyright (c) 2022 Apple Inc. All rights reserved.
 *
 * SPDX-License-Identifier: BSD-2-Clause
 *)

(*
 * Verifying quicksort implementation using AutoCorres!
 *)
theory Quicksort_Ex
imports
  "AutoCorres2_Main.AutoCorres_Main"
  "HOL-Library.Multiset"
begin


declare creturn_def [vcg_simp]

install_C_file "quicksort.c"
autocorres "quicksort.c"

context quicksort_all_corres begin

thm partition_body_def
thm quicksort_body_def

thm partition'_def
thm quicksort'.simps
end

(* Some rules for pointer addition *)

(* fixme: move *)
lemma ptr_add_assoc [simp]:
  "p +p (i + j) = p +p i +p j"
  by (simp add: CTypesDefs.ptr_add_def distrib_right)

(* fixme: move *)
lemma ptr_add_commute [simp]:
  "p +p i +p j = p +p j +p i"
  by (metis ptr_add_assoc add.commute)


(*
 * Array validity definitions
 *)

definition
  array_loc_valid :: "word32 ptr  nat  bool"
where
  "array_loc_valid a n 
   (unat (ptr_val a) + size_of TYPE(word32) * n  2 ^ len_of TYPE(addr_bitsize))"

fun
  array_elems_valid :: "lifted_globals  word32 ptr  nat  bool"
where
  "array_elems_valid s a 0 = True" |
  "array_elems_valid s a (Suc n) = (ptr_valid (heap_typing s) a  array_elems_valid s (a +p 1) n)"

lemma array_all_elems_valid:
  (* equivalent characterisation of array validity *)
  "array_elems_valid s a n = (m. m < n  ptr_valid (heap_typing s) (a +p int m))"
  apply (induct n arbitrary: a)
   apply simp
  apply (case_tac "n = 0")
   apply simp
  apply (rule iffI)
   apply clarsimp
   apply (case_tac "m = 0")
    apply simp
   apply (case_tac "m = n")
    apply clarsimp
    apply (drule_tac x = "n - 1" in spec)
    apply (simp add: CTypesDefs.ptr_add_def)
   apply (drule_tac x = "m - 1" in spec)
   apply (simp add: CTypesDefs.ptr_add_def)
  apply simp
  apply (frule_tac x = "0" in spec)
  apply clarsimp
  apply (frule_tac x = "m + 1" in spec)
  apply simp
  done

definition
  is_array :: "lifted_globals  word32 ptr  nat  bool"
where
  "is_array s a n  (array_loc_valid a n  array_elems_valid s a n)"

(* Necessary condition for many pointer comparison lemmas *)
definition array_not_at_mem_end :: "word32 ptr  nat  bool"
where
  "array_not_at_mem_end a n 
   (unat (ptr_val a) + size_of TYPE(word32) * n < 2 ^ len_of TYPE(addr_bitsize))"
  (* same as "array_loc_valid a n" but excluding equality *)


(* Some obvious but useful corollaries *)

lemma array_valid_elem:
  " is_array s a n; m < n   ptr_valid (heap_typing s) (a +p int m)"
  by (metis is_array_def array_all_elems_valid)

lemma array_valid_elem2:
  " is_array s a (unat n); m < n   ptr_valid (heap_typing s) (a +p uint m)"
  by (metis array_valid_elem uint_nat word_less_nat_alt)

lemma empty_array_is_array:
  "is_array s a 0"
  apply (insert unat_lt2p[of "ptr_val a"])
  apply (simp add: is_array_def array_loc_valid_def)
  done

lemma empty_array_not_at_mem_end:
  "array_not_at_mem_end a 0"
  apply (insert unat_lt2p[of "ptr_val a"])
  apply (simp add: array_not_at_mem_end_def)
  done

lemma subarray_not_at_mem_end1:
  " array_loc_valid a n; m < n   array_not_at_mem_end a m"
  by (simp add: array_loc_valid_def array_not_at_mem_end_def)

lemma subarray_not_at_mem_end2:
  " is_array s a n; m < n   array_not_at_mem_end a m"
  by (clarsimp simp: is_array_def subarray_not_at_mem_end1)

lemma updated_array_elems_valid [simp]:
  "array_elems_valid (heap_w32_update f s) a n = array_elems_valid s a n"
  by (induct n arbitrary: a, auto)

lemma updated_array_is_array [simp]:
  "is_array (heap_w32_update f s) a n = is_array s a n"
  by (simp add: is_array_def)


(* Some word arithmetic *)

(* fixme: move *)
lemma unat_plus_weak:
  "(x::addr)  x + y  unat (x + y) = unat x + unat y"
  by (simp add: unat_plus_simple)

lemma unat_linear_over_array_loc:
  "array_not_at_mem_end a n 
   unat (ptr_val a + of_nat (size_of TYPE(word32)) * (of_nat n)) =
   unat (ptr_val a) + size_of TYPE(word32) * n"
  apply (simp add: array_not_at_mem_end_def)
  apply (subgoal_tac "unat (4 * (of_nat n)) = 4 * n")
   apply (subst unat_plus_weak)
    apply (subst no_olen_add_nat)
    apply (rule_tac s = "4 * n" and t = "unat (4 * of_nat n)" in subst)
     apply (erule sym)
    apply simp+
  apply (simp add: unat_mult_simple unat_of_nat_eq)
  done

lemma unat_linear_over_array_loc2:
  "array_not_at_mem_end a n 
   unat (ptr_val a + 4 * (of_nat n)) = unat (ptr_val a) + 4 * n"
  by (frule unat_linear_over_array_loc, simp)


(* Pointer inequalities *)

lemma array_no_wrap:
  " array_loc_valid a n; m < n   a  a +p int m"
  apply (clarsimp simp: array_loc_valid_def CTypesDefs.ptr_add_def
                        ptr_le_def' ptr_le_def)
  apply (subst no_olen_add_nat)
  apply (subst unat_mult_simple)
   apply (subst unat_of_nat_eq, simp+)+
  done

lemma array_no_wrap2:
  " array_loc_valid a (unat n); m < n   a  a +p uint m"
  by (metis array_no_wrap uint_nat word_less_nat_alt)

lemma array_loc_mono:
  " array_not_at_mem_end a n; m  n   a +p int m  a +p int n"
  apply (clarsimp simp: array_not_at_mem_end_def CTypesDefs.ptr_add_def
                        ptr_le_def' ptr_le_def)
  apply (rule word_plus_mono_right)
   apply (rule word_mult_le_mono1)
     apply (rule PackedTypes.of_nat_mono_maybe_le, simp+)
   apply (subst unat_of_nat_eq, simp+)
  apply (subst no_olen_add_nat)
  apply (subst unat_mult_simple)
   apply (subst unat_of_nat_eq, simp+)+
  done

lemma array_loc_mono2:
  " array_not_at_mem_end a (unat n); m  n   a +p uint m  a +p uint n"
  by (metis array_loc_mono uint_nat word_le_nat_alt)

lemma array_loc_strict_mono:
  " array_not_at_mem_end a n; m < n   a +p int m < a +p int n"
  apply (clarsimp simp: array_not_at_mem_end_def CTypesDefs.ptr_add_def
                        ptr_less_def' ptr_less_def)
  apply (rule word_plus_strict_mono_right)
   apply (rule word_mult_less_mono1)
     apply (subst of_nat_mono_maybe, simp+)
   apply (subst unat_of_nat_eq, simp+)
  apply (subst no_olen_add_nat)
  apply (subst unat_mult_simple)
   apply (subst unat_of_nat_eq, simp+)+
  done

lemma array_loc_strict_mono2:
  " array_not_at_mem_end a (unat n); m < n   a +p uint m < a +p uint n"
  by (metis array_loc_strict_mono uint_nat word_less_nat_alt)


(* Concatenation lemmas *)

lemma array_concat_elems_valid:
  "array_elems_valid s a (n + m) =
   (array_elems_valid s a n  array_elems_valid s (a +p int n) m)"
  apply (subst array_all_elems_valid)+
  apply (rule iffI)
   apply clarsimp
   apply (frule_tac x = "n + ma" in spec)
   apply simp
  apply clarsimp
  apply (case_tac "ma < n")
   apply simp
  apply (frule_tac x = "ma - n" and
                   P = "λma. (ma < m  ptr_valid (heap_typing s) (a +p int ma +p int n))"
                   in spec)
  apply (subgoal_tac "ma - n < m")
   apply (simp add: CTypesDefs.ptr_add_def)
  apply simp
  done

lemma is_array_concat:
  " m = 0  array_not_at_mem_end a n  
   is_array s a (n + m) = (is_array s a n  is_array s (a +p int n) m)"
  apply (erule disjE)
   apply (simp add: empty_array_is_array)
  apply (frule unat_linear_over_array_loc)
  apply (subgoal_tac "array_loc_valid a n")
   apply (simp add: is_array_def array_loc_valid_def
                    array_concat_elems_valid add.assoc conj_commute)
  apply (simp add: array_loc_valid_def array_not_at_mem_end_def)
  done

lemma is_array_concat2:
  " m  n; array_not_at_mem_end a m  m = n  
   is_array s a n = (is_array s a m  is_array s (a +p int m) (n - m))"
  apply (subst diff_add_inverse[symmetric, where n = "m"])
  apply (subst diff_add_assoc, assumption)
  apply (rule is_array_concat, force)
  done

lemma subarray1_is_array:
  " is_array s a n; m  n   is_array s a m"
  apply (case_tac "m = n", simp)
  apply (frule_tac m = "m" in subarray_not_at_mem_end2, simp)
  apply (simp add: is_array_concat2)
  done

lemma subarray2_is_array:
  " is_array s a n; m  n; array_not_at_mem_end a m  m = n  
   is_array s (a +p int m) (n - m)"
  by (simp add: is_array_concat2)

lemma concat_is_array:
  " is_array s a m; is_array s (a +p int m) (n - m);
     m  n; array_not_at_mem_end a m  m = n  
   is_array s a n"
  by (subst is_array_concat2[where m = "m"], simp_all)


(* Array contents, definitions and lemmas *)

primrec
  the_array :: "lifted_globals  word32 ptr  nat  word32 list"
where
  "the_array s a 0 = []" |
  "the_array s a (Suc n) = (heap_w32 s a) # (the_array s (a +p 1) n)"

lemma the_array_length [simp]:
  "length (the_array s a n) = n"
  by (induct n arbitrary: a, auto)

lemma the_array_elem:
  "m < n  the_array s a n ! m = heap_w32 s (a +p int m)"
  apply (induct n arbitrary: a m)
   apply simp
  apply (case_tac "m = 0")
   apply (auto simp: CTypesDefs.ptr_add_def)
  done

lemma the_arrays_equal:
  "(the_array s a n = the_array s' a' n) =
   (m < n. heap_w32 s (a +p int m) = heap_w32 s' (a' +p int m))"
  by (simp add: list_eq_iff_nth_eq the_array_elem)

lemma the_array_concat:
  "the_array s a (n + m) = the_array s a n @ the_array s (a +p int n) m"
  by (induct n arbitrary: a m, auto)

lemma the_array_concat2:
  "m  n 
   the_array s a n = the_array s a m @ the_array s (a +p int m) (n - m)"
  apply (subst diff_add_inverse[symmetric, where n = "m"])
  apply (subst diff_add_assoc)
  apply (simp_all only: the_array_concat)
  done


(* Pointer simplification rules *)

(* fixme: move *)
lemma word32_mult_cancel_right:
  fixes a :: "addr" and b :: "addr" and c :: "addr"
  shows
  " unat a * unat c < 2 ^ len_of TYPE(addr_bitsize);
     unat b * unat c < 2 ^ len_of TYPE(addr_bitsize); c  0  
   (a * c = b * c) = (a = b)"
  apply (rule iffI)
   apply simp_all
  apply (subgoal_tac "a = a * c div c")
   apply simp
   apply (rule sym)
   apply (rule Word.word_div_mult[symmetric, where 'a = addr_bitsize],
          unat_arith, simp)+
  done

lemma ptr_offsets_eq [simp]:
  fixes i :: "nat" and j :: "nat" and a :: "word32 ptr"
  shows
  " i * size_of TYPE(word32) < 2 ^ len_of TYPE(addr_bitsize);
     j * size_of TYPE(word32) < 2 ^ len_of TYPE(addr_bitsize)  
   (a +p int i = a +p int j) = (i = j)"
  by (simp add: CTypesDefs.ptr_add_def word32_mult_cancel_right
                unat_of_nat_eq of_nat_inj)

lemma ptr_offsets_eq2 [simp]:
  fixes i :: "addr" and j :: "addr" and a :: "word32 ptr"
  shows
  " unat i * size_of TYPE(word32) < 2 ^ len_of TYPE(addr_bitsize);
     unat j * size_of TYPE(word32) < 2 ^ len_of TYPE(addr_bitsize)  
   (a +p uint i = a +p uint j) = (i = j)"
  by (metis  Word.of_nat_unat len_of_addr_card ptr_offsets_eq size_of_words(3)
  ucast_id uint_nat)


(* Array update simplification *)

(* fixme: move? *)
lemma trivial_heap_update [simp]:
  "heap_w32_update (λx. x) s = s"
  by simp

lemma the_updated_array:
  " is_array s p n; m < n  
   the_array (heap_w32_update (λh q. if q = p +p int m then x else f h q) s) p n =
   (the_array (heap_w32_update f s) p n)[m := x]"
  by (auto simp: is_array_def array_loc_valid_def the_array_elem
           intro: nth_equalityI)


lemma multiset_of_cycle:
  (* Courtesy of Dave G *)
  " i < length ls; j < length ls; k < length ls; i = k  ls ! i = ls ! j  
   mset (ls[i := (ls ! j), j := ls ! k, k := ls ! i]) = mset ls"
  apply (subst (2) mset_swap[symmetric, where i = j and j = i], assumption+)
  apply (subst (2) mset_swap[symmetric, where i = j and j = k], simp+)
  apply (clarsimp simp: nth_list_update)
  apply (metis list_update_overwrite list_update_swap)
  done


(* Defining sanity of program, i.e. that only the array can change,
   and some related lemmas *)

definition unmodified_outside_range ::
             "lifted_globals  lifted_globals  word32 ptr  nat  bool"
where
  "unmodified_outside_range s s' a n 
   (p. (p < a  (array_not_at_mem_end a n  p  a +p int n)) 
        (ptr_valid (heap_typing s') p = ptr_valid (heap_typing s) p  heap_w32 s' p = heap_w32 s p))"

lemma unmodified_outside_range_refl [simp]:
  "unmodified_outside_range s s a n"
  by (simp add: unmodified_outside_range_def)

lemma unmodified_outside_range_trans:
  " unmodified_outside_range s1 s2 a n; unmodified_outside_range s2 s3 a n 
    unmodified_outside_range s1 s3 a n"
  by (simp add: unmodified_outside_range_def)

lemma unmodified_outside_empty_range:
  "unmodified_outside_range s s' p 0
    p. (ptr_valid (heap_typing s') p = ptr_valid (heap_typing s) p 
            heap_w32 s' p = heap_w32 s p)"
  apply (clarsimp simp: unmodified_outside_range_def empty_array_not_at_mem_end)
  apply (case_tac "pa < p", simp+)
  done

lemma unmodified_outside_empty_range2:
  " unmodified_outside_range s s' p 0; array_loc_valid a n 
    unmodified_outside_range s s' a n"
  apply (clarsimp simp: unmodified_outside_range_def empty_array_not_at_mem_end)
  apply (case_tac "pa < p", simp+)
  done

lemma unmodified_outside_subrange1:
  " array_loc_valid a n; unmodified_outside_range s s' a m; m  n 
    unmodified_outside_range s s' a n"
  apply (unfold unmodified_outside_range_def)
  apply clarsimp
  apply (subgoal_tac "array_not_at_mem_end a m  a +p int m  p", simp)
  apply (simp add: array_not_at_mem_end_def)
  apply (rule_tac y = "a +p int n" in order_trans)
   apply (simp add: array_loc_mono array_not_at_mem_end_def)
  apply assumption
  done

lemma unmodified_outside_subrange2:
  " array_loc_valid a n;
     unmodified_outside_range s s' (a +p int m) (n - m); m  n 
    unmodified_outside_range s s' a n"
  apply (case_tac "m = n", simp)
   apply (rule_tac p = "a +p int n" in unmodified_outside_empty_range2, assumption+)
  apply (unfold unmodified_outside_range_def)
  apply clarsimp
  apply (simp add: ptr_add_assoc[symmetric])
  apply (rule conjI)
   apply clarsimp
   apply (frule_tac z = "a +p int m" in order_less_le_trans)
    apply (simp add: array_no_wrap)
   apply simp
  apply (auto simp: array_not_at_mem_end_def unat_linear_over_array_loc2)
  done


(* Sanity in terms of arrays not changing *)

lemma is_still_array:
  " unmodified_outside_range s s' a n;
     (array_not_at_mem_end a' n'  a' +p int n'  a) 
     (array_not_at_mem_end a n  a +p int n  a')  n = 0;
     is_array s a' n'   is_array s' a' n'"
  apply (clarsimp simp: unmodified_outside_range_def is_array_def
                        array_all_elems_valid)
  apply (drule_tac x = "m" in spec, clarsimp)
  apply (drule_tac R = "ptr_valid (heap_typing s') (a' +p int m)" in disjE, simp_all)
   apply clarsimp
   apply (subgoal_tac "a' +p int m < a", simp)
   apply (rule_tac y = "a' +p int n'" in less_le_trans)
    apply (rule array_loc_strict_mono, assumption+)
  apply (drule_tac R = "ptr_valid (heap_typing s') (a' +p int m)" in disjE, simp_all)
   apply clarsimp
   apply (subgoal_tac "a +p int n  a' +p int m", simp)
   apply (erule_tac y = "a'" in order_trans)
   apply (rule_tac n = "n'" in array_no_wrap, assumption+)
  apply (case_tac "a' +p int m < a", simp)
  apply (simp add: empty_array_not_at_mem_end)
  done

lemma the_same_array:
  " unmodified_outside_range s s' a n; array_loc_valid a' n';
     (array_not_at_mem_end a' n'  a' +p int n'  a) 
     (array_not_at_mem_end a n  a'  a +p int n)  n = 0  
   the_array s' a' n' = the_array s a' n'"
  apply (clarsimp simp: unmodified_outside_range_def the_arrays_equal)
  apply (drule_tac R = "heap_w32 s' (a' +p int m) = heap_w32 s (a' +p int m)"
                in disjE)
    apply simp_all
   apply clarsimp
   apply (subgoal_tac "a' +p int m < a", simp)
   apply (rule_tac y = "a' +p int n'" in less_le_trans)
    apply (rule array_loc_strict_mono, assumption+)
  apply (drule_tac R = "heap_w32 s' (a' +p int m) = heap_w32 s (a' +p int m)"
                in disjE)
    apply simp_all
   apply (subgoal_tac "a +p int n  a' +p int m", simp)
   apply clarsimp
   apply (erule_tac y = "a'" in order_trans)
   apply (rule_tac n = "n'" in array_no_wrap, assumption+)
  apply (case_tac "a' +p int m < a", simp)
  apply (simp add: empty_array_not_at_mem_end)
  done


(*
 * Proof of partition function!
 *)

definition partitioned
where
  "partitioned s a n pivot_idx 
   (i. i < n  (i < pivot_idx  heap_w32 s (a +p int i) < heap_w32 s (a +p int pivot_idx)))"

lemmas runs_to_whileLoop2 =  runs_to_whileLoop_res' [split_tuple C and B arity: 2]

lemma (in ts_definition_partition) partition_correct:
  "is_array s0 a (unat n)  n > 0 
        partition' a n  s0
         λrv s. r. rv = Result r  is_array s a (unat n) 
                mset (the_array s a (unat n)) = mset (the_array s0 a (unat n)) 
                r < n  partitioned s a (unat n) (unat r) 
                unmodified_outside_range s0 s a (unat n) "
  apply (unfold partition'_def fun_upd_apply)
  apply runs_to_vcg
  apply simp
  apply (rule runs_to_whileLoop2 [where
         I = "λ(i, pivot_idx) s. is_array s a (unat n) 
                                 mset (the_array s a (unat n)) =
                                 mset (the_array s0 a (unat n)) 
                                 i  n  pivot_idx < i 
                                 partitioned s a (unat i) (unat pivot_idx) 
                                 unmodified_outside_range s0 s a (unat n)" and
         R = "measure (λ((i, pivot_idx), s). unat (n - i))"])
  subgoal by simp
  subgoal 
    apply simp
    apply unat_arith
    by (simp add: partitioned_def)
  subgoal
    by auto
  subgoal for i pivot_index s
    apply runs_to_vcg
    subgoal by (erule_tac n = "n" in array_valid_elem2, unat_arith)
    subgoal by (rule_tac n = "n" in array_valid_elem2, assumption+)
    subgoal by (erule_tac n = "n" in array_valid_elem2, unat_arith)
    subgoal
      apply (simp add: o_def)
      apply (subst uint_nat, subst the_updated_array, assumption, unat_arith)+
      apply (clarsimp simp: is_array_def array_loc_valid_def)
      apply (intro conjI impI)
       apply (subst (asm) ptr_offsets_eq2, simp, unat_arith, simp, unat_arith)
       apply simp
      apply (simp only: uint_nat)
      apply (subst the_array_elem[symmetric, where n = "unat n"], unat_arith)+
      apply (subst multiset_of_cycle)
          apply (simp, unat_arith)
         apply (simp, unat_arith)
        apply (simp, unat_arith)
       apply simp+
      done
    subgoal
      by unat_arith
    subgoal
      by unat_arith
    subgoal unfolding partitioned_def
      apply clarsimp
      apply (clarsimp simp: is_array_def array_loc_valid_def uint_nat simp del: Word.of_nat_unat)
      apply (intro conjI impI allI)
                apply simp
               apply (simp add: CTypesDefs.ptr_add_def)
              apply (simp add: CTypesDefs.ptr_add_def )
             apply (simp add: CTypesDefs.ptr_add_def)
            apply unat_arith
           apply (simp add: CTypesDefs.ptr_add_def)
          apply (subst (asm) (2) ptr_offsets_eq)
            apply (simp, unat_arith)
           apply (simp, unat_arith)
          apply (simp, unat_arith)
         apply (subst (asm) ptr_offsets_eq, simp, unat_arith, simp, unat_arith)+
         apply simp
        apply (subst (asm) (3) ptr_offsets_eq)
          apply (simp, unat_arith)
         apply (simp, unat_arith)
        apply (subst (asm) ptr_offsets_eq, simp, unat_arith, simp, unat_arith)+
        apply clarsimp
        apply (drule_tac x = "unat (pivot_index + 1)" in spec)
        apply (subgoal_tac "unat (pivot_index + 1) < unat i")
         apply clarsimp
         apply (subgoal_tac "¬ unat (pivot_index + 1) < unat pivot_index")
          apply (subgoal_tac "¬ unat i < unat (pivot_index + 1)", simp)
          apply simp
         apply unat_arith
        apply unat_arith
       apply (subst (asm) (4) ptr_offsets_eq)
         apply (simp, unat_arith)
        apply (simp, unat_arith)
       apply simp
      apply (subst (asm) ptr_offsets_eq, simp, unat_arith, simp, unat_arith)+
      apply clarsimp
      subgoal for i'
        apply (drule_tac x = "i'" in spec)
        apply (subgoal_tac "i' < unat i")
         apply clarsimp
         apply (subgoal_tac "(i' < unat pivot_index) = (i' < unat (pivot_index + 1))", simp)
         apply unat_arith
        apply unat_arith
        done
      done
    subgoal
      apply (clarsimp simp: is_array_def unmodified_outside_range_def)
      apply (intro conjI impI)
               apply (subst (asm) ptr_offsets_eq2)
                 apply simp+
             apply (subgoal_tac "a  a +p uint pivot_index", simp)
             apply (rule_tac n = "n" in array_no_wrap2, simp, unat_arith)
            apply clarsimp
            apply (subgoal_tac "a +p uint pivot_index < a +p uint n", simp add: uint_nat del: Word.of_nat_unat)
            apply (erule_tac n = "n" in array_loc_strict_mono2, unat_arith)
           apply (subgoal_tac "a  a +p uint (pivot_index + 1)", simp)
           apply (erule_tac n = "n" in array_no_wrap2, unat_arith)
          apply clarsimp
          apply (subgoal_tac "a +p uint i < a +p uint n", simp add: uint_nat del: Word.of_nat_unat)
          apply (rule_tac n = "n" in array_loc_strict_mono2, assumption+)
         apply (subgoal_tac "a  a +p uint i", simp)
         apply (rule_tac n = "n" in array_no_wrap2, assumption+)
        apply clarsimp
        apply (subgoal_tac "a +p uint i < a +p uint n", simp add: uint_nat del: Word.of_nat_unat)
        apply (rule_tac n = "n" in array_loc_strict_mono2, assumption+)
       apply (subgoal_tac "a  a +p uint (pivot_index + 1)", simp)
       apply (erule_tac n = "n" in array_no_wrap2, unat_arith)
      apply clarsimp
      apply (subgoal_tac "a +p uint (pivot_index + 1) < a +p uint n", simp add: uint_nat del: Word.of_nat_unat)
      apply (erule_tac n = "n" in array_loc_strict_mono2, unat_arith)
      done
    subgoal
      by unat_arith
    subgoal
      by unat_arith
    subgoal
      by unat_arith
    subgoal
      unfolding partitioned_def
      apply clarsimp
      subgoal for i'
          apply (case_tac "i' = unat i")
          apply (simp, unat_arith)
           apply (subgoal_tac "i' < unat i", simp)
        apply unat_arith
        done
      done
    subgoal by unat_arith
    done
  done

(* Induction rule used for quicksort proof *)
lemma word_strong_induct[case_names Ind]:
  "(n. (k < n. P k)  P n)  P (m::addr)"
  by (rule less_induct, blast)


(* Some extra Hoare logic rules *)

lemma when_True:
  "P  when P A = A"
  by simp

lemma when_False:
  "¬ P  when P A = return ()"
  by simp

lemma is_array_after_changing_left:
  " m  n; is_array s1 a n; is_array s2 a m;
     unmodified_outside_range s1 s2 a m 
    is_array s2 a n"
  apply (case_tac "m = n", simp)
  apply (frule_tac m = "m" in subarray_not_at_mem_end2, simp)
  apply (erule_tac m = "m" in concat_is_array)
    apply (auto simp: is_still_array subarray2_is_array)
  done

lemma is_array_after_changing_right:
  " m  n; is_array s1 a n; is_array s2 (a +p int m) (n - m);
     unmodified_outside_range s1 s2 (a +p int m) (n - m) 
    is_array s2 a n"
  apply (case_tac "m = n")
   apply (simp add: is_array_def array_all_elems_valid unmodified_outside_empty_range)
  apply (frule_tac m = "m" in subarray_not_at_mem_end2, simp)
  apply (rule_tac m = "m" in concat_is_array)
     apply (rule_tac s = "s1" and a = "a +p int m" and n = "n - m" in is_still_array)
       apply simp+
     apply (rule_tac n = "n" in subarray1_is_array)
      apply simp+
  done


(*
 * There is an issue later on with using "subst" substitutions on expressions
 * involving certain variables, so we have to use "rule_tac" with rules like
 * the following.
 *)
lemma multiset_of_concat_array:
  " m  n; mset (the_array s a m) = mset (the_array s' a m);
     mset (the_array s (a +p int m) (n - m)) = mset (the_array s' (a +p int m) (n - m)) 
    mset (the_array s a n) = mset (the_array s' a n)"
  by (simp add: the_array_concat2)

lemma multiset_same_after_shuffling_left:
  " m  n; array_loc_valid a n;
     mset (the_array s1 a m) = mset (the_array s2 a m);
     unmodified_outside_range s1 s2 a m 
    mset (the_array s1 a n) = mset (the_array s2 a n)"
  apply (case_tac "m = n", simp)
  apply (frule_tac m = "m" in subarray_not_at_mem_end1, simp)
  apply (rule_tac m = "m" in multiset_of_concat_array, assumption+)
  apply (subgoal_tac "the_array s2 (a +p int m) (n - m) =
                      the_array s1 (a +p int m) (n - m)", simp)
  apply (erule_tac a = "a" and n = "m" in the_same_array)
   apply (clarsimp simp: is_array_def array_loc_valid_def
                         unat_linear_over_array_loc2, unat_arith)
  apply simp
  done

lemma multiset_same_after_shuffling_right:
  " m  n; array_loc_valid a n;
     mset (the_array s1 (a +p int m) (n - m)) =
     mset (the_array s2 (a +p int m) (n - m));
     unmodified_outside_range s1 s2 (a +p int m) (n - m) 
    mset (the_array s1 a n) = mset (the_array s2 a n)"
  apply (rule_tac m = "m" in multiset_of_concat_array)
    apply assumption
   apply (subgoal_tac "the_array s2 a m = the_array s1 a m", simp)
   apply (case_tac "m = n")
    apply (simp add: unmodified_outside_empty_range the_arrays_equal)
   apply (frule_tac m = "m" in subarray_not_at_mem_end1, simp)
   apply (rule_tac a = "a +p int m" and n = "n - m" in the_same_array)
    apply (auto simp: array_not_at_mem_end_def array_loc_valid_def)
  done


(* Preparing for lemmas about partitioned-ness being preserved *)

lemma old_array_elem:
  " mset (the_array s a n) =
     mset (the_array s0 a n); i < n 
    j < n. heap_w32 s (a +p int i) = heap_w32 s0 (a +p int j)"
  apply (drule mset_eq_setD)
  apply (subgoal_tac "heap_w32 s (a +p int i)  set (the_array s a n)")
   apply simp
   apply (frule nth_the_index)
   apply (rule_tac x = "the_index (the_array s0 a n) (heap_w32 s (a +p int i))"
          in exI)
   apply (rule conjI)
    apply (subst (4) the_array_length[symmetric, where s = "s0" and a = "a"])
    apply (erule the_index_bounded)
   apply (subst (asm) the_array_elem)
    apply (subst (4) the_array_length[symmetric, where s = "s0" and a = "a"])
    apply (erule the_index_bounded)
   apply simp
  apply (subst the_array_elem[symmetric, where n = "n"], assumption)
  apply (rule nth_mem, simp)
  done

lemma old_array_elem2:
  " mset (the_array s (a +p int k) (n - k)) =
     mset (the_array s0 (a +p int k) (n - k)); k + i < n 
    j < n. j  k  heap_w32 s (a +p int (k + i)) = heap_w32 s0 (a +p int j)"
  apply (drule_tac i = "i" in old_array_elem, simp)
  apply clarsimp
  apply (rule_tac x = "k + j" in exI)
  apply simp
  done

lemma partitioned_after_shuffling_left:
  " is_array s0 a (unat n); pivot_idx < n;
     partitioned s0 a (unat n) (unat pivot_idx);
     mset (the_array s a (unat pivot_idx)) =
     mset (the_array s0 a (unat pivot_idx));
     unmodified_outside_range s0 s a (unat pivot_idx) 
    partitioned s a (unat n) (unat pivot_idx)"
  apply (clarsimp simp: partitioned_def unmodified_outside_range_def)
  apply (subgoal_tac "heap_w32 s (a +p uint pivot_idx) =
                      heap_w32 s0 (a +p uint pivot_idx)", simp add: uint_nat del: Word.of_nat_unat)
   apply (subgoal_tac "array_not_at_mem_end a (unat pivot_idx)")
    apply (case_tac "i < unat pivot_idx", clarsimp)
     apply (subgoal_tac "j. (j < unat pivot_idx 
                              heap_w32 s (a +p int i) = heap_w32 s0 (a +p int j))")
      apply clarsimp
      apply (drule_tac x = "j" in spec)
      apply (subgoal_tac "j < unat n")
       apply (clarsimp simp: uint_nat simp del: Word.of_nat_unat)
      apply unat_arith
     apply (erule old_array_elem, simp)
    apply (drule_tac x = "a +p int i" in spec)
    apply (subgoal_tac "a +p int (unat pivot_idx)  a +p int i")
     apply simp
    apply (rule array_loc_mono)
     apply (rule_tac s = "s0" and n = "unat n" in subarray_not_at_mem_end2,
            assumption+)
    apply unat_arith
   apply (erule_tac s = "s0" and n = "unat n" in subarray_not_at_mem_end2,
          unat_arith)
  apply (drule_tac x = "a +p uint pivot_idx" in spec)
  apply (subgoal_tac "array_not_at_mem_end a (unat pivot_idx)")
   apply (simp add: uint_nat del: Word.of_nat_unat)
  apply (erule_tac s = "s0" and n = "unat n" in subarray_not_at_mem_end2,
         unat_arith)
  done

lemma partitioned_after_shuffling_right:
  " is_array s0 a (unat n); pivot_idx < n;
     partitioned s0 a (unat n) (unat pivot_idx);
     mset (the_array s (a +p int (Suc (unat pivot_idx)))
                       (unat n - Suc (unat pivot_idx))) =
     mset (the_array s0 (a +p int (Suc (unat pivot_idx)))
                        (unat n - Suc (unat pivot_idx)));
     unmodified_outside_range s0 s (a +p int (Suc (unat pivot_idx)))
                                   (unat n - unat pivot_idx - 1) 
    partitioned s a (unat n) (unat pivot_idx)"
  apply (unfold partitioned_def)
  apply (case_tac "unat n = Suc (unat pivot_idx)")
   apply (simp add: unmodified_outside_empty_range)
  apply (subgoal_tac "Suc (unat pivot_idx) < unat n")
   prefer 2
   apply unat_arith
  apply (subgoal_tac "array_not_at_mem_end a (Suc (unat pivot_idx))")
   prefer 2
   apply (rule_tac s = "s0" and n = "unat n" in subarray_not_at_mem_end2)
    apply assumption+
  apply (unfold unmodified_outside_range_def, clarify)
  apply (frule_tac x = "a +p int (unat pivot_idx)" in spec)
  apply (subgoal_tac "a +p int (unat pivot_idx) < a +p int (Suc (unat pivot_idx))")
   prefer 2
   apply (erule array_loc_strict_mono, simp)
  apply (case_tac "i  unat pivot_idx")
   apply (frule_tac x = "a +p int i" in spec)
   apply (subgoal_tac "a +p int i < a +p int (Suc (unat pivot_idx))")
     apply (simp add: uint_nat del: Word.of_nat_unat)
    apply (erule array_loc_strict_mono, simp)
  apply (subgoal_tac "¬ i < unat pivot_idx")
   prefer 2
   apply unat_arith
  apply (drule_tac i = "i - Suc (unat pivot_idx)" in old_array_elem2)
   apply unat_arith
  apply clarify
  apply (frule_tac x = "j" in spec)
  apply (subgoal_tac "¬ j < unat pivot_idx")
   apply (subgoal_tac "a +p 1 +p int (unat pivot_idx) +p
                       int (i - Suc (unat pivot_idx)) = a +p int i")
    apply simp
   apply (subst of_nat_diff, unat_arith)
   apply (subst ptr_add_assoc[symmetric])+
   apply simp
  apply unat_arith
  done


(*
 * The basic idea of quicksort: if the array is partitioned and
 * both halves are sorted then the array is sorted.
 *)

lemma partitioned_array_sorted:
  " m < n; sorted (the_array s a m);
     sorted (the_array s (a +p int (Suc m)) (n - Suc m));
     partitioned s a n m  
   sorted (the_array s a n)"
  apply (subst the_array_concat2[where m = "m"], simp)
  apply (subst sorted_append, simp)
  apply (rule conjI)
   apply (subst the_array_concat2[where m = "1"], simp+)
   apply (simp add: partitioned_def)
   apply (subst all_set_conv_all_nth)
   apply (clarsimp simp: the_array_elem)
   apply (drule_tac x = "m + i + 1" in spec)
   apply (subgoal_tac "m + i + 1 < n")
    apply clarsimp
    apply unat_arith
   apply unat_arith
  apply (subst all_set_conv_all_nth)
  apply (clarsimp simp: the_array_elem partitioned_def)
  apply (frule_tac x = "i" in spec)
  apply simp
  apply (subgoal_tac "x  heap_w32 s (a +p int m)")
   apply simp
  apply (frule nth_the_index)
  apply (subst (asm) the_array_elem)
  apply (frule the_index_bounded, simp)
  apply (subgoal_tac "m + the_index (the_array s (a +p int m) (n - m)) x < n")
   apply (frule_tac x = "m + the_index (the_array s (a +p int m) (n - m)) x"
          in spec)
   apply simp
  apply (subgoal_tac "the_index (the_array s (a +p int m) (n - m)) x < n - m")
   apply simp
  apply (frule the_index_bounded, simp)
  done


(* Some more trivial lemmas *)

lemma array_index_Suc:
  "a +p uint m +p 1 = a +p int (Suc (unat m))"
  by (simp add: uint_nat del: Word.of_nat_unat)

lemma array_loc_le_Suc:
  "array_not_at_mem_end a (Suc (unat m))  a +p int (unat m)  a +p 1 +p uint m"
  apply (subst ptr_add_commute)
  apply (subst array_index_Suc)
  apply (rule array_loc_mono, simp+)
  done

lemma unat_sub_sub1 [simp]:
  "(m::addr) < n  unat (n - m - 1) = unat n - unat m - 1"
  by unat_arith

lemma unat_inc:
  "(m::addr) < n  unat (m + 1) = Suc (unat m)"
  by unat_arith

(*
 * Proof of recursive quicksort function!
 *)

lemma (in ts_definition_quicksort) quicksort_correct:
  assumes "is_array s0 a (unat n)" "unat n < m"
  shows "
         quicksort' a n  s0
          λr s. is_array s a (unat n) 
                 mset (the_array s a (unat n)) =
                 mset (the_array s0 a (unat n)) 
                 sorted (the_array s a (unat n)) 
                 unmodified_outside_range s0 s a (unat n) "
   (is "quicksort' a n  s0  ?post a n s0 " )
  using assms  
proof (induction n arbitrary: a s0 rule: word_strong_induct )
  case (Ind n)
  note Ind.IH [rule_format, runs_to_vcg]
  note partition_correct [runs_to_vcg]
  show ?case
    using Ind.prems
    apply (subst quicksort'.simps)
    apply runs_to_vcg
    subgoal by unat_arith
    subgoal by (erule subarray1_is_array, unat_arith)
    subgoal by unat_arith
    subgoal by unat_arith
    subgoal for rv s' s'a
      apply (case_tac "unat n = Suc (unat rv)", simp add: empty_array_is_array)
        (* Add a few useful facts into the assumption set *)
      apply (frule_tac s = "s'" and m = "unat rv" in subarray_not_at_mem_end2, unat_arith)

      apply simp
      apply (frule_tac s = "s'" and m = "Suc (unat rv)" in subarray_not_at_mem_end2, unat_arith, unat_arith)
      apply (frule_tac s = "s'" and m = "unat rv" in subarray1_is_array, unat_arith)
      apply (frule_tac s = "s'" and m = "unat rv" in subarray2_is_array, unat_arith, simp)
      apply (frule_tac s = "s'" and m = "Suc (unat rv)" in subarray1_is_array, unat_arith)
      apply (frule_tac s = "s'" and m = "Suc (unat rv)" in subarray2_is_array,
          unat_arith, simp)
        (* ...and back to the proof *)
      apply (erule is_still_array)
       apply (rule disjI2, rule disjI1)
      using array_loc_le_Suc apply blast
      apply (simp add: uint_nat del: Word.of_nat_unat)
      done
    subgoal by unat_arith
    subgoal for rv s' s'a s'b
      apply (rule_tac ?s1.0 = "s'a" and m = "Suc (unat rv)"
          in is_array_after_changing_right)
         apply unat_arith
        apply (rule_tac ?s1.0 = "s'" and m = "unat rv"
          in is_array_after_changing_left)
           apply unat_arith
          apply assumption+
       apply (simp add: uint_nat del: Word.of_nat_unat)+
      done
    subgoal for rv s' s'a s'b
      apply (subgoal_tac "mset (the_array s'b a (unat n)) =
                           mset (the_array s'a a (unat n))")
       apply (subgoal_tac "mset (the_array s'a a (unat n)) =
                            mset (the_array s' a (unat n))", simp)
       apply (rule_tac m = "unat rv" in multiset_same_after_shuffling_left)
          apply unat_arith
         apply (simp add: is_array_def)
        apply assumption
       apply (simp add: unmodified_outside_range_def)
      apply (rule_tac m = "Suc (unat rv)" in multiset_same_after_shuffling_right)
         apply unat_arith
        apply (simp add: is_array_def)
       apply simp
      apply (simp add: unmodified_outside_range_def)
      done
    subgoal for rv s' s'a s'b
      apply (rule_tac m = "unat rv" in partitioned_array_sorted)
         apply unat_arith
        apply (subgoal_tac "the_array s'b a (unat rv) = the_array s'a a (unat rv)", simp)
        apply (case_tac "Suc (unat rv) = unat n")
         apply (simp add: unmodified_outside_empty_range the_arrays_equal)
        apply simp
        apply (erule_tac a = "a +p 1 +p uint rv" and n = "unat n - Suc (unat rv)"
          in the_same_array)
         apply (simp add: is_array_def)
        apply (rule disjI1, rule conjI)
         apply (erule subarray_not_at_mem_end2, unat_arith)
        apply (rule array_loc_le_Suc)
        apply (erule subarray_not_at_mem_end2, unat_arith)
       apply (simp add: uint_nat del: Word.of_nat_unat)
      apply (rule_tac ?s0.0 = "s'a" in partitioned_after_shuffling_right)
          apply (rule_tac ?s1.0 = "s'" and m = "unat rv" in is_array_after_changing_left)
             apply unat_arith
            apply assumption+
        apply (rule_tac ?s0.0 = "s'" in partitioned_after_shuffling_left, assumption+)
       apply (simp add: uint_nat del: Word.of_nat_unat)+
      done
    subgoal for rv s' s'a s'b
      apply (erule unmodified_outside_range_trans)
      apply (rule_tac ?s2.0 = "s'a" in unmodified_outside_range_trans)
       apply (rule_tac m = "unat rv" in unmodified_outside_subrange1)
         apply (simp add: is_array_def)
        apply assumption
       apply unat_arith
      apply (rule_tac m = "Suc (unat rv)" in unmodified_outside_subrange2)
        apply (simp add: is_array_def)
       apply (simp add: uint_nat del: Word.of_nat_unat)
      apply unat_arith
      done
    subgoal
      apply (case_tac "n = 1", simp)
      apply (subgoal_tac "n = 0", simp, unat_arith)
      done

    done
qed


end