Theory DBM_Operations_Impl_Refine

theory DBM_Operations_Impl_Refine
  imports
    DBM_Operations_Impl
    "HOL-Library.IArray"
    DBM_Imperative_Loops
begin

lemma rev_map_fold_append_aux:
  "fold (λ x xs. f x # xs) xs zs @ ys = fold (λ x xs. f x # xs) xs (zs@ys)"
  by (induction xs arbitrary: zs) auto

lemma rev_map_fold:
  "rev (map f xs) = fold (λ x xs. f x # xs) xs []"
  by (induction xs; simp add: rev_map_fold_append_aux)

lemma map_rev_fold:
  "map f xs = rev (fold (λ x xs. f x # xs) xs [])"
  using rev_map_fold rev_swap by fastforce

lemma pointwise_cmp_iff:
  "pointwise_cmp P n M M'  list_all2 P (take ((n + 1) * (n + 1)) xs) (take ((n + 1) * (n + 1)) ys)"
  if "in. jn. xs ! (i + i * n + j) = M i j"
    "in. jn. ys ! (i + i * n + j) = M' i j"
    "(n + 1) * (n + 1)  length xs" "(n + 1) * (n + 1)  length ys"
  using that unfolding pointwise_cmp_def
  unfolding list_all2_conv_all_nth
  apply clarsimp
  apply safe
  subgoal premises prems for x
  proof -
    let ?i = "x div (n + 1)" let ?j = "x mod (n + 1)"
    from x < _ have "?i < Suc n" "?j n"
      by (simp add: less_mult_imp_div_less)+
    with prems have
      "xs ! (?i + ?i * n + ?j) = M ?i ?j" "ys ! (?i + ?i * n + ?j) = M' ?i ?j"
      "P (M ?i ?j) (M' ?i ?j)"
      by auto
    moreover have "?i + ?i * n + ?j = x"
      by (metis ab_semigroup_add_class.add.commute mod_div_mult_eq mult_Suc_right plus_1_eq_Suc)
    ultimately show P (xs ! x) (ys ! x)
      by auto
  qed
  subgoal for i j
    apply (erule allE[of _ i], erule impE, simp)
    apply (erule allE[of _ i], erule impE, simp)
    apply (erule allE[of _ "i + i * n + j"], erule impE)
    subgoal
      by (rule le_imp_less_Suc) (auto intro!: add_mono simp: algebra_simps)
    apply (erule allE[of _ j], erule impE, simp)
    apply (erule allE[of _ j], erule impE, simp)
    apply simp
    done
  done

fun intersperse :: "'a  'a list  'a list" where
  "intersperse sep (x # y # xs) = x # sep # intersperse sep (y # xs)" |
  "intersperse _ xs = xs"

lemma the_pure_id_assn_eq[simp]:
  "the_pure (λa c.  (c = a)) = Id"
proof -
  have *: "(λa c.  (c = a)) = pure Id"
    unfolding pure_def by simp
  show ?thesis
    by (subst *) simp
qed

lemma pure_eq_conv:
  "(λa c.  (c = a)) = id_assn"
  using is_pure_assn_def is_pure_iff_pure_assn is_pure_the_pure_id_eq the_pure_id_assn_eq by blast

section ‹Refinement›

