Theory Automatic_Refinement.Param_HOL

section ‹Parametricity Theorems for HOL›
theory Param_HOL
imports Param_Tool
begin

subsection ‹Sets›

lemma param_empty[param]:
  "({},{})Rset_rel" by (auto simp: set_rel_def)

lemma param_member[param]:
  "single_valued R; single_valued (R¯)  ((∈), (∈))  R  Rset_rel  bool_rel"  
  unfolding set_rel_def
  by (blast dest: single_valuedD)

    
lemma param_insert[param]:
  "(insert,insert)RRset_relRset_rel"
  by (auto simp: set_rel_def)

lemma param_union[param]:
  "((∪), (∪))  Rset_rel  Rset_rel  Rset_rel"
  by (auto simp: set_rel_def)

lemma param_inter[param]:
  assumes "single_valued R" "single_valued (R¯)"
  shows "((∩), (∩))  Rset_rel  Rset_rel  Rset_rel"
  using assms  
  unfolding set_rel_def
  by (blast dest: single_valuedD)

lemma param_diff[param]:
  assumes "single_valued R" "single_valued (R¯)"
  shows "((-), (-))  Rset_rel  Rset_rel  Rset_rel"
  using assms 
  unfolding set_rel_def
  by (blast dest: single_valuedD)
    
lemma param_subseteq[param]: 
  "single_valued R; single_valued (R¯)  ((⊆), (⊆))  Rset_rel  Rset_rel  bool_rel"
  unfolding set_rel_def
  by (blast dest: single_valuedD)

lemma param_subset[param]: 
  "single_valued R; single_valued (R¯)  ((⊂), (⊂))  Rset_rel  Rset_rel  bool_rel"
  unfolding set_rel_def 
  by (blast dest: single_valuedD)

lemma param_Ball[param]: "(Ball,Ball)Raset_rel(RaId)Id"
  by (force simp: set_rel_alt dest: fun_relD) 
  
lemma param_Bex[param]: "(Bex,Bex)Raset_rel(RaId)Id"
  by (fastforce simp: set_rel_def dest: fun_relD)
    
    
lemma param_set[param]: 
  "single_valued Ra  (set,set)Ralist_rel  Raset_rel"
proof 
  fix l l'
  assume A: "single_valued Ra"
  assume "(l,l')Ralist_rel"
  thus "(set l, set l')Raset_rel"
    apply (induct)
    apply simp
    apply simp
    using A apply (parametricity)
    done
qed
  
lemma param_Collect[param]: 
  "Domain A = UNIV; Range A = UNIV  (Collect,Collect)(Abool_rel)  Aset_rel"
  unfolding set_rel_def
  apply (clarsimp; safe)
  subgoal using fun_relD1 by fastforce
  subgoal using fun_relD2 by fastforce  
  done  
  
lemma param_finite[param]: "
    single_valued R; single_valued (R¯)
    (finite,finite)  Rset_rel  bool_rel"
  using finite_set_rel_transfer finite_set_rel_transfer_back by blast

lemma param_card[param]: "single_valued R; single_valued (R¯) 
   (card, card)  Rset_rel nat_rel"
  apply (rule rel2pD)
  apply (simp only: rel2p)
  apply (rule card_transfer)
  by (simp add: rel2p_bi_unique)
  
  
subsection ‹Standard HOL Constructs›  
  
lemma param_if[param]: 
  assumes "(c,c')Id"
  assumes "c;c'  (t,t')R"
  assumes "¬c;¬c'  (e,e')R"
  shows "(If c t e, If c' t' e')R"
  using assms by auto

lemma param_Let[param]: 
  "(Let,Let)Ra  (RaRr)  Rr"
  by (auto dest: fun_relD)

subsection ‹Functions›  
    
lemma param_id[param]: "(id,id)RR" unfolding id_def by parametricity

lemma param_fun_comp[param]: "((o), (o))  (RaRb)  (RcRa)  RcRb" 
  unfolding comp_def[abs_def] by parametricity

