Theory Regex_Equivalence

section ‹ Key algorithms for WEST ›

theory Regex_Equivalence

imports WEST_Algorithms WEST_Proofs

begin

fun depth_dataype_list:: "state_regex  nat"
  where "depth_dataype_list [] = 0"
  | "depth_dataype_list (One#T) = 1 + depth_dataype_list T"
  | "depth_dataype_list (Zero#T) = 1 + depth_dataype_list T"
  | "depth_dataype_list (S#T) = 2 + 2*(depth_dataype_list T)"


function enumerate_list:: "state_regex  trace_regex"
  where "enumerate_list [] = [[]]"
  | "enumerate_list (One#T) =  (map (λx. One#x) (enumerate_list T))"
  | "enumerate_list (Zero#T) =  (map (λx. Zero#x) (enumerate_list T))"
  | "enumerate_list (S#T) =  (enumerate_list (Zero#T))@(enumerate_list (One#T))"
  apply (metis WEST_and_bitwise.elims list.exhaust)
  by simp_all
termination  apply (relation "measure (λL. depth_dataype_list L)")
  by simp_all



fun flatten_list:: "'a list list  'a list"
  where "flatten_list L = foldr (@) L []"

value "flatten_list [[12, 13::nat], [15]]"

value "flatten_list (let enumerate_H = enumerate_list [S, One] in
let enumerate_T = [[]] in
map (λt. (map (λh. h#t) enumerate_H)) enumerate_T)"


fun enumerate_trace:: "trace_regex  WEST_regex"
  where "enumerate_trace [] = [[]]"
  | "enumerate_trace (H#T) = flatten_list
  (let enumerate_H = enumerate_list H in
   let enumerate_T = enumerate_trace T in
   map (λt. (map (λh. h#t) enumerate_H)) enumerate_T)"

value "enumerate_trace [[S, One], [S], [One]]"
value "enumerate_trace [[]]"

fun enumerate_sets:: "WEST_regex  trace_regex set"
  where "enumerate_sets [] = {}"
  | "enumerate_sets (h#T) = (set (enumerate_trace h))  (enumerate_sets T)"

fun naive_equivalence:: "WEST_regex  WEST_regex  bool"
  where "naive_equivalence A B = (A = B  (enumerate_sets A) = (enumerate_sets B))"


section ‹ Regex Equivalence Correctness ›

lemma enumerate_list_len_alt:
  shows " state  set (enumerate_list state_regex).
         length state = length state_regex"
proof(induct state_regex)
  case Nil
  then show ?case by simp
next
  case (Cons a state_regex)
  {assume zero: "a = Zero"
    then have " state  set (enumerate_list state_regex).
         length state = length state_regex"
      using Cons by blast
    then have ?case unfolding zero
      by simp
  } moreover {
    assume one: "a = One"
    then have " state  set (enumerate_list state_regex).
         length state = length state_regex"
      using Cons by blast
    then have ?case unfolding one
      by simp
  } moreover {
    assume s: "a = S"
    then have " state  set (enumerate_list state_regex).
         length state = length state_regex"
      using Cons by blast
    then have ?case unfolding s by auto
  }
  ultimately show ?case
    using WEST_bit.exhaust by blast
qed


lemma enumerate_list_len:
  assumes "state  set (enumerate_list state_regex)"
  shows "length state = length state_regex"
  using assms enumerate_list_len_alt by blast


lemma enumerate_list_prop:
  assumes "(k. List.member j k  k  S)"
  shows "enumerate_list j = [j]"
  using assms
proof (induct j)
  case Nil
  then show ?case by auto
next
  case (Cons h t)
  then have elt: "enumerate_list t = [t]"
    by (simp add: member_rec(1))
  then have "h = One  h = Zero"
    using Cons
    by (meson WEST_bit.exhaust member_rec(1))
  then show ?case using enumerate_list.simps(2-3) elt
    by fastforce
qed


lemma enumerate_fixed_trace:
  fixes h1:: "trace_regex"
  assumes "j. List.member h1 j  (k. List.member j k  k  S)"
  shows "(enumerate_trace h1) = [h1]"
  using assms
proof (induct h1)
  case Nil
  then show ?case by auto
next
  case (Cons h t)
  then have ind: "enumerate_trace t = [t]"
    by (meson member_rec(1))
  have "enumerate_list h = [h]"
    using enumerate_list_prop Cons
    by (meson member_rec(1))
  then show ?case
    using Cons ind unfolding enumerate_trace.simps
    by auto
qed

text ‹ If we have two state regexs that don't contain S's,
   then enumerate trace on each is different. ›
lemma enum_trace_prop:
  fixes h1 h2:: "trace_regex"
  assumes "j. List.member h1 j  (k. List.member j k  k  S)"
  assumes "j. List.member h2 j  (k. List.member j k  k  S)"
  assumes "(set h1)  (set h2)"
  shows "set (enumerate_trace h1)  set (enumerate_trace h2)"
  using enumerate_fixed_trace[of h1] enumerate_fixed_trace[of h2] assms
  by auto

lemma enumerate_list_tail_in:
  assumes "head_t#tail_t  set (enumerate_list (h#trace))"
  shows "tail_t  set (enumerate_list trace)"
proof-
  {assume one: "h = One"
    have ?thesis
      using assms unfolding one enumerate_list.simps by auto
  } moreover {
    assume zero: "h = Zero"
    have ?thesis
      using assms unfolding zero enumerate_list.simps by auto
  } moreover {
    assume s: "h = S"
    have ?thesis
      using assms unfolding s enumerate_list.simps by auto
  }
  ultimately show ?thesis using WEST_bit.exhaust by blast
qed

lemma enumerate_list_fixed:
  assumes "t  set (enumerate_list trace)"
  shows "(k. List.member t k  k  S)"
  using assms
proof (induct trace arbitrary: t)
  case Nil
  then show ?case using member_rec(2) by force
next
  case (Cons h trace)
  obtain head_t tail_t where obt: "t = head_t#tail_t"
    using Cons.prems enumerate_list_len
    by (metis length_0_conv neq_Nil_conv)
  have "tail_t  set (enumerate_list trace)"
    using enumerate_list.simps obt Cons.prems enumerate_list_tail_in by blast
  then have hyp: "k. List.member tail_t k  k  S"
    using Cons.hyps(1)[of tail_t] by auto
  {assume one: "h = One"
    then have "head_t = One"
      using obt Cons.prems unfolding enumerate_list.simps by auto
    then have ?case
      using hyp obt
      by (simp add: member_rec(1))
  } moreover {
    assume zero: "h = Zero"
    then have "head_t = Zero"
      using obt Cons.prems unfolding enumerate_list.simps by auto
    then have ?case
      using hyp obt
      by (simp add: member_rec(1))
  } moreover {
    assume s: "h = S"
    then have "head_t = Zero  head_t = One"
      using obt Cons.prems unfolding enumerate_list.simps by auto
    then have ?case
      using hyp obt
      by (metis calculation(1) calculation(2) member_rec(1) s)
  }
  ultimately show ?case using WEST_bit.exhaust by blast
qed


lemma map_enum_list_nonempty:
  fixes t::"WEST_bit list list"
  fixes head::"WEST_bit list"
  shows "map (λh. h # t) (enumerate_list head)  []"
proof(induct head arbitrary: t)
  case Nil
  then show ?case by simp
next
  case (Cons a head)
  {assume a: "a = One"
    then have ?case unfolding a enumerate_list.simps
      using Cons by auto
  } moreover {
    assume a: "a = Zero"
    then have ?case unfolding a enumerate_list.simps
      using Cons by auto
  } moreover {
    assume a: "a = S"
    then have ?case unfolding a enumerate_list.simps
      using Cons by auto
  }
  ultimately show ?case using WEST_bit.exhaust by blast
qed



lemma length_of_flatten_list:
assumes "flat =
  foldr (@)
   (map (λt. map (λh. h # t) H) T) []"
shows " length flat = length T * length H"
  using assms
proof (induct T arbitrary: flat)
  case Nil
  then show ?case by auto
next
  case (Cons t1 T2)
  then have "flat = foldr (@)
     (map (λt. map (λh. h # t) H) (t1 # T2)) []"
    by auto
  then have "flat = foldr (@)
    (map (λh. h # t1) H #(map (λt. map (λh. h # t) H) T2)) []"
    by auto
  then have "flat = map (λh. h # t1) H @ (foldr (@) (map (λt. map (λh. h # t) H) T2)) []"
    by simp
  then have "length flat = length H + length (T2) * length H"
    using Cons by auto
  then show ?case by simp
qed


lemma flatten_list_idx:
  assumes "flat = flatten_list (map (λt. map (λh. h # t) head) tail)"
  assumes "i < length tail"
  assumes "j < length head"
  shows "(head!j)#(tail!i) = flat!(i*(length head) + j)  i*(length head) + j < length flat"
  using assms
proof(induct tail arbitrary: head i j flat)
  case Nil
  then show ?case
    by auto
next
  case (Cons a tail)
  let ?flat = "flatten_list (map (λt. map (λh. h # t) head) tail)"
  have cond1: "?flat = ?flat" by auto
  have equiv: "(map (λt. map (λh. h # t) head) (a # tail)) =
      (map (λh. h # a) head) # (map (λt. map (λh. h # t) head) tail)"
      by auto
  then have flat_is: "flat = (map (λh. h # a) head) @ flatten_list (map (λt. map (λh. h # t) head) tail)"
    using Cons(2) unfolding flatten_list.simps by simp

  {assume i0: "i = 0"
    then have bound: "i * length head + j < length flat"
      using Cons by simp
     have "length (map (λh. h # a) head) > j"
      using Cons(4) by auto
    then have "(map (λh. h # a) head) ! j = flat ! j"
      using flat_is
      by (simp add: nth_append)
    then have "(head ! j)#a  = flat ! j"
      using Cons(4) by simp
    then have "head ! j # (a # tail) ! i = flat ! (i * length head + j)"
      unfolding i0 by simp
    then have ?case using bound by auto
  } moreover {
    assume i_ge_0: "i > 0"
    have len_flat: "length flat = length head * length (a # tail)"
     using Cons(3-4) length_of_flatten_list[of flat head "a#tail"]
      Cons(2) unfolding flatten_list.simps
     by simp
   have "i * length head  (length (a # tail) - 1)*length head"
     using Cons(3) by auto
   then have "i * length head  (length (a # tail))*length head - length head"
     by auto
   then have "i * length head + j < (length (a # tail))*length head - length head + length head"
     using Cons(4) by linarith
   then have "i * length head + j < (length (a # tail))*length head"
     by auto
    then have bound: "i * length head + j < length flat"
      using len_flat
      by (simp add: mult.commute)
    have i_minus: " i - 1 < length tail"
      using i_ge_0 Cons(3)
      by auto
    have "flat ! (i * length head + j) = flat ! ((i-1) * length head + j + length head)"
      using i_ge_0
      by (smt (z3) add.commute bot_nat_0.not_eq_extremum group_cancel.add1 mult_eq_if)
    then have "flat ! (i * length head + j) = flatten_list
     (map (λt. map (λh. h # t) head) tail) !
    ((i - 1) * length head + j)"
      using flat_is
      by (smt (verit, ccfv_threshold) add.commute length_map nth_append_length_plus)
    then have  "flat ! (i * length head + j) = head ! j # tail ! (i - 1)"
          using Cons.hyps[OF cond1 i_minus Cons(4)]
          by argo
    then have access: "head ! j # (a # tail) ! i =
    flat ! (i * length head + j)"
      using i_ge_0
      by simp
    have ?case
      using bound access
      by auto
  }
  ultimately show ?case by blast
qed


lemma flatten_list_shape:
  assumes "List.member flat x1"
  assumes "flat = flatten_list (map (λt. map (λh. h # t) H) T)"
  shows " x1_head x1_tail. x1 = x1_head#x1_tail  List.member H x1_head  List.member T x1_tail"
  using assms
proof(induction T arbitrary: flat H)
  case Nil
  have "flat = (flatten_list (map (λt. map (λh. h # t) H) []))"
    using Nil(1) unfolding Nil by blast
  then have "flat = []"
    by simp
  then show ?case
    using Nil
    by (simp add: member_rec(2))
next
  case (Cons a T)
  have "k. x1 = flat ! k  k < length flat"
     using Cons(2)
     by (metis in_set_conv_nth member_def)
  then obtain k where k_is: "x1 = flat ! k  k < length flat"
    by auto
  have len_flat: "length flat = (length (a#T)*length H)"
    using Cons(3) length_of_flatten_list
    by auto
  let ?j = "k mod (length H)"
  have "i . k = (i*(length H)+?j)"
    by (meson mod_div_decomp)
  then obtain i where i_is: "k = (i*(length H)+?j)"
    by auto
  then have i_lt: "i < length (a#T)"
    using len_flat k_is
    by (metis add_lessD1 mult_less_cancel2)
  have j_lt: "?j < length H"
    by (metis k_is len_flat length_0_conv length_greater_0_conv mod_by_0 mod_less_divisor mult_0_right)
  have "i < length (a # T). k = (i*(length H)+?j)"
    using i_is i_lt by blast
  then have "i < length (a # T). j < length H. k = (i*(length H)+j)"
    using j_lt by blast
  then obtain i j where ij_props: "i < length (a#T)" "j < length H" "k = (i*(length H)+j)"
    by blast
  then have "flat ! k =  H ! j # (a # T) ! i"
    using flatten_list_idx[OF Cons(3) ij_props(1) ij_props(2) ]
      Cons(2) k_is ij_props(3)
    by argo
  then obtain x1_head x1_tail where "x1 = x1_head#x1_tail"
  and "List.member H x1_head" and "List.member (a#T) x1_tail"
    using ij_props
    by (simp add: index_of_L_in_L k_is)
  then show ?case
    using Cons(3) by simp
qed


lemma flatten_list_len:
  assumes "t. List.member T t  length t = n"
  assumes "flat = flatten_list (map (λt. map (λh. h # t) H) T)"
  shows "x1. List.member flat x1  length x1 = n+1"
  using assms
proof(induction T arbitrary: flat n H)
  case Nil
  have "flat = (flatten_list (map (λt. map (λh. h # t) H) []))"
    using Nil(1) unfolding Nil(3) by blast
  then have "flat = []"
    by simp
  then show ?case
    using Nil by (simp add: member_rec(2))
next
  case (Cons a T)
  have "k. x1 = flat ! k  k < length flat"
     using Cons(2)
     by (metis in_set_conv_nth member_def)
  then obtain k where k_is: "x1 = flat ! k  k < length flat"
    by auto
  have len_flat: "length flat = (length (a#T)*length H)"
    using Cons(4) length_of_flatten_list
    by auto
  let ?j = "k mod (length H)"
  have "i . k = (i*(length H)+?j)"
    by (meson mod_div_decomp)
  then obtain i where i_is: "k = (i*(length H)+?j)"
    by auto
  then have i_lt: "i < length (a#T)"
    using len_flat k_is
    by (metis add_lessD1 mult_less_cancel2)
  have j_lt: "?j < length H"
    by (metis k_is len_flat length_0_conv length_greater_0_conv mod_by_0 mod_less_divisor mult_0_right)
  have "i < length (a # T). k = (i*(length H)+?j)"
    using i_is i_lt by blast
  then have "i < length (a # T). j < length H. k = (i*(length H)+j)"
    using j_lt by blast
  then obtain i j where ij_props: "i < length (a#T)" "j < length H" "k = (i*(length H)+j)"
    by blast
  then have "flat ! k =  H ! j # (a # T) ! i"
    using flatten_list_idx[OF Cons(4) ij_props(1) ij_props(2) ]
      Cons(2) k_is ij_props(3)
    by argo
  then obtain x1_head x1_tail where "x1 = x1_head#x1_tail"
  and "List.member H x1_head" and "List.member (a#T) x1_tail"
    using ij_props
    by (simp add: index_of_L_in_L k_is)
  then show ?case
    using Cons(3) by simp
qed


lemma flatten_list_lemma:
  assumes "x1. List.member to_flatten x1  (x2. List.member x1 x2  length x2 = length trace)"
  assumes "a  set (flatten_list to_flatten)"
  shows "length a = length trace"
  using assms proof (induct to_flatten)
  case Nil
  then show ?case by auto
next
  case (Cons h t)
   have a_in: "a  set h  a  set (flatten_list t)"
     using Cons(3) unfolding flatten_list.simps foldr_def by simp
  {assume *: "a  set h"
    then have ?case
      using Cons(2)[of h]
      by (simp add: in_set_member member_rec(1))
  } moreover {assume *: "a  set (flatten_list t)"
    have ind_h_setup: "(x1 x2. List.member t x1  List.member x1 x2 
        length x2 = length trace)"
      using Cons(2) by (meson member_rec(1))
    have " a  set (flatten_list t)  length a = length trace"
      using Cons(1) ind_h_setup
      by auto
    then have ?case
      using * by auto
  }
  ultimately show ?case
    using a_in by blast
qed


lemma enumerate_trace_len:
  assumes "a  set (enumerate_trace trace)"
  shows "length a = length trace"
  using assms
proof(induct "length trace" arbitrary: trace a)
  case 0
  then show ?case by auto
next
  case (Suc x)
  then obtain h t where trace_is: "trace = h#t"
    by (meson Suc_length_conv)
  obtain i where "(enumerate_trace trace)!i = a"
    using Suc.prems
    by (meson in_set_conv_nth)
  let ?enumerate_H = "enumerate_list h"
  let ?enumerate_t = "enumerate_trace t"
  have enum_tr_is: "enumerate_trace trace =
      flatten_list (map (λt. map (λh. h # t) ?enumerate_H) ?enumerate_t)"
    using trace_is by auto
  let ?to_flatten = "map (λt. map (λh. h # t) ?enumerate_H) ?enumerate_t"

  have all_w: "(w. List.member (enumerate_trace t) w  length w = length t)"
    using Suc(1)[of t] Suc(2) trace_is
    by (simp add: in_set_member)
  have a_mem: "List.member (enumerate_trace trace) a"
    using Suc(3) in_set_member by fast
  show ?case
    using flatten_list_len[OF _ enum_tr_is a_mem, of "length t"] all_w
    trace_is by simp
qed

definition regex_zeros_and_ones:: "trace_regex  bool"
  where "regex_zeros_and_ones tr =
    (j. List.member tr j  (k. List.member j k  k  S))"


lemma match_enumerate_state_aux_first_bit:
  assumes "a_head = Zero  a_head = One"
  assumes "a_head # a_tail  set (enumerate_list (h_head # h))"
  shows "h_head = a_head  h_head = S"
proof-
  {assume h_head: "h_head = One"
    then have ?thesis
      using assms unfolding h_head enumerate_list.simps by auto
  } moreover {
    assume h_head: "h_head = Zero"
    then have ?thesis
      using assms unfolding h_head enumerate_list.simps by auto
  } moreover {
    assume "h_head = S"
    then have ?thesis by auto
  }
  ultimately show ?thesis using WEST_bit.exhaust by blast
qed

lemma advance_state_iff:
  assumes "x > 0"
  shows "x  state  (x-1)  advance_state state"
proof-
  have forward: "x  state  (x-1)  advance_state state"
    using assms by auto
  have converse: "(x-1)  advance_state state  x  state"
    unfolding advance_state.simps using assms
    by (smt (verit, best) Suc_diff_1 diff_0_eq_0 diff_Suc_1' diff_self_eq_0 less_one mem_Collect_eq nat.distinct(1) not0_implies_Suc not_gr_zero old.nat.exhaust)
  show ?thesis using forward converse by blast
qed

lemma match_enumerate_state_aux:
  assumes "a  set (enumerate_list h)"
  assumes "match_timestep state a"
  shows "match_timestep state h"
  using assms
proof(induct h arbitrary: state a)
  case Nil
  have "a = []"
    using Nil by auto
  then show ?case using Nil by blast
next
  case (Cons h_head h)
  then obtain a_head a_tail where obt: "a = a_head#a_tail"
    using enumerate_list_len Cons
    by (metis length_0_conv list.exhaust)
  let ?adv_state = "advance_state state"
  {assume in_state: "0  state"
    then have "a_head = One"
      using Cons.prems(2) unfolding obt match_timestep_def
      using enumerate_list_fixed
      by (metis WEST_bit.exhaust Cons(2) length_pos_if_in_set list.set_intros(1) member_rec(1) nth_Cons_0 obt)
    then have h_head: "h_head = One  h_head = S"
      using Cons.prems(1) unfolding obt
      using match_enumerate_state_aux_first_bit by blast
    have match_adv: "match_timestep (advance_state state) h"
      using Cons.hyps[of a_tail ?adv_state]
      using Cons.prems(1) Cons.prems(2) advance_state_match_timestep enumerate_list_tail_in obt by blast
    have "x. x<length (h_head # h) 
       ((h_head # h) ! x = One  x  state) 
       ((h_head # h) ! x = Zero  x  state)"
    proof-
      fix x
      assume x: "x<length (h_head # h)"
      let ?thesis = "((h_head # h) ! x = One  x  state) 
       ((h_head # h) ! x = Zero  x  state)"
      {assume x0: "x = 0"
        then have ?thesis unfolding x0 using h_head in_state by auto
      } moreover {
        assume x_ge_0: "x > 0"
        then have "x-1 < length h"
          using x by simp
        then have *:"(h ! (x-1) = One  (x-1)  advance_state state) 
                   (h ! (x-1) = Zero  (x-1)  advance_state state)"
          using match_adv unfolding match_timestep_def by blast
        have "h ! (x-1) = (h_head # h) ! x" using x_ge_0 by auto
        then have *: "((h_head # h) ! x = One  (x-1)  advance_state state) 
                   ((h_head # h) ! x = Zero  (x-1)  advance_state state)"
          using * by argo
        then have ?thesis using advance_state_iff x_ge_0 by blast
      }
      ultimately show ?thesis by blast
    qed
    then have ?case
      using h_head unfolding match_timestep_def by blast
  } moreover {
    assume not_in: "0  state"
    then have "a_head = Zero"
      using Cons.prems(2) unfolding obt match_timestep_def
      using enumerate_list_fixed
      by (metis WEST_bit.exhaust Cons(2) length_pos_if_in_set list.set_intros(1) member_rec(1) nth_Cons_0 obt)
    then have h_head: "h_head = Zero  h_head = S"
      using Cons.prems(1) unfolding obt
      using match_enumerate_state_aux_first_bit by blast
    have match_adv: "match_timestep (advance_state state) h"
      using Cons.hyps[of a_tail ?adv_state]
      using Cons.prems(1) Cons.prems(2) advance_state_match_timestep enumerate_list_tail_in obt by blast
    have "x. x<length (h_head # h) 
       ((h_head # h) ! x = One  x  state) 
       ((h_head # h) ! x = Zero  x  state)"
    proof-
      fix x
      assume x: "x<length (h_head # h)"
      let ?thesis = "((h_head # h) ! x = One  x  state) 
       ((h_head # h) ! x = Zero  x  state)"
      {assume x0: "x = 0"
        then have ?thesis unfolding x0 using h_head not_in by auto
      } moreover {
        assume x_ge_0: "x > 0"
        then have "x-1 < length h"
          using x by simp
        then have *:"(h ! (x-1) = One  (x-1)  advance_state state) 
                   (h ! (x-1) = Zero  (x-1)  advance_state state)"
          using match_adv unfolding match_timestep_def by blast
        have "h ! (x-1) = (h_head # h) ! x" using x_ge_0 by auto
        then have *: "((h_head # h) ! x = One  (x-1)  advance_state state) 
                   ((h_head # h) ! x = Zero  (x-1)  advance_state state)"
          using * by argo
        then have ?thesis using advance_state_iff x_ge_0 by blast
      }
      ultimately show ?thesis by blast
    qed
    then have ?case
      using h_head unfolding match_timestep_def by blast
  }
  ultimately show ?case using WEST_bit.exhaust by blast
qed


lemma enumerate_list_index_one:
  assumes "j < length (enumerate_list a)"
  shows "One # enumerate_list a ! j = enumerate_list (S # a) ! (length (enumerate_list a) + j) 
    (length (enumerate_list a) + j < length (enumerate_list (S # a)))"
  using assms
proof(induct a arbitrary: j)
  case Nil
  then show ?case by auto
next
  case (Cons a1 a2)
  then show ?case unfolding enumerate_list.simps
    by (metis (mono_tags, lifting) length_append length_map nat_add_left_cancel_less nth_append_length_plus nth_map)
qed

lemma list_concat_index:
  assumes "j < length L1"
  shows "(L1@L2)!j = L1!j"
  using assms
  by (simp add: nth_append)

lemma enumerate_list_index_zero:
  assumes "j < length (enumerate_list a)"
  shows "Zero # enumerate_list a ! j = enumerate_list (S # a) ! j 
    j < length (enumerate_list (S # a))"
  using assms unfolding enumerate_list.simps
proof(induct a arbitrary: j)
  case Nil
  then show ?case by simp
next
  case (Cons a1 a2)
  then have j_bound: "j < length (enumerate_list (S # a1 # a2))"
    by simp
  let ?subgoal = "Zero # enumerate_list (a1 # a2) ! j = enumerate_list (S # a1 # a2) ! j"
  have "j < length (map ((#) Zero) (enumerate_list (a1 # a2)))"
    using j_bound Cons by simp
  then have "(((map ((#) Zero) (enumerate_list (a1 # a2)) @
     map ((#) One) (enumerate_list (a1 # a2)))) !
    j) = (map ((#) Zero) (enumerate_list (a1 # a2)))!j"
    using Cons.prems j_bound list_concat_index by blast
  then have ?subgoal using Cons unfolding enumerate_list.simps
    by simp
  then show ?case using j_bound by auto
qed


lemma match_enumerate_list:
  assumes "match_timestep state a"
  shows "j<length (enumerate_list a).
         match_timestep state (enumerate_list a ! j)"
  using assms
proof(induct a arbitrary: state)
  case Nil
  then show ?case by simp
next
  case (Cons head a)
  let ?adv_state = "advance_state state"
  {assume in_state: "0  state"
    then have "(head # a) ! 0  Zero"
      using Cons.prems unfolding match_timestep_def by blast
    then have head: "head = One  head = S"
      using WEST_bit.exhaust by auto
    have "match_timestep ?adv_state a"
      using Cons.prems
      using advance_state_match_timestep by auto
    then obtain j where obt: "match_timestep ?adv_state (enumerate_list a ! j)
                             j < length (enumerate_list a)"
      using Cons.hyps[of ?adv_state] by blast
    let ?state = "(enumerate_list a ! j)"
    {assume headcase: "head = One"
      let ?s = "One # ?state"
      have "x. x<length ?s 
       ((?s ! x = One  x  state)  (?s ! x = Zero  x  state))"
      proof-
        fix x
        assume x: "x<length ?s"
        let ?thesis = "((?s ! x = One  x  state)  (?s ! x = Zero  x  state))"
        {assume x0: "x = 0"
          then have ?thesis using in_state by simp
        } moreover {
          assume x_ge_0: "x > 0"
          have cond1: "(One = One  0  state)  (One = Zero  0  state)"
            using in_state by blast
          have cond2: "x<length (enumerate_list a ! j).
       (enumerate_list a ! j ! x = One  x + 1  state) 
       (enumerate_list a ! j ! x = Zero  x + 1  state)"
            using obt unfolding match_timestep_def advance_state_iff by fastforce
          have "x<length (One # enumerate_list a ! j)"
            using x by blast
          then have ?thesis
            using index_shift[of "One" state ?state, OF cond1 cond2] by blast
        }
        ultimately show ?thesis by blast
      qed
      then have match: "match_timestep state ?s"
        using obt headcase in_state unfolding match_timestep_def by blast
      have "(map ((#) One) (enumerate_list a) ! j) = One # (enumerate_list a ! j)"
        using obt by simp
      then have ?case unfolding headcase enumerate_list.simps
        using match obt by auto
    } moreover {
      assume headcase: "head = S"
      let ?s = "One # ?state"
      have "x. x<length ?s 
       ((?s ! x = One  x  state)  (?s ! x = Zero  x  state))"
      proof-
        fix x
        assume x: "x<length ?s"
        let ?thesis = "((?s ! x = One  x  state)  (?s ! x = Zero  x  state))"
        {assume x0: "x = 0"
          then have ?thesis using in_state by simp
        } moreover {
          assume x_ge_0: "x > 0"
          have cond1: "(One = One  0  state)  (One = Zero  0  state)"
            using in_state by blast
          have cond2: "x<length (enumerate_list a ! j).
       (enumerate_list a ! j ! x = One  x + 1  state) 
       (enumerate_list a ! j ! x = Zero  x + 1  state)"
            using obt unfolding match_timestep_def advance_state_iff by fastforce
          have "x<length (One # enumerate_list a ! j)"
            using x by blast
          then have ?thesis
            using index_shift[of "One" state ?state, OF cond1 cond2] by blast
        }
        ultimately show ?thesis by blast
      qed
      then have match: "match_timestep state ?s"
        using obt headcase in_state unfolding match_timestep_def by blast
      have "x. x<length (S # enumerate_list a ! j) 
       ((S # enumerate_list a ! j) ! x = One  x  state) 
       ((S # enumerate_list a ! j) ! x = Zero  x  state)"
      proof-
        fix x
        assume x: "x<length (S # enumerate_list a ! j)"
        let ?thesis = "((S # enumerate_list a ! j) ! x = One  x  state) 
       ((S # enumerate_list a ! j) ! x = Zero  x  state)"
        {assume x0: "x = 0"
          then have ?thesis by auto
        } moreover {
          assume x_ge_0: "x > 0"
          then have ?thesis using x match unfolding match_timestep_def by simp
        }
        ultimately show ?thesis by blast
      qed
      then have match_S: "match_timestep state (S # enumerate_list a ! j)"
        using match unfolding match_timestep_def by blast
      have j_bound: "j < length (enumerate_list a)"
        using obt by blast
      have "?s = enumerate_list (S # a)!((length (enumerate_list a))+j)
             (length (enumerate_list a))+j < length (enumerate_list (S # a))"
        using j_bound enumerate_list_index_one by blast
      then have ?case unfolding headcase
        using match obt match_S by metis
    }
    ultimately have ?case using head by blast
  } moreover {
    assume not_in: "0  state"
    then have "(head # a) ! 0  One"
      using Cons.prems unfolding match_timestep_def by blast
    then have head: "head = Zero  head = S"
      using WEST_bit.exhaust by auto
    have "match_timestep ?adv_state a"
      using Cons.prems
      using advance_state_match_timestep by auto
    then obtain j where obt: "match_timestep ?adv_state (enumerate_list a ! j)
                             j < length (enumerate_list a)"
      using Cons.hyps[of ?adv_state] by blast
    let ?state = "(enumerate_list a ! j)"
    {assume headcase: "head = Zero"
      let ?s = "Zero # ?state"
      have "x. x<length ?s 
       ((?s ! x = One  x  state)  (?s ! x = Zero  x  state))"
      proof-
        fix x
        assume x: "x<length ?s"
        let ?thesis = "((?s ! x = One  x  state)  (?s ! x = Zero  x  state))"
        {assume x0: "x = 0"
          then have ?thesis using not_in headcase by simp
        } moreover {
          assume x_ge_0: "x > 0"
          have cond1: "(Zero = One  0  state)  (Zero = Zero  0  state)"
            using not_in by blast
          have cond2: "x<length (enumerate_list a ! j).
       (enumerate_list a ! j ! x = One  x + 1  state) 
       (enumerate_list a ! j ! x = Zero  x + 1  state)"
            using obt unfolding match_timestep_def advance_state_iff by fastforce
          have "x<length (Zero # enumerate_list a ! j)"
            using x by blast
          then have ?thesis
            using index_shift[of "Zero" state ?state, OF cond1 cond2] by blast
        }
        ultimately show ?thesis by blast
      qed
      then have match: "match_timestep state ?s"
        using obt headcase not_in unfolding match_timestep_def by blast
      have ?case unfolding headcase enumerate_list.simps
        using match obt by auto
    } moreover {
      assume headcase: "head = S"
      let ?s = "Zero # ?state"
      have "x. x<length ?s 
       ((?s ! x = One  x  state)  (?s ! x = Zero  x  state))"
      proof-
        fix x
        assume x: "x<length ?s"
        let ?thesis = "((?s ! x = One  x  state)  (?s ! x = Zero  x  state))"
        {assume x0: "x = 0"
          then have ?thesis using not_in by simp
        } moreover {
          assume x_ge_0: "x > 0"
          have cond1: "(Zero = One  0  state)  (Zero = Zero  0  state)"
            using not_in by blast
          have cond2: "x<length (enumerate_list a ! j).
       (enumerate_list a ! j ! x = One  x + 1  state) 
       (enumerate_list a ! j ! x = Zero  x + 1  state)"
            using obt unfolding match_timestep_def advance_state_iff by fastforce
          have "x<length (Zero # enumerate_list a ! j)"
            using x by blast
          then have ?thesis
            using index_shift[of "Zero" state ?state, OF cond1 cond2] by blast
        }
        ultimately show ?thesis by blast
      qed
      then have match: "match_timestep state ?s"
        using obt headcase not_in unfolding match_timestep_def by blast
      have "x. x<length (S # enumerate_list a ! j) 
       ((S # enumerate_list a ! j) ! x = One  x  state) 
       ((S # enumerate_list a ! j) ! x = Zero  x  state)"
      proof-
        fix x
        assume x: "x<length (S # enumerate_list a ! j)"
        let ?thesis = "((S # enumerate_list a ! j) ! x = One  x  state) 
       ((S # enumerate_list a ! j) ! x = Zero  x  state)"
        {assume x0: "x = 0"
          then have ?thesis by auto
        } moreover {
          assume x_ge_0: "x > 0"
          then have ?thesis using x match unfolding match_timestep_def by simp
        }
        ultimately show ?thesis by blast
      qed
      then have match_S: "match_timestep state (S # enumerate_list a ! j)"
        using match unfolding match_timestep_def by blast
      have j_bound: "j < length (enumerate_list a)"
        using obt by blast
      have "?s = enumerate_list (S # a)!(j)
             j < length (enumerate_list (S # a))"
        using j_bound enumerate_list_index_zero by blast
      then have ?case unfolding headcase
        using match obt match_S by metis
    }
    ultimately have ?case using head by blast
  }
  ultimately show ?case by blast
qed


lemma enumerate_trace_head_in:
  assumes "a_head # a_tail  set (enumerate_trace (h # trace))"
  shows " a_head  set (enumerate_list h)"
proof -
    let ?flat = "flatten_list
     (map (λt. map (λh. h # t)
                 (enumerate_list h))
       (enumerate_trace trace))"
    have flat_is: "?flat = ?flat"
      by auto
    have mem: "List.member
     ?flat
     (a_head # a_tail)"
      using assms unfolding enumerate_trace.simps
      using in_set_member by metis
    then obtain x1_head x1_tail where
      x1_props: "a_head # a_tail = x1_head # x1_tail 
     List.member (enumerate_list h) x1_head 
     List.member (enumerate_trace trace) x1_tail"
     using flatten_list_shape[OF mem flat_is]  by auto
   then have "a_head = x1_head"
     by auto
   then have "List.member (enumerate_list h) a_head "
     using x1_props
     by auto
   then show ?thesis
    using in_set_member
    by fast
qed


lemma enumerate_trace_tail_in:
  assumes "a_head # a_tail  set (enumerate_trace (h # trace))"
  shows "a_tail  set (enumerate_trace trace)"
proof -
    let ?flat = "flatten_list
     (map (λt. map (λh. h # t)
                 (enumerate_list h))
       (enumerate_trace trace))"
    have flat_is: "?flat = ?flat"
      by auto
    have mem: "List.member
     ?flat
     (a_head # a_tail)"
      using assms unfolding enumerate_trace.simps
      using in_set_member by metis
    then obtain x1_head x1_tail where
      x1_props: "a_head # a_tail = x1_head # x1_tail 
     List.member (enumerate_list h) x1_head 
     List.member (enumerate_trace trace) x1_tail"
     using flatten_list_shape[OF mem flat_is]  by auto
   then have "a_tail = x1_tail"
     by auto
   then have "List.member (enumerate_trace trace) a_tail "
     using x1_props
     by auto
   then show ?thesis
    using in_set_member
    by fast
qed


text ‹ Intuitively, this says that the traces in enumerate trace h are
``more specific'' than h, which is ``more generic''---i.e., h
matches everything that each element of enumerate trace h matches. ›
lemma match_enumerate_trace_aux:
  assumes "a  set (enumerate_trace trace)"
  assumes "match_regex π a"
  shows "match_regex π trace"
proof -
  show ?thesis using assms proof (induct trace arbitrary: a π)
    case Nil
    then show ?case by auto
  next
    case (Cons h trace)
    then obtain a_head a_tail where obt_a: "a = a_head#a_tail"
      using enumerate_trace_len
      by (metis length_0_conv neq_Nil_conv)
    have "length π > 0"
      using Cons unfolding match_regex_def obt_a by auto
    then obtain π_head π_tail where obt_π: "π = π_head#π_tail"
      using min_list.cases by auto
    have cond1: "a_tail  set (enumerate_trace trace)"
      using Cons.prems(1) unfolding obt_a
      using enumerate_trace_tail_in by blast
    have cond2: "match_regex π_tail a_tail"
      using Cons.prems(2) unfolding obt_a obt_π match_regex_def by auto
    have match_tail: "match_regex π_tail trace"
      using Cons.hyps[OF cond1 cond2] by blast
    have a_head: "a_head  set (enumerate_list h)"
      using Cons.prems(1) unfolding obt_a
      using enumerate_trace_head_in by blast
    have "match_timestep π_head a_head"
      using Cons.prems(2) unfolding obt_π match_regex_def
      using obt_a by auto
    then have match_head: "match_timestep π_head h"
      using match_enumerate_state_aux[of a_head h π_head] a_head by blast
    have "time. time<length (h # trace) 
        match_timestep ((π_head # π_tail) ! time) ((h # trace) ! time)"
    proof-
      fix time
      assume time: "time<length (h # trace)"
      let ?thesis = "match_timestep ((π_head # π_tail) ! time) ((h # trace) ! time)"
      {assume time0: "time = 0"
        then have ?thesis using match_head by simp
      } moreover {
        assume time_ge_0: "time > 0"
        then have ?thesis
          using match_tail time_ge_0 time unfolding match_regex_def by simp
      }
      ultimately show ?thesis by blast
    qed
    then show ?case using match_tail unfolding match_regex_def obt_a obt_π
      by simp
  qed
qed


lemma match_enumerate_trace:
  assumes "a  set (enumerate_trace h)  match_regex π a"
  shows "match π (h # T)"
proof-
  show ?thesis
    unfolding match_def
    using match_enumerate_trace_aux assms
    by auto
qed


lemma match_enumerate_sets1:
  assumes "(r  (enumerate_sets R). match_regex π r)"
  shows "(match π R)"
  using assms
proof (induct R)
  case Nil
  then show ?case by simp
next
  case (Cons h T)
  then obtain a where a_prop: "aset (enumerate_trace h)  enumerate_sets T  match_regex π a"
    by auto
  { assume *: "a  set (enumerate_trace h)"
    then have ?case
      using match_enumerate_trace a_prop
      by blast
  } moreover {assume *: "a  enumerate_sets T"
    then have "match π T"
      using Cons a_prop by blast
    then have ?case
      by (metis Suc_leI le_imp_less_Suc length_Cons match_def nth_Cons_Suc)
  }
  ultimately show ?case
    using a_prop by auto
qed

lemma match_cases:
  assumes "match π (a # R)"
  shows "match π [a]  match π R"
proof-
  obtain i where obt: "match_regex π ((a # R)!i)  i < length (a # R)"
    using assms unfolding match_def by blast
  {assume i0: "i = 0"
    then have ?thesis
      using assms unfolding match_def using obt by simp
  } moreover {
    assume i_ge_0: "i > 0"
    then have "match_regex π (R ! (i-1))"
      using assms obt unfolding match_def by simp
    then have "match π R"
      unfolding match_def using obt i_ge_0
      by (metis Suc_diff_1 Suc_less_eq length_Cons)
    then have ?thesis by blast
  }
  ultimately show ?thesis using assms unfolding match_def by blast
qed


lemma enumerate_trace_decompose:
  assumes "state  set (enumerate_list h)"
  assumes "trace  set (enumerate_trace T)"
  shows "state#trace  set (enumerate_trace (h#T))"
proof-
  let ?enumh = "enumerate_list h"
  let ?enumT = "enumerate_trace T"
  let ?flat = "flatten_list (map (λt. map (λh. h # t) ?enumh) ?enumT)"
  have enum: "enumerate_trace (h#T) = ?flat"
    unfolding enumerate_trace.simps by simp
  obtain i where i: "?enumT!i = trace  i < length ?enumT"
    using assms(2) by (meson in_set_conv_nth)
  obtain j where j: "?enumh!j = state  j < length ?enumh"
    using assms(1) by (meson in_set_conv_nth)
  have "enumerate_list h ! j # enumerate_trace T ! i =
    flatten_list (map (λt. map (λh. h # t) (enumerate_list h)) (enumerate_trace T)) !
    (i * length (enumerate_list h) + j) 
    i * length (enumerate_list h) + j
    < length
       (flatten_list
         (map (λt. map (λh. h # t) (enumerate_list h)) (enumerate_trace T)))"
    using flatten_list_idx[of ?flat ?enumh ?enumT i j] enum i j by blast
  then show ?thesis
    using i j enum by simp
qed


lemma match_enumerate_trace_aux_converse:
  assumes "match_regex π trace"
  shows "match π (enumerate_trace trace)"
  using assms
proof(induct trace arbitrary: π)
  case Nil
  have enum: "enumerate_trace [] = [[]]"
    by simp
  show ?case unfolding enum match_def match_regex_def by auto
next
  case (Cons a trace)
  have "length π > 0"
    using Cons.prems unfolding match_regex_def by auto
  then obtain pi_head pi_tail where pi_obt: "π = pi_head#pi_tail"
    using list.exhaust by auto
  have cond: "match_regex pi_tail trace"
    using Cons.prems pi_obt unfolding match_regex_def by auto
  then have match_tail: "match pi_tail (enumerate_trace trace)"
    using Cons.hyps by blast
  then obtain i where obt_i: "match_regex pi_tail (enumerate_trace trace ! i) 
         i<length (enumerate_trace trace)"
    unfolding match_def by blast
  let ?enum_tail = "(enumerate_trace trace ! i)"

  have match_head: "match_timestep pi_head a"
    using Cons.prems unfolding match_regex_def
    by (metis Cons.prems WEST_and_trace_correct_forward_aux nth_Cons' pi_obt)
  then have "j < length (enumerate_list a).
             match_timestep pi_head ((enumerate_list a)!j)"
    using match_enumerate_list by blast
  then obtain j where obt_j: "match_timestep pi_head ((enumerate_list a)!j) 
                       j < length (enumerate_list a)"
    by blast
  let ?enum_head = "(enumerate_list a)!j"

  have "(?enum_head#?enum_tail)  set(enumerate_trace (a # trace))"
    using enumerate_trace_decompose
    by (meson in_set_conv_nth obt_i obt_j)
  have match_tail: "match_regex pi_tail ?enum_tail"
    using obt_i by blast
  have match_head: "match_timestep pi_head ((enumerate_list a)!j)"
    using obt_j by blast
  have match: "match_regex π (?enum_head#?enum_tail)"
    using match_head match_tail
    using WEST_and_trace_correct_forward_aux_converse[OF pi_obt match_head match_tail] by auto
  let ?flat = "flatten_list
     (map (λt. map (λh. h # t) (enumerate_list a))
       (enumerate_trace trace))"
  have "enumerate_list a ! j # enumerate_trace trace ! i =
  flatten_list
   (map (λt. map (λh. h # t) (enumerate_list a)) (enumerate_trace trace)) !
  (i * length (enumerate_list a) + j) 
  i * length (enumerate_list a) + j
  < length
     (flatten_list
       (map (λt. map (λh. h # t) (enumerate_list a)) (enumerate_trace trace)))"
    using flatten_list_idx[of ?flat "enumerate_list a" "enumerate_trace trace" i j]
    using obt_i obt_j by blast
  then show ?case
    unfolding match_def using match
    by auto
qed


lemma match_enumerate_sets2:
  assumes "(match π R)"
  shows "(r  enumerate_sets R. match_regex π r)"
  using assms
proof(induct R arbitrary: π)
  case Nil
  then show ?case unfolding match_def by auto
next
  case (Cons a R)
  have "enumerate_sets (a # R) = set (enumerate_trace a)  enumerate_sets R"
    unfolding enumerate_sets.simps by blast
  {assume match_a: "match π [a]"
    then have "match_regex π a"
      unfolding match_def by simp
    then have "match π (enumerate_trace a)"
      using match_enumerate_trace_aux
      using match_enumerate_trace_aux_converse by blast
    then have "b  set (enumerate_trace a). match_regex π b"
      unfolding match_def by auto
    then have ?case by auto
  } moreover {
    assume match_R: "match π R"
    then have ?case
      using Cons by auto
  }
  ultimately show ?case
    using Cons.prems match_cases by blast
qed


lemma match_enumerate_sets:
  shows "(r  enumerate_sets R. match_regex π r)  (match π R)"
  using match_enumerate_sets1 match_enumerate_sets2
  by blast

lemma regex_equivalence_correct1:
  assumes "(naive_equivalence A B)"
  shows "match π A = match π B"
  unfolding regex_equiv_def
  using match_enumerate_sets[of A π] match_enumerate_sets[of B π]
  using assms
  unfolding naive_equivalence.simps
  by blast


lemma regex_equivalence_correct:
  shows "(naive_equivalence A B)  (regex_equiv A B)"
  using regex_equivalence_correct1
  unfolding regex_equiv_def
  by metis


export_code naive_equivalence in Haskell module_name regex_equiv

end