instance DBMEntry :: ("{countable}") countable
  apply (rule
    countable_classI[of
      "(λLe (a::'a)  to_nat (0::nat,a) |
           DBM.Lt a  to_nat (1::nat,a) |
            DBM.INF  to_nat (2::nat,undefined::'a) )"])
  apply (simp split: DBMEntry.splits)
done

instance DBMEntry :: ("{heap}") heap ..

definition dbm_subset' :: "nat  ('t :: {linorder, zero}) DBM'  't DBM'  bool" where
  "dbm_subset' n M M'  pointwise_cmp (≤) n (curry M) (curry M')"

lemma dbm_subset'_alt_def:
  "dbm_subset' n M M' 
    list_all (λi. list_all (λj. (op_mtx_get M (i, j)  op_mtx_get M' (i, j))) [0..<Suc n])
      [0..<Suc n]"
  by (simp add: dbm_subset'_def pointwise_cmp_alt_def neutral)

lemma dbm_subset_alt_def'[code]:
  "dbm_subset n M M' 
    list_ex (λi. op_mtx_get M (i, i) < 0) [0..<Suc n] 
    list_all (λi. list_all (λj. (op_mtx_get M (i, j)  op_mtx_get M' (i, j))) [0..<Suc n])
      [0..<Suc n]"
  by (simp add: dbm_subset_def check_diag_alt_def pointwise_cmp_alt_def neutral)

definition
  "mtx_line_to_iarray m M = IArray (map (λi. M (0, i)) [0..<Suc m])"

definition
  "mtx_line m (M :: _ DBM') = map (λi. M (0, i)) [0..<Suc m]"

locale DBM_Impl =
  fixes n :: nat
begin

abbreviation
  mtx_assn :: "(nat × nat  ('a :: {linordered_ab_monoid_add, heap}))  'a array  assn"
where
  "mtx_assn  asmtx_assn (Suc n) id_assn"

abbreviation "clock_assn  nbn_assn (Suc n)"

lemmas Relation.IdI[where a = , sepref_import_param]
lemma [sepref_import_param]: "((+),(+))  IdIdId" by simp
lemma [sepref_import_param]: "(uminus,uminus)  (Id::(_*_)set)Id" by simp
lemma [sepref_import_param]: "(Lt,Lt)  IdId" by simp
lemma [sepref_import_param]: "(Le,Le)  IdId" by simp
lemma [sepref_import_param]: "(,)  Id" by simp
lemma [sepref_import_param]: "(min :: _ DBMEntry  _, min)  Id  Id  Id" by simp
lemma [sepref_import_param]: "(Suc, Suc)  Id  Id" by simp

lemma [sepref_import_param]: "(norm_lower, norm_lower)  IdIdId" by simp
lemma [sepref_import_param]: "(norm_upper, norm_upper)  IdIdId" by simp
lemma [sepref_import_param]: "(norm_diag,  norm_diag)  IdId" by simp

end


definition zero_clock :: "_ :: linordered_cancel_ab_monoid_add" where
  "zero_clock = 0"

sepref_register zero_clock

lemma [sepref_import_param]: "(zero_clock, zero_clock)  Id" by simp

lemmas [sepref_opt_simps] = zero_clock_def


context
  fixes n :: nat
begin

interpretation DBM_Impl n .

sepref_definition reset_canonical_upd_impl' is
  "uncurry2 (uncurry (λx. RETURN ooo reset_canonical_upd x))" ::
  "[λ(((_,i),j),_). in  jn]a mtx_assnd *a nat_assnk  *a nat_assnk *a id_assnk  mtx_assn"
  unfolding reset_canonical_upd_alt_def op_mtx_set_def[symmetric] by sepref

sepref_definition reset_canonical_upd_impl is
  "uncurry2 (uncurry (λx. RETURN ooo reset_canonical_upd x))" ::
  "[λ(((_,i),j),_). in  jn]a mtx_assnd *a nat_assnk  *a nat_assnk *a id_assnk  mtx_assn"
  unfolding reset_canonical_upd_alt_def op_mtx_set_def[symmetric] by sepref

sepref_definition up_canonical_upd_impl is
  "uncurry (RETURN oo up_canonical_upd)" :: "[λ(_,i). in]a mtx_assnd *a nat_assnk  mtx_assn"
  unfolding up_canonical_upd_def op_mtx_set_def[symmetric] by sepref

lemma [sepref_import_param]:
  "(Le 0, 0)  Id"
  unfolding neutral by simp

― ‹Not sure if this is dangerous.›
sepref_register 0

sepref_definition check_diag_impl' is
  "uncurry (RETURN oo check_diag)" ::
  "[λ(i, _). in]a nat_assnk *a mtx_assnk  bool_assn"
  unfolding check_diag_alt_def list_ex_foldli neutral[symmetric] by sepref

lemma [sepref_opt_simps]:
  "(x = True) = x"
  by simp

sepref_definition dbm_subset'_impl2 is
  "uncurry2 (RETURN ooo dbm_subset')" ::
  "[λ((i, _), _). in]a nat_assnk *a mtx_assnk *a mtx_assnk  bool_assn"
unfolding dbm_subset'_alt_def list_all_foldli by sepref

definition
  "dbm_subset'_impl'  λm a b.
    do {
    imp_for 0 ((m + 1) * (m + 1)) Heap_Monad.return
      (λi _. do {
        x  Array.nth a i; y  Array.nth b i; Heap_Monad.return (x  y)
      })
      True
    }"

lemma imp_for_list_all2_spec:
  "
  <a a xs * b a ys>
  imp_for 0 n' Heap_Monad.return
    (λi _. do {
      x  Array.nth a i; y  Array.nth b i; Heap_Monad.return (P x y)
    })
    True
  <λr. (r  list_all2 P (take n' xs) (take n' ys)) * a a xs * b a ys>t"
  if "n'  length xs" "n'  length ys"
  apply (rule cons_rule[rotated 2])
    apply (rule imp_for_list_all2'[where xs = xs and ys = ys and R = id_assn and S = id_assn])
        apply (use that in simp; fail)+
    apply (sep_auto simp: pure_def array_assn_def is_array_def)+
  done

lemma dbm_subset'_impl'_refine:
  "(uncurry2 dbm_subset'_impl', uncurry2 (RETURN ∘∘∘ dbm_subset'))
 [λ((i, _), _). i = n]a nat_assnk *a local.mtx_assnk *a local.mtx_assnk  bool_assn"
  apply sepref_to_hoare
  unfolding dbm_subset'_impl'_def
  unfolding amtx_assn_def hr_comp_def is_amtx_def
  apply (sep_auto heap: imp_for_list_all2_spec simp only:)
    apply (simp; intro add_mono mult_mono; simp; fail)+
  apply sep_auto

  subgoal for b bi ba bia l la a bb
    unfolding dbm_subset'_def by (simp add: pointwise_cmp_iff[where xs = l and ys = la])

  subgoal for b bi ba bia l la a bb
    unfolding dbm_subset'_def by (simp add: pointwise_cmp_iff[where xs = l and ys = la])
  done

sepref_register check_diag ::
  "nat  _ :: {linordered_cancel_ab_monoid_add,heap} DBMEntry i_mtx  bool"

sepref_register dbm_subset' ::
  "nat  'a :: {linordered_cancel_ab_monoid_add,heap} DBMEntry i_mtx  'a DBMEntry i_mtx  bool"

lemmas [sepref_fr_rules] = dbm_subset'_impl'_refine check_diag_impl'.refine

sepref_definition dbm_subset_impl' is
  "uncurry2 (RETURN ooo dbm_subset)" ::
  "[λ((i, _), _). i=n]a nat_assnk *a mtx_assnk *a mtx_assnk  bool_assn"
unfolding dbm_subset_def dbm_subset'_def[symmetric] short_circuit_conv by sepref

context
  notes [id_rules] = itypeI[of n "TYPE (nat)"]
    and [sepref_import_param] = IdI[of n]
begin

sepref_definition dbm_subset_impl is
  "uncurry (RETURN oo PR_CONST (dbm_subset n))" :: "mtx_assnk *a mtx_assnk a bool_assn"
  unfolding dbm_subset_def dbm_subset'_def[symmetric] short_circuit_conv PR_CONST_def by sepref

sepref_definition check_diag_impl is
  "RETURN o PR_CONST (check_diag n)" :: "mtx_assnk a bool_assn"
  unfolding check_diag_alt_def list_ex_foldli neutral[symmetric] PR_CONST_def by sepref

sepref_definition dbm_subset'_impl is
  "uncurry (RETURN oo PR_CONST (dbm_subset' n))" :: "mtx_assnk *a mtx_assnk a bool_assn"
  unfolding dbm_subset'_alt_def list_all_foldli PR_CONST_def by sepref

end

abbreviation
  "iarray_assn x y  pure (br IArray (λ_. True)) y x"

lemma [sepref_fr_rules]:
  "(uncurry (return oo IArray.sub), uncurry (RETURN oo op_list_get))
   iarray_assnk *a id_assnk a id_assn"
unfolding br_def by sepref_to_hoare sep_auto

lemmas extra_defs = extra_upd_def upd_line_def upd_line_0_def

sepref_definition norm_upd_impl is
  "uncurry2 (RETURN ooo norm_upd)" ::
   "[λ((_, xs), i). length xs > n  in]a mtx_assnd *a iarray_assnk *a nat_assnk  mtx_assn"
  unfolding norm_upd_def extra_defs zero_clock_def[symmetric] by sepref

sepref_definition norm_upd_impl' is
  "uncurry2 (RETURN ooo norm_upd)" ::
   "[λ((_, xs), i). length xs > n  in]a mtx_assnd *a (list_assn id_assn)k *a nat_assnk  mtx_assn"
  unfolding norm_upd_def extra_defs zero_clock_def[symmetric] by sepref

sepref_definition extra_lu_upd_impl is
  "uncurry3 (λx. RETURN ooo (extra_lu_upd x))" ::
  "[λ(((_, ys), xs), i). length xs > n  length ys > n  in]a
    mtx_assnd *a iarray_assnk *a iarray_assnk *a nat_assnk  mtx_assn"
  unfolding extra_lu_upd_def extra_defs zero_clock_def[symmetric] by sepref

sepref_definition mtx_line_to_list_impl is
  "uncurry (RETURN oo PR_CONST mtx_line)" ::
  "[λ(m, _). m  n]a nat_assnk *a mtx_assnk  list_assn id_assn"
  unfolding mtx_line_def HOL_list.fold_custom_empty PR_CONST_def map_rev_fold by sepref

context
  fixes m :: nat assumes "m  n"
  notes [id_rules] = itypeI[of m "TYPE (nat)"]
    and [sepref_import_param] = IdI[of m]
begin

sepref_definition mtx_line_to_list_impl2 is
  "RETURN o PR_CONST mtx_line m" :: "mtx_assnk a list_assn id_assn"
  unfolding mtx_line_def HOL_list.fold_custom_empty PR_CONST_def map_rev_fold
  apply sepref_dbg_keep
  using m  n
      apply sepref_dbg_trans_keep
     apply sepref_dbg_opt
    apply sepref_dbg_cons_solve
   apply sepref_dbg_cons_solve
  apply sepref_dbg_constraints
  done

end

lemma IArray_impl:
  "(return o IArray, RETURN o id)  (list_assn id_assn)k a iarray_assn"
  by sepref_to_hoare (sep_auto simp: br_def list_assn_pure_conv pure_eq_conv)

definition
  "mtx_line_to_iarray_impl m M = (mtx_line_to_list_impl2 m M  return o IArray)"

lemmas mtx_line_to_iarray_impl_ht =
  mtx_line_to_list_impl2.refine[to_hnr, unfolded hn_refine_def hn_ctxt_def, simplified]

lemmas IArray_ht = IArray_impl[to_hnr, unfolded hn_refine_def hn_ctxt_def, simplified]

lemma mtx_line_to_iarray_impl_refine[sepref_fr_rules]:
  "(uncurry mtx_line_to_iarray_impl, uncurry (RETURN ∘∘ mtx_line))
   [λ(m, _). m  n]a nat_assnk *a mtx_assnk  iarray_assn"
  unfolding mtx_line_to_iarray_impl_def hfref_def
  apply clarsimp
  apply sepref_to_hoare
  apply (sep_auto
    heap: mtx_line_to_iarray_impl_ht IArray_ht simp: br_def pure_eq_conv list_assn_pure_conv)
  apply (simp add: pure_def)
  done

sepref_register "mtx_line" :: "nat  ('ef) DBMEntry i_mtx  'ef DBMEntry list"

lemma [sepref_import_param]: "(dbm_lt :: _ DBMEntry  _, dbm_lt)  Id  Id  Id" by simp

sepref_definition extra_lup_upd_impl is
  "uncurry3 (λx. RETURN ooo (extra_lup_upd x))" ::
   "[λ(((_, ys), xs), i). length xs > n  length ys > n  in]a
    mtx_assnd *a iarray_assnk *a iarray_assnk *a nat_assnk  mtx_assn"
  unfolding extra_lup_upd_alt_def2 extra_defs zero_clock_def[symmetric] mtx_line_def[symmetric]
  by sepref


context
  notes [id_rules] = itypeI[of n "TYPE (nat)"]
    and [sepref_import_param] = IdI[of n]
begin

definition
  "unbounded_dbm' = unbounded_dbm n"

lemma unbounded_dbm_alt_def:
  "unbounded_dbm n = op_amtx_new (Suc n) (Suc n) (unbounded_dbm')"
  unfolding unbounded_dbm'_def by simp

text ‹We need the custom rule here because unbounded_dbm› is a higher-order constant›
lemma [sepref_fr_rules]:
  "(uncurry0 (return unbounded_dbm'), uncurry0 (RETURN (PR_CONST (unbounded_dbm'))))
   unit_assnk a pure (nat_rel ×r nat_rel  Id)"
  by sepref_to_hoare sep_auto

sepref_register "PR_CONST (unbounded_dbm n) :: nat × nat  int DBMEntry" :: "'b DBMEntry i_mtx"
sepref_register "unbounded_dbm' :: nat × nat  _ DBMEntry"

text ‹Necessary to solve side conditions of @{term op_amtx_new}
lemma unbounded_dbm'_bounded:
  "mtx_nonzero unbounded_dbm'  {0..<Suc n} × {0..<Suc n}"
  unfolding mtx_nonzero_def unbounded_dbm'_def unbounded_dbm_def neutral by auto

text ‹We need to pre-process the lemmas due to a failure of TRADE›
lemma unbounded_dbm'_bounded_1:
  "(a, b)  mtx_nonzero unbounded_dbm'  a < Suc n"
  using unbounded_dbm'_bounded by auto

lemma unbounded_dbm'_bounded_2:
  "(a, b)  mtx_nonzero unbounded_dbm'  b < Suc n"
  using unbounded_dbm'_bounded by auto

lemmas [sepref_fr_rules] = dbm_subset_impl.refine

sepref_register "PR_CONST (dbm_subset n)" :: "'e DBMEntry i_mtx  'e DBMEntry i_mtx  bool"

lemma [def_pat_rules]:
  "dbm_subset $ n  PR_CONST (dbm_subset n)"
  by simp

sepref_definition unbounded_dbm_impl is
  "uncurry0 (RETURN (PR_CONST (unbounded_dbm n)))" :: "unit_assnk a mtx_assn"
  supply unbounded_dbm'_bounded_1[simp] unbounded_dbm'_bounded_2[simp]
  using unbounded_dbm'_bounded
  apply (subst unbounded_dbm_alt_def)
  unfolding PR_CONST_def by sepref


text ‹DBM to List›
definition dbm_to_list :: "(nat × nat  'a)  'a list" where
  "dbm_to_list M 
  rev $ fold (λi xs. fold (λj xs. M (i, j) # xs) [0..<Suc n] xs) [0..<Suc n] []"

sepref_definition dbm_to_list_impl is
  "RETURN o PR_CONST dbm_to_list" :: "mtx_assnk a list_assn id_assn"
  unfolding dbm_to_list_def HOL_list.fold_custom_empty PR_CONST_def by sepref


section ‹Pretty-Printing›

context
  fixes show_clock :: "nat  string"
    and show_num :: "'a :: {linordered_ab_group_add,heap}  string"
begin

definition
  "make_string e i j 
    if i = j then if e < 0 then Some (''EMPTY'') else None
    else
    if i = 0 then
    case e of
      DBMEntry.Le a  if a = 0 then None else Some (show_clock j @ '' >= '' @ show_num (- a))
    | DBMEntry.Lt a  Some (show_clock j @ '' > ''  @ show_num (- a))
    | _  None
    else if j = 0 then
    case e of
      DBMEntry.Le a  Some (show_clock i @ '' <= '' @ show_num a)
    | DBMEntry.Lt a  Some (show_clock i @ '' < ''  @ show_num a)
    | _  None
    else
    case e of
      DBMEntry.Le a  Some (show_clock i @ '' - '' @ show_clock j @ '' <= '' @ show_num a)
    | DBMEntry.Lt a  Some (show_clock i @ '' - '' @ show_clock j @ '' < '' @ show_num a)
    | _  None
"

definition
  "dbm_list_to_string xs 
  (concat o intersperse '', '' o rev o snd o snd) $ fold (λe (i, j, acc).
    let
      v = make_string e i j;
      j = (j + 1) mod (n + 1);
      i = (if j = 0 then i + 1 else i)
    in
    case v of
      None  (i, j, acc)
    | Some s  (i, j, s # acc)
  ) xs (0, 0, [])
"

lemma [sepref_import_param]:
  "(dbm_list_to_string, PR_CONST dbm_list_to_string)  Idlist_rel  Idlist_rel"
  by simp

definition show_dbm where
  "show_dbm M  PR_CONST dbm_list_to_string (dbm_to_list M)"

sepref_register "PR_CONST local.dbm_list_to_string"
sepref_register dbm_to_list :: "'b i_mtx  'b list"

lemmas [sepref_fr_rules] = dbm_to_list_impl.refine

sepref_definition show_dbm_impl is
  "RETURN o show_dbm" :: "mtx_assnk a list_assn id_assn"
  unfolding show_dbm_def by sepref

end (* Context for show functions *)

end (* Context for importing n *)

end (* Context for DBM dimension n *)


section ‹Generate Code›

lemma [code]:
  "dbm_le a b = (a = b  (a  b))"
unfolding dbm_le_def by auto

export_code
  norm_upd_impl
  reset_canonical_upd_impl
  up_canonical_upd_impl
  dbm_subset_impl
  dbm_subset
  show_dbm_impl
checking SML

export_code
  norm_upd_impl
  reset_canonical_upd_impl
  up_canonical_upd_impl
  dbm_subset_impl
  dbm_subset
  show_dbm_impl
checking SML_imp

end