lemma param_fun_upd[param]: "
  ((=), (=))  RaRaId 
   (fun_upd,fun_upd)  (RaRb)  Ra  Rb  Ra  Rb"
  unfolding fun_upd_def[abs_def]
  by (parametricity)


    
subsection ‹Boolean›  
    
lemma rec_bool_is_case: "old.rec_bool = case_bool"
  by (rule ext)+ (auto split: bool.split)

lemma param_bool[param]:
  "(True,True)Id"
  "(False,False)Id"
  "(conj,conj)IdIdId"
  "(disj,disj)IdIdId"
  "(Not,Not)IdId"
  "(case_bool,case_bool)RRIdR"
  "(old.rec_bool,old.rec_bool)RRIdR"
  "((⟷), (⟷))IdIdId"
  "((⟶), (⟶))IdIdId"
  by (auto split: bool.split simp: rec_bool_is_case)

lemma param_and_cong1: " (a,a')bool_rel; a; a'  (b,b')bool_rel   (ab,a'b')bool_rel"
  by blast
lemma param_and_cong2: " (a,a')bool_rel; a; a'  (b,b')bool_rel   (ba,b'a')bool_rel"
  by blast
    
    
subsection ‹Nat›  
    
lemma param_nat1[param]:
  "(0, 0::nat)  Id"
  "(Suc, Suc)  Id  Id"
  "(1, 1::nat)  Id"
  "(numeral n::nat,numeral n::nat)  Id"
  "((<), (<) ::nat  _)  Id  Id  Id"
  "((≤), (≤) ::nat  _)  Id  Id  Id"
  "((=), (=) ::nat  _)  Id  Id  Id"
  "((+) ::nat_,(+))IdIdId"
  "((-) ::nat_,(-))IdIdId"
  "((*) ::nat_,(*))IdIdId"
  "((div) ::nat_,(div))IdIdId"
  "((mod) ::nat_,(mod))IdIdId"
  by auto

lemma param_case_nat[param]:
  "(case_nat,case_nat)Ra  (Id  Ra)  Id  Ra"
  apply (intro fun_relI)
  apply (auto split: nat.split dest: fun_relD)
  done

lemma param_rec_nat[param]: 
  "(rec_nat,rec_nat)  R  (Id  R  R)  Id  R"
proof (intro fun_relI, goal_cases)
  case (1 s s' f f' n n') thus ?case
    apply (induct n' arbitrary: n s s')
    apply (fastforce simp: fun_rel_def)+
    done
qed

subsection ‹Int›  
  
lemma param_int[param]:
  "(0, 0::int)  Id"
  "(1, 1::int)  Id"
  "(numeral n::int,numeral n::int)  Id"
  "((<), (<) ::int  _)  Id  Id  Id"
  "((≤), (≤) ::int  _)  Id  Id  Id"
  "((=), (=) ::int  _)  Id  Id  Id"
  "((+) ::int_,(+))IdIdId"
  "((-) ::int_,(-))IdIdId"
  "((*) ::int_,(*))IdIdId"
  "((div) ::int_,(div))IdIdId"
  "((mod) ::int_,(mod))IdIdId"
  by auto

subsection ‹Product›  
    
lemma param_unit[param]: "((),())unit_rel" by auto
    
lemma rec_prod_is_case: "old.rec_prod = case_prod"
  by (rule ext)+ (auto split: bool.split)

lemma param_prod[param]:
  "(Pair,Pair)Ra  Rb  Ra,Rbprod_rel"
  "(case_prod,case_prod)  (Ra  Rb  Rr)  Ra,Rbprod_rel  Rr"
  "(old.rec_prod,old.rec_prod)  (Ra  Rb  Rr)  Ra,Rbprod_rel  Rr"
  "(fst,fst)Ra,Rbprod_rel  Ra"
  "(snd,snd)Ra,Rbprod_rel  Rb"
  by (auto dest: fun_relD split: prod.split 
    simp: prod_rel_def rec_prod_is_case)

lemma param_case_prod':
  " (p,p')Ra,Rbprod_rel;
     a b a' b'.  p=(a,b); p'=(a',b'); (a,a')Ra; (b,b')Rb  
       (f a b, f' a' b')R
      (case_prod f p, case_prod f' p')  R"
  by (auto split: prod.split)

lemma param_case_prod'': (* TODO: Really needed? *)
  " 
    a b a' b'. p=(a,b); p'=(a',b')  (f a b,f' a' b')R  
    (case_prod f p, case_prod f' p')R"
  by (auto split: prod.split)


lemma param_map_prod[param]: 
  "(map_prod, map_prod) 
   (RaRb)  (RcRd)  Ra,Rcprod_rel  Rb,Rdprod_rel"
  unfolding map_prod_def[abs_def]
  by parametricity

lemma param_apfst[param]: 
  "(apfst,apfst)(RaRb)Ra,Rcprod_relRb,Rcprod_rel"
  unfolding apfst_def[abs_def] by parametricity

lemma param_apsnd[param]: 
  "(apsnd,apsnd)(RbRc)Ra,Rbprod_relRa,Rcprod_rel"
  unfolding apsnd_def[abs_def] by parametricity

lemma param_curry[param]: 
  "(curry,curry)  (Ra,Rbprod_rel  Rc)  Ra  Rb  Rc"
  unfolding curry_def by parametricity

lemma param_uncurry[param]: "(uncurry,uncurry)  (ABC)  A×rBC"
  unfolding uncurry_def[abs_def] by parametricity
    
lemma param_prod_swap[param]: "(prod.swap, prod.swap)A×rB  B×rA" by auto
    
context partial_function_definitions begin
  lemma 
    assumes M: "monotone le_fun le_fun F" 
    and M': "monotone le_fun le_fun F'"
    assumes ADM: 
      "admissible (λa. x xa. (x, xa)  Rb  (a x, fixp_fun F' xa)  Ra)"
    assumes bot: "x xa. (x, xa)  Rb  (lub {}, fixp_fun F' xa)  Ra"
    assumes F: "(F,F')(RbRa)RbRa"
    assumes A: "(x,x')Rb"
    shows "(fixp_fun F x, fixp_fun F' x')Ra"
    using A
    apply (induct arbitrary: x x' rule: ccpo.fixp_induct[OF ccpo _ M])
    apply (rule ADM)
    apply(simp add: fun_lub_def bot)
    apply (subst ccpo.fixp_unfold[OF ccpo M'])
    apply (parametricity add: F)
    done
end

subsection ‹Option›  

lemma param_option[param]:
  "(None,None)Roption_rel"
  "(Some,Some)R  Roption_rel"
  "(case_option,case_option)Rr(R  Rr)Roption_rel  Rr"
  "(rec_option,rec_option)Rr(R  Rr)Roption_rel  Rr"
  by (auto split: option.split 
    simp: option_rel_def case_option_def[symmetric]
    dest: fun_relD)
  
lemma param_map_option[param]: "(map_option, map_option)  (A  B)  Aoption_rel  Boption_rel"
  apply (intro fun_relI)
  apply (auto elim!: option_relE dest: fun_relD)
  done

lemma param_case_option':
  " (x,x')Rvoption_rel; 
     x=None; x'=None   (fn,fn')R;  
     v v'.  x=Some v; x'=Some v'; (v,v')Rv   (fs v, fs' v')R
     (case_option fn fs x, case_option fn' fs' x')  R"
  by (auto split: option.split)

lemma the_paramL: "lNone; (l,r)Roption_rel  (the l, the r)R"
  apply (cases l)
  by (auto elim: option_relE)

lemma the_paramR: "rNone; (l,r)Roption_rel  (the l, the r)R"
  apply (cases l)
  by (auto elim: option_relE)

lemma the_default_param[param]: 
  "(the_default, the_default)  R  Roption_rel  R"
  unfolding the_default_def
  by parametricity

subsection ‹Sum›  
    
lemma rec_sum_is_case: "old.rec_sum = case_sum"
  by (rule ext)+ (auto split: sum.split)

lemma param_sum[param]:
  "(Inl,Inl)  Rl  Rl,Rrsum_rel"
  "(Inr,Inr)  Rr  Rl,Rrsum_rel"
  "(case_sum,case_sum)  (Rl  R)  (Rr  R)  Rl,Rrsum_rel  R"
  "(old.rec_sum,old.rec_sum)  (Rl  R)  (Rr  R)  Rl,Rrsum_rel  R"
  by (fastforce split: sum.split dest: fun_relD 
    simp: rec_sum_is_case)+

lemma param_case_sum':
  " (s,s')Rl,Rrsum_rel;
     l l'.  s=Inl l; s'=Inl l'; (l,l')Rl   (fl l, fl' l')R;
     r r'.  s=Inr r; s'=Inr r'; (r,r')Rr   (fr r, fr' r')R
     (case_sum fl fr s, case_sum fl' fr' s')R"
  by (auto split: sum.split)

primrec is_Inl where "is_Inl (Inl _) = True" | "is_Inl (Inr _) = False"
primrec is_Inr where "is_Inr (Inr _) = True" | "is_Inr (Inl _) = False"

lemma is_Inl_param[param]: "(is_Inl,is_Inl)  Ra,Rbsum_rel  bool_rel"
  unfolding is_Inl_def by parametricity
lemma is_Inr_param[param]: "(is_Inr,is_Inr)  Ra,Rbsum_rel  bool_rel"
  unfolding is_Inr_def by parametricity

lemma sum_projl_param[param]: 
  "is_Inl s; (s',s)Ra,Rbsum_rel 
   (Sum_Type.sum.projl s',Sum_Type.sum.projl s)  Ra"
  apply (cases s)
  apply (auto elim: sum_relE)
  done

lemma sum_projr_param[param]: 
  "is_Inr s; (s',s)Ra,Rbsum_rel 
   (Sum_Type.sum.projr s',Sum_Type.sum.projr s)  Rb"
  apply (cases s)
  apply (auto elim: sum_relE)
  done

subsection ‹List›  
        
lemma list_rel_append1: "(as @ bs, l)  Rlist_rel 
   (cs ds. l = cs@ds  (as,cs)Rlist_rel  (bs,ds)Rlist_rel)"
  apply (simp add: list_rel_def list_all2_append1)
  apply auto
  apply (metis list_all2_lengthD)
  done

lemma list_rel_append2: "(l,as @ bs)  Rlist_rel 
   (cs ds. l = cs@ds  (cs,as)Rlist_rel  (ds,bs)Rlist_rel)"
  apply (simp add: list_rel_def list_all2_append2)
  apply auto
  apply (metis list_all2_lengthD)
  done


lemma param_append[param]: 
  "(append, append)Rlist_rel  Rlist_rel  Rlist_rel"
  by (auto simp: list_rel_def list_all2_appendI)

lemma param_list1[param]:
  "(Nil,Nil)Rlist_rel"
  "(Cons,Cons)R  Rlist_rel  Rlist_rel"
  "(case_list,case_list)Rr(RRlist_relRr)Rlist_relRr"
  apply (force dest: fun_relD split: list.split)+
  done

lemma param_rec_list[param]: 
  "(rec_list,rec_list) 
   Ra  (Rb  Rblist_rel  Ra  Ra)  Rblist_rel  Ra"
proof (intro fun_relI, goal_cases)
  case prems: (1 a a' f f' l l')
  from prems(3) show ?case
    using prems(1,2)
    apply (induct arbitrary: a a')
    apply simp
    apply (fastforce dest: fun_relD)
    done
qed

lemma param_case_list':
  " (l,l')Rblist_rel;
     l=[]; l'=[]  (n,n')Ra;  
     x xs x' xs'.  l=x#xs; l'=x'#xs'; (x,x')Rb; (xs,xs')Rblist_rel  
      (c x xs, c' x' xs')Ra
     (case_list n c l, case_list n' c' l')  Ra"
  by (auto split: list.split)
    
lemma param_map[param]: 
  "(map,map)(R1R2)  R1list_rel  R2list_rel"
  unfolding map_rec[abs_def] by (parametricity)
    
lemma param_fold[param]: 
  "(fold,fold)(ReRsRs)  Relist_rel  Rs  Rs"
  "(foldl,foldl)(RsReRs)  Rs  Relist_rel  Rs"
  "(foldr,foldr)(ReRsRs)  Relist_rel  Rs  Rs"
  unfolding List.fold_def List.foldr_def List.foldl_def
  by (parametricity)+

  lemma param_list_all[param]: "(list_all,list_all)  (Abool_rel)  Alist_rel  bool_rel"
    by (fold rel2p_def) (simp add: rel2p List.list_all_transfer)

context begin
  private primrec list_all2_alt :: "('a  'b  bool)  'a list  'b list  bool" where
    "list_all2_alt P [] ys  (case ys of []  True | _  False)"
  | "list_all2_alt P (x#xs) ys  (case ys of []  False | y#ys  P x y  list_all2_alt P xs ys)"
  
  private lemma list_all2_alt: "list_all2 P xs ys = list_all2_alt P xs ys"
    by (induction xs arbitrary: ys) (auto split: list.splits)
  
  lemma param_list_all2[param]: "(list_all2, list_all2)  (ABbool_rel)  Alist_rel  Blist_rel  bool_rel"
    unfolding list_all2_alt[abs_def] 
    unfolding list_all2_alt_def[abs_def] 
    by parametricity
  
end
  
lemma param_hd[param]: "l[]  (l',l)Alist_rel  (hd l', hd l)A"
  unfolding hd_def by (auto split: list.splits)

lemma param_last[param]: 
  assumes "y  []" 
  assumes "(x, y)  Alist_rel"  
  shows "(last x, last y)  A"
  using assms(2,1)
  by (induction rule: list_rel_induct) auto

lemma param_rotate1[param]: "(rotate1, rotate1)  Alist_rel  Alist_rel"
  unfolding rotate1_def by parametricity
    
schematic_goal param_take[param]: "(take,take)(?R::(_×_) set)"
  unfolding take_def 
  by (parametricity)

schematic_goal param_drop[param]: "(drop,drop)(?R::(_×_) set)"
  unfolding drop_def 
  by (parametricity)

schematic_goal param_length[param]: 
  "(length,length)(?R::(_×_) set)"
  unfolding size_list_overloaded_def size_list_def 
  by (parametricity)

fun list_eq :: "('a  'a  bool)  'a list  'a list  bool" where
  "list_eq eq [] []  True"
| "list_eq eq (a#l) (a'#l') 
      (if eq a a' then list_eq eq l l' else False)"
| "list_eq _ _ _  False"

lemma param_list_eq[param]: "
  (list_eq,list_eq)  
    (R  R  Id)  Rlist_rel  Rlist_rel  Id"
proof (intro fun_relI, goal_cases)
  case prems: (1 eq eq' l1 l1' l2 l2')
  thus ?case
    apply -
    apply (induct eq' l1' l2' arbitrary: l1 l2 rule: list_eq.induct)
    apply (simp_all only: list_eq.simps |
      elim list_relE |
      parametricity)+
    done
qed

lemma id_list_eq_aux[simp]: "(list_eq (=)) = (=)"
proof (intro ext)
  fix l1 l2 :: "'a list"
  show "list_eq (=) l1 l2 = (l1 = l2)"
    apply (induct "(=) :: 'a  _" l1 l2 rule: list_eq.induct)
    apply simp_all
    done
qed

lemma param_list_equals[param]:
  " ((=), (=))  RRId  
   ((=), (=))  Rlist_rel  Rlist_rel  Id"
  unfolding id_list_eq_aux[symmetric]
  by (parametricity) 

lemma param_tl[param]:
  "(tl,tl)  Rlist_rel  Rlist_rel"
  unfolding tl_def[abs_def]
  by (parametricity)


primrec list_all_rec where
  "list_all_rec P []  True"
| "list_all_rec P (a#l)  P a  list_all_rec P l"

primrec list_ex_rec where
  "list_ex_rec P []  False"
| "list_ex_rec P (a#l)  P a  list_ex_rec P l"

lemma list_all_rec_eq: "(xset l. P x) = list_all_rec P l"
  by (induct l) auto

lemma list_ex_rec_eq: "(xset l. P x) = list_ex_rec P l"
  by (induct l) auto

lemma param_list_ball[param]:
  "(P,P')(RaId); (l,l')Ra list_rel 
     (xset l. P x, xset l'. P' x)  Id"
  unfolding list_all_rec_eq
  unfolding list_all_rec_def
  by (parametricity)

lemma param_list_bex[param]:
  "(P,P')(RaId); (l,l')Ra list_rel 
     (xset l. P x, xset l'. P' x)  Id"
  unfolding list_ex_rec_eq[abs_def]
  unfolding list_ex_rec_def
  by (parametricity)

lemma param_rev[param]: "(rev,rev)  Rlist_rel  Rlist_rel"
  unfolding rev_def
  by (parametricity)
  
lemma param_foldli[param]: "(foldli, foldli) 
   Relist_rel  (RsId)  (ReRsRs)  Rs  Rs"
  unfolding foldli_def
  by parametricity

lemma param_foldri[param]: "(foldri, foldri) 
   Relist_rel  (RsId)  (ReRsRs)  Rs  Rs"
  unfolding foldri_def[abs_def]
  by parametricity

lemma param_nth[param]: 
  assumes I: "i'<length l'"
  assumes IR: "(i,i')nat_rel"
  assumes LR: "(l,l')Rlist_rel" 
  shows "(l!i,l'!i')  R"
  using LR I IR
  by (induct arbitrary: i i' rule: list_rel_induct) 
     (auto simp: nth.simps split: nat.split)

lemma param_replicate[param]:
  "(replicate,replicate)  nat_rel  R  Rlist_rel"
  unfolding replicate_def by parametricity

term list_update
lemma param_list_update[param]: 
  "(list_update,list_update)  Ralist_rel  nat_rel  Ra  Ralist_rel"
  unfolding list_update_def[abs_def] by parametricity

lemma param_zip[param]:
  "(zip, zip)  Ralist_rel  Rblist_rel  Ra,Rbprod_rellist_rel"
    unfolding zip_def by parametricity

lemma param_upt[param]:
  "(upt, upt)  nat_rel  nat_rel  nat_rellist_rel"
   unfolding upt_def[abs_def] by parametricity

lemma param_concat[param]: "(concat, concat)  
    Rlist_rellist_rel  Rlist_rel"
unfolding concat_def[abs_def] by parametricity

lemma param_all_interval_nat[param]: 
  "(List.all_interval_nat, List.all_interval_nat) 
   (nat_rel  bool_rel)  nat_rel  nat_rel  bool_rel"
  unfolding List.all_interval_nat_def[abs_def]
  apply parametricity
  apply simp
  done

lemma param_dropWhile[param]: 
  "(dropWhile, dropWhile)  (a  bool_rel)  alist_rel  alist_rel"
  unfolding dropWhile_def by parametricity

lemma param_takeWhile[param]: 
  "(takeWhile, takeWhile)  (a  bool_rel)  alist_rel  alist_rel"
  unfolding takeWhile_def by parametricity



end