Theory Shuffle

theory Shuffle
  imports 
    CryptHOL.CryptHOL
    Additive_Sharing
    Spmf_Common
    Sharing_Lemmas
begin

text ‹
This is the formalization of the array shuffling protocol defined in \cite{laud2016secure}
adapted for the ABY3 sharing scheme.
For the moment, we assume an oracle that generates uniformly distributed permutations,
instead of instantiating it with e.g. Fischer-Yates algorithm.
›

no_notation (ASCII) comp  (infixl "o" 55)
no_notation m_inv ("invı _" [81] 80)

no_adhoc_overloading Monad_Syntax.bind bind_pmf

fun shuffleF :: "natL sharing list  natL sharing list spmf" where
  "shuffleF xsl = spmf_of_set (permutations_of_multiset (mset xsl))"

type_synonym zero_sharing = "natL sharing list"
type_synonym party2_data = "natL list"
type_synonym party01_permutation = "nat  nat"
type_synonym phase_msg = "zero_sharing × party2_data × party01_permutation"

type_synonym role_msg = "(natL list × natL list × natL list) × party2_data × (party01_permutation × party01_permutation)"

(* (a, b, c) → (a, b+c, 0) *)
definition aby3_stack_sharing :: "Role  natL sharing  natL sharing" where
  "aby3_stack_sharing r s = make_sharing' r (next_role r) (prev_role r)
                             (get_party r s)
                             (get_party (next_role r) s + get_party (prev_role r) s)
                             0"

(* one permutation step *)
definition aby3_do_permute :: "Role  natL sharing list  (phase_msg × natL sharing list) spmf" where
  "aby3_do_permute r x = (do {
    let n = length x;
    ζ  sequence_spmf (replicate n zero_sharing);
    π  spmf_of_set {π. π permutes {..<n}};
    let x2 = map (get_party (prev_role r)) x;
    let y' = map (aby3_stack_sharing r) x;
    let y = map2 (map_sharing2 (+)) (permute_list π y') ζ;
    let msg = (ζ, x2, π);
    return_spmf (msg, y)
  })"

(* the shuffling protocol, consisting of three shuffling steps *)
definition aby3_shuffleR :: "natL sharing list  (role_msg sharing × natL sharing list) spmf" where
  "aby3_shuffleR x = (do {
    ((ζa,x',πa), a)  aby3_do_permute Party1 x;     ― ‹1st round›
    ((ζb,a',πb), b)  aby3_do_permute Party2 a;     ― ‹2nd round›
    ((ζc,b',πc), c)  aby3_do_permute Party3 b;     ― ‹3rd round›
    let msg1 = ((map (get_party Party1) ζa, map (get_party Party1) ζb, map (get_party Party1) ζc), b', πa, πc);
    let msg2 = ((map (get_party Party2) ζa, map (get_party Party2) ζb, map (get_party Party2) ζc), x', πa, πb);
    let msg3 = ((map (get_party Party3) ζa, map (get_party Party3) ζb, map (get_party Party3) ζc), a', πb, πc);
    let msg = make_sharing msg1 msg2 msg3;
    return_spmf (msg, c)
  })"


(* the ideal functionality of shuffling *)
definition aby3_shuffleF :: "natL sharing list  natL sharing list spmf" where
  "aby3_shuffleF x = (do {
    π  spmf_of_set {π. π permutes {..<length x}};
    let x1 = map reconstruct x;
    let  = permute_list π x1;
    y  sequence_spmf (map share_nat );
    return_spmf y
  })"

(* the simulator for party 1 *)
definition S1 :: "natL list  natL list  role_msg spmf" where
    "S1 x1 yc1 = (do {
       let n = length x1;

         πa  spmf_of_set {π. π permutes {..<n}};
         πc  spmf_of_set {π. π permutes {..<n}};

         ya1::natL list  sequence_spmf (replicate n (spmf_of_set UNIV));

         yb1::natL list  sequence_spmf (replicate n (spmf_of_set UNIV));
         yb2::natL list  sequence_spmf (replicate n (spmf_of_set UNIV));

― ‹round 1›
         let ζa1 = map2 (-) ya1 (permute_list πa x1);

― ‹round 2›
         let ζb1 = yb1;

― ‹round 3›
         let b' = yb2;
         let ζc1 = map2 (-) (yc1) (permute_list πc (map2 (+) yb1 yb2));  ― ‹non-free message›

       let msg1 = ((ζa1, ζb1, ζc1), b', πa, πc);
       return_spmf msg1
    })"

(* the simulator for party 2 *)
definition S2 :: "natL list  natL list  role_msg spmf" where
    "S2 x2 yc2 = (do {
       let n = length x2;
       x3  sequence_spmf (replicate n (spmf_of_set UNIV));

         πa  spmf_of_set {π. π permutes {..<n}};
         πb  spmf_of_set {π. π permutes {..<n}};

         ya2::natL list  sequence_spmf (replicate n (spmf_of_set UNIV));

         yb2::natL list  sequence_spmf (replicate n (spmf_of_set UNIV));

― ‹round 1›
         let x' = x3;
         let ζa2 = map2 (-) ya2 (permute_list πa (map2 (+) x2 x3));

― ‹round 2›
         let ζb2 = map2 (-) yb2 (permute_list πb ya2);

― ‹round 3›
         let ζc2 = yc2;  ― ‹non-free message›

       let msg2 = ((ζa2, ζb2, ζc2), x', πa, πb);
       return_spmf msg2
    })"

(* the simulator for party 3 *)
definition S3 :: "natL list  natL list  role_msg spmf" where
    "S3 x3 yc3 = (do {
       let n = length x3;

         πb  spmf_of_set {π. π permutes {..<n}};
         πc  spmf_of_set {π. π permutes {..<n}};

         ya3::natL list  sequence_spmf (replicate n (spmf_of_set UNIV));
         ya1::natL list  sequence_spmf (replicate n (spmf_of_set UNIV));

         yb3::natL list  sequence_spmf (replicate n (spmf_of_set UNIV));

― ‹round 1›
         let ζa3 = ya3;

― ‹round 2›
         let a' = ya1;
         let ζb3 = map2 (-) yb3 (permute_list πb (map2 (+) ya3 ya1));

― ‹round 3›
         let ζc3 = map2 (-) yc3 (permute_list πc yb3);  ― ‹non-free message›

       let msg3 = ((ζa3, ζb3, ζc3), a', πb, πc);
       return_spmf msg3

    })"

definition S :: "Role  natL list  natL list  role_msg spmf" where
  "S r = get_party r (make_sharing S1 S2 S3)"
  
definition is_uniform_sharing_list :: "natL sharing list spmf  bool" where
  "is_uniform_sharing_list xss  (xs. xss = bind_spmf xs (sequence_spmf  map share_nat))"

lemma case_prod_nesting_same:
  "case_prod (λa b. f (case_prod g x) a b ) x = case_prod (λa b. f (g a b) a b ) x"
  by (cases x) simp

lemma zip_map_map_same:
  "map (λx. (f x, g x)) xs = zip (map f xs) (map g xs)"
  unfolding zip_map_map
  unfolding zip_same_conv_map
  by simp

lemma dup_map_eq:
  "length xs = length ys  (xs, map2 f ys xs) = (λxys. (map fst xys, map snd xys)) (map2 (λx y. (y, f x y)) ys xs)"
  by (auto simp: map_snd_zip[unfolded snd_def])


abbreviation "map2_spmf f xs ys  map_spmf (case_prod f) (pair_spmf xs ys)"
abbreviation "map3_spmf f xs ys zs  map2_spmf (λa. case_prod (f a)) xs (pair_spmf ys zs)"

lemma map_spmf_cong2:
  assumes "p = map_spmf m q" "x. xset_spmf q  f (m x) = g x"
  shows "map_spmf f p = map_spmf g q"
  using assms by (simp add: spmf.map_comp cong: map_spmf_cong)

lemma bind_spmf_cong2:
  assumes "p = map_spmf m q" "x. xset_spmf q  f (m x) = g x"
  shows "bind_spmf p f = bind_spmf q g"
  using assms by (simp add: map_spmf_conv_bind_spmf cong: bind_spmf_cong)

lemma map2_spmf_map2_sequence:
  "length xss = length yss  map2_spmf (map2 f) (sequence_spmf xss) (sequence_spmf yss) = sequence_spmf (map2 (map2_spmf f) xss yss)"
  apply (induction xss yss rule: list_induct2)
  subgoal by simp
  subgoal premises IH for x xs y ys
    apply simp
    apply (fold IH)
    apply (unfold pair_map_spmf)
    apply (unfold spmf.map_comp)
    apply (rule map_spmf_cong2[where m="λ((x,y),(xs,ys)). ((x,xs),(y,ys))"])
    subgoal 
      unfolding pair_spmf_alt_def
      apply (simp add: map_spmf_conv_bind_spmf)
      apply (subst bind_commute_spmf[where q=y])
      ..
    subgoal by auto
    done
  done

abbreviation map3 :: "('a  'b  'c  'd)  'a list  'b list  'c list  'd list" where
  "map3 f a b c  map2 (λa (b,c). f a b c) a (zip b c)"

lemma map3_spmf_map3_sequence:
  "length xss = length yss  length yss = length zss  map3_spmf (map3 f) (sequence_spmf xss) (sequence_spmf yss) (sequence_spmf zss) = sequence_spmf (map3 (map3_spmf f) xss yss zss)"
  apply (induction xss yss zss rule: list_induct3)
  subgoal by simp
  subgoal premises IH for x xs y ys z zs
    apply simp
    apply (fold IH)
    apply (unfold pair_map_spmf)
    apply (unfold spmf.map_comp)
    apply (rule map_spmf_cong2[where m="λ((x,y,z),(xs,ys,zs)). ((x,xs),(y,ys),(z,zs))"])
    subgoal 
      unfolding pair_spmf_alt_def
      apply (simp add: map_spmf_conv_bind_spmf)
      apply (subst bind_commute_spmf[where q=y])
      apply (subst bind_commute_spmf[where q=z])
      apply (subst bind_commute_spmf[where q=z])
      ..
    subgoal by auto
    done
  done


lemma in_pairD2:
  "x  A × B  snd x  B"
  by auto

lemma list_map_cong2:
  "x = map m y  (z. zset y =simp=> f (m z) = g z)  map f x = map g y"
  unfolding simp_implies_def
  by simp

lemma map_swap_zip:
  "map prod.swap (zip xs ys) = zip ys xs"
  apply (induction xs arbitrary: ys)
  subgoal by simp
  subgoal for x xs ys
    by (cases ys) auto
  done

lemma inv_zero_sharing_sequence:
  "n = length x 
   map_spmf (λζs. (ζs, map2 (map_sharing2 (+)) x ζs)) (sequence_spmf (replicate n zero_sharing))
   =
   map_spmf (λys. (map2 (map_sharing2 (-)) ys x, ys)) (sequence_spmf (map (share_nat reconstruct) x))"
proof -
  assume n: "n = length x"
  have "map_spmf (λζs. (ζs, map2 (map_sharing2 (+)) x ζs)) (sequence_spmf (replicate n zero_sharing))
=
map2_spmf (λζs x. (ζs, map2 (map_sharing2 (+)) x ζs)) (sequence_spmf (replicate n zero_sharing)) (sequence_spmf (map return_spmf x))"
    unfolding sequence_map_return_spmf
    apply (rule map_spmf_cong2[where m="fst"])
    subgoal by simp
    subgoal by (auto simp: case_prod_unfold dest: in_pairD2)
    done

  also have "... = map_spmf (λζxs. (map fst ζxs, map snd ζxs)) (map2_spmf (map2 (λζ x. (ζ, map_sharing2 (+) x ζ))) (sequence_spmf (replicate n zero_sharing)) (sequence_spmf (map return_spmf x)))"
    apply (unfold spmf.map_comp)
    apply (rule map_spmf_cong[OF refl])
    using n by (auto simp: case_prod_unfold comp_def set_sequence_spmf list_all2_iff map_swap_zip intro: list_map_cong2[where m=prod.swap])

  also have "... = map_spmf (λζxs. (map fst ζxs, map snd ζxs)) (map2_spmf (map2 (λy x. (map_sharing2 (-) y x, y))) (sequence_spmf (map (share_nat  reconstruct) x)) (sequence_spmf (map return_spmf x)))"
    apply (rule arg_cong[where f="map_spmf _"])
    using n     apply (simp add: map2_spmf_map2_sequence)
    apply (rule arg_cong[where f=sequence_spmf])
    apply (unfold list_eq_iff_nth_eq)
    apply safe
    subgoal by simp
    apply (simp add: )
    apply (subst map_spmf_cong2[where p="pair_spmf _ (return_spmf _)"])
      apply (rule pair_spmf_return_spmf2)
     apply simp
    apply (subst map_spmf_cong2[where p="pair_spmf _ (return_spmf _)"])
      apply (rule pair_spmf_return_spmf2)
     apply simp
    using inv_zero_sharing .

  also have "... = map2_spmf (λys x. (map2 (map_sharing2 (-)) ys x, ys)) (sequence_spmf (map (share_nat reconstruct) x)) (sequence_spmf (map return_spmf x))"
    apply (unfold spmf.map_comp)
    apply (rule map_spmf_cong[OF refl])
    using n by (auto simp: case_prod_unfold comp_def set_sequence_spmf list_all2_iff map_swap_zip intro: list_map_cong2[where m=prod.swap])

  also have "... = map_spmf (λys. (map2 (map_sharing2 (-)) ys x, ys)) (sequence_spmf (map (share_nat reconstruct) x))"
    unfolding sequence_map_return_spmf
    apply (rule map_spmf_cong2[where m="fst", symmetric])
    subgoal by simp
    subgoal by (auto simp: case_prod_unfold dest: in_pairD2)
    done

  finally show ?thesis .
qed

lemma get_party_map_sharing2:
  "get_party p  (case_prod (map_sharing2 f)) = case_prod f  map_prod (get_party p) (get_party p)"
  by auto

lemma map_map_prod_zip:
  "map (map_prod f g) (zip xs ys) = zip (map f xs) (map g ys)"
  by (simp add: map_prod_def zip_map_map)

lemma map_map_prod_zip':
  "map (h  map_prod f g) (zip xs ys) = map h (zip (map f xs) (map g ys))"
  by (simp add: map_prod_def zip_map_map)

lemma eq_map_spmf_conv:
  assumes "x. f (f' x) = x" "f' = inv f" "map_spmf f' x = y"
  shows "x = map_spmf f y"
proof -
  have surj: "surj f"
    apply (rule surjI) using assms(1) .
  have "map_spmf f (map_spmf f' x) = map_spmf f y"
    unfolding assms(3) ..
  thus ?thesis
    using assms(1) by (simp add: spmf.map_comp surj_iff comp_def)
qed


lemma lift_map_spmf_pairs:
  "map2_spmf f = F  map_spmf (case_prod f) (pair_spmf A B) = F A B"
  by auto

lemma measure_pair_spmf_times':
    "C = A × B  measure (measure_spmf (pair_spmf p q)) C = measure (measure_spmf p) A * measure (measure_spmf q) B"
  by (simp add: measure_pair_spmf_times)

lemma map_spmf_pairs_tmp:
  "map_spmf (λ(a,b,c,d,e,f,g). (a,e,b,f,c,g,d)) (pair_spmf A (pair_spmf B (pair_spmf C (pair_spmf D (pair_spmf E (pair_spmf F G))))))
       = (pair_spmf A (pair_spmf E (pair_spmf B (pair_spmf F (pair_spmf C (pair_spmf G  D))))))"
  apply (rule spmf_eqI)
  apply (clarsimp simp add: spmf_map)
  subgoal for a e b f c g d
    apply (subst measure_pair_spmf_times'[where A="{a}"]) defer
    apply (subst measure_pair_spmf_times'[where A="{b}"]) defer
    apply (subst measure_pair_spmf_times'[where A="{c}"]) defer
    apply (subst measure_pair_spmf_times'[where A="{d}"]) defer
    apply (subst measure_pair_spmf_times'[where A="{e}"]) defer
         apply (subst measure_pair_spmf_times'[where A="{f}" and B="{g}"]) defer
          apply (auto simp: spmf_conv_measure_spmf)
    done
  done

lemma case_case_prods_tmp:
  "(case case x of (a, b, c, d, e, f, g)  (a, e, b, f, c, g, d) of
                (ya, yb, yc, yd, ye, yf, yg)  F ya yb yc yd ye yf yg)
      = (case x of (a,b,c,d,e,f,g)  F a e b f c g d)"
  by (cases x) simp

lemma bind_spmf_permutes_cong:
  "(π. π permutes {..<(x::nat)}  f π = g π)
     bind_spmf (spmf_of_set {π. π permutes {..<x}}) f = bind_spmf (spmf_of_set {π. π permutes {..<x}}) g"
  by (rule bind_spmf_cong[OF refl]) (simp add: set_spmf_of_set finite_permutations set_sequence_spmf[unfolded list_all2_iff])

lemma map_eq_iff_list_all2:
  "map f xs = map g ys  list_all2 (λx y. f x = g y) xs ys"
  apply (induction xs arbitrary: ys)
  subgoal by auto
  subgoal for x xs ys by (cases ys) (auto)
  done

lemma bind_spmf_sequence_map_share_nat_cong:
  "(l. map reconstruct l = x  f l = g l)
     bind_spmf (sequence_spmf (map share_nat x)) f = bind_spmf (sequence_spmf (map share_nat x)) g"
  subgoal premises prems
    apply (rule bind_spmf_cong[OF refl])
    apply (rule prems)
    unfolding set_sequence_spmf mem_Collect_eq
    apply (simp add: map_eq_iff_list_all2[where g=id, simplified])
    apply (simp add: list_all2_map2)
    apply (erule list_all2_mono)
    unfolding share_nat_def
    by simp
  done

lemma map_reconstruct_comp_eq_iff:
  "(x. xset xs  reconstruct (f x) = reconstruct x)  map (reconstruct  f) xs = map reconstruct xs"
  by (induction xs) auto

lemma permute_list_replicate:
  "p permutes {..<n}  permute_list p (replicate n x) = replicate n x"
  apply (fold map_replicate_const[where lst="[0..<n]", simplified])
  apply (simp add: permute_list_map)
  unfolding map_replicate_const
  by simp

lemma map2_minus_zero:
  "length xs = length ys  (y::natL. yset ys  y = 0)  map2 (-) xs ys = xs"
  by (induction xs ys rule: list_induct2) auto

lemma permute_comp_left_inj:
  "p permutes {..<n}  inj (λp'. p  p')"
  by (rule fun.inj_map) (rule permutes_inj_on)

lemma permute_comp_left_inj_on:
  "p permutes {..<n}  inj_on (λp'. p  p') A"
  using permute_comp_left_inj inj_on_subset by blast

lemma permute_comp_right_inj:
  "p permutes {..<n}  inj (λp'. p'  p)"
  using inj_onI comp_id o_assoc permutes_surj surj_iff
  by (smt (verit))

lemma permute_comp_right_inj_on:
  "p permutes {..<n}  inj_on (λp'. p'  p) A"
  using permute_comp_right_inj inj_on_subset by blast

lemma permutes_inv_comp_left:
  "p permutes {..<n}  inv (λp'. p  p') = (λp'. inv p  p')"
  by (rule inv_unique_comp; rule ext, simp add: permutes_inv_o comp_assoc[symmetric])

lemma permutes_inv_comp_right:
  "p permutes {..<n}  inv (λp'. p'  p) = (λp'. p'  inv p)"
  by (rule inv_unique_comp; rule ext, simp add: permutes_inv_o comp_assoc)

lemma permutes_inv_comp_left_right:
  "πa permutes {..<n}  πb permutes {..<n}  inv (λp'. πa  p'  πb) = (λp'. inv πa  p'  inv πb)"
  by (rule inv_unique_comp; rule ext, simp add: permutes_inv_o comp_assoc, simp add: permutes_inv_o comp_assoc[symmetric])

lemma permutes_inv_comp_left_left:
  "πa permutes {..<n}  πb permutes {..<n}  inv (λp'. πa  πb  p') = (λp'. inv πb  inv πa  p')"
  by (rule inv_unique_comp; rule ext, simp add: permutes_inv_o comp_assoc, simp add: permutes_inv_o comp_assoc[symmetric])

lemma permutes_inv_comp_right_right:
  "πa permutes {..<n}  πb permutes {..<n}  inv (λp'. p'  πa   πb) = (λp'. p'  inv πb  inv πa)"
  by (rule inv_unique_comp; rule ext, simp add: permutes_inv_o comp_assoc, simp add: permutes_inv_o comp_assoc[symmetric])

lemma image_compose_permutations_left_right:
  fixes S
  assumes "πa permutes S" "πb permutes S"
  shows "{πa  π  πb |π. π permutes S} = {π. π permutes S}"
proof -
  have *: "(λπ. πa  π  πb) = (λπ'. πa  π')  (λπ. π  πb)"
    by (simp add: comp_def)
  then show ?thesis
    apply (fold image_Collect)
    apply (unfold *)
    apply (fold image_comp)
    apply (subst image_Collect)
    apply (unfold image_compose_permutations_right[OF assms(2)])
    apply (subst image_Collect)
    apply (unfold image_compose_permutations_left[OF assms(1)])
    ..
qed

lemma image_compose_permutations_left_left:
  fixes S
  assumes "πa permutes S" "πb permutes S"
  shows "{πa  πb  π |π. π permutes S} = {π. π permutes S}"
  using image_compose_permutations_left image_compose_permutations_right 
proof -
  have *: "(λπ. πa  πb  π) = (λπ'. πa  π')  (λπ. πb  π)"
    by (simp add: comp_def)
  show ?thesis
    apply (fold image_Collect)
    apply (unfold *)
    apply (fold image_comp)
    apply (subst image_Collect)
    apply (unfold image_compose_permutations_left[OF assms(2)])
    apply (subst image_Collect)
    apply (unfold image_compose_permutations_left[OF assms(1)])
    ..
qed

lemma image_compose_permutations_right_right:
  fixes S
  assumes "πa permutes S" "πb permutes S"
  shows "{π  πa  πb |π. π permutes S} = {π. π permutes S}"
  using image_compose_permutations_left image_compose_permutations_right 
proof -
  have *: "(λπ. π  πa  πb) = (λπ. π  πb)  (λπ'. π'  πa)"
    by (simp add: comp_def)
  show ?thesis
    apply (fold image_Collect)
    apply (unfold *)
    apply (fold image_comp)
    apply (subst image_Collect)
    apply (unfold image_compose_permutations_right[OF assms(1)])
    apply (subst image_Collect)
    apply (unfold image_compose_permutations_right[OF assms(2)])
    ..
qed

lemma random_perm_middle:
  defines "random_perm n  spmf_of_set {π. π permutes {..<n::nat}}"
  shows
    "map_spmf (λ(πa,πb,πc). ((πa,πb,πc),πa  πb  πc)) (pair_spmf (random_perm n) (pair_spmf (random_perm n) (random_perm n)))
     = map_spmf (λ(π,πa,πc). ((πa,inv πa  π  inv πc,πc),π)) (pair_spmf (random_perm n) (pair_spmf (random_perm n) (random_perm n)))"
    (is "?lhs = ?rhs")
proof -
  have "?lhs = (do {πa  random_perm n; πc  random_perm n; (πb,p)  map_spmf (λπb. (πb,πa  πb  πc)) (random_perm n); return_spmf ((πa, πb, πc), p)})"
    unfolding map_spmf_conv_bind_spmf pair_spmf_alt_def
    apply simp
    apply (subst (4) bind_commute_spmf)
    ..
  also
  have " = (do {πa  random_perm n; πc  random_perm n; map_spmf (λp. ((πa, inv πa  p  inv πc,πc),p)) (random_perm n)})"
    unfolding random_perm_def
    supply [intro!] =
      bind_spmf_permutes_cong
    apply rule+
    apply (subst inv_uniform)
    subgoal for πa πc
      apply (rule inj_compose[unfolded comp_def, where f="λp. p  πc"])
      subgoal by (rule permute_comp_right_inj_on)
      subgoal by (rule permute_comp_left_inj_on)
      done
    apply (simp add: permutes_inv_comp_left_right map_spmf_conv_bind_spmf image_Collect image_compose_permutations_left_right)
    done
  also
  have " = ?rhs"
    unfolding map_spmf_conv_bind_spmf pair_spmf_alt_def
    apply (subst (2) bind_commute_spmf)
    apply (subst (1) bind_commute_spmf)
    apply simp
    done
  finally show ?thesis .
qed

lemma random_perm_right:
  defines "random_perm n  spmf_of_set {π. π permutes {..<n::nat}}"
  shows
    "map_spmf (λ(πa,πb,πc). ((πa,πb,πc),πa  πb  πc)) (pair_spmf (random_perm n) (pair_spmf (random_perm n) (random_perm n)))
     = map_spmf (λ(π,πa,πb). ((πa,πb,inv πb  inv πa  π),π)) (pair_spmf (random_perm n) (pair_spmf (random_perm n) (random_perm n)))"
    (is "?lhs = ?rhs")
proof -
  have "?lhs = (do {πa  random_perm n; πb  random_perm n; (πc,π)  map_spmf (λπc. (πc,πa  πb  πc)) (random_perm n); return_spmf ((πa, πb, πc), π)})"
    unfolding map_spmf_conv_bind_spmf pair_spmf_alt_def
    by simp
  also
  have " = (do {πa  random_perm n; πb  random_perm n; map_spmf (λπ. ((πa, πb, inv πb  inv πa  π),π)) (random_perm n)})"
    unfolding random_perm_def
    supply [intro!] =
      bind_spmf_permutes_cong
    apply rule+
    apply (subst inv_uniform)
    subgoal
      apply (rule permute_comp_left_inj_on)
      using permutes_compose .
    apply (simp add: permutes_inv_comp_left_left map_spmf_conv_bind_spmf image_Collect image_compose_permutations_left_left)
    done
  also
  have " = ?rhs"
    unfolding map_spmf_conv_bind_spmf pair_spmf_alt_def
    apply (subst (2) bind_commute_spmf)
    apply (subst (1) bind_commute_spmf)
    apply simp
    done
  finally show ?thesis .
qed

lemma random_perm_left:
  defines "random_perm n  spmf_of_set {π. π permutes {..<n::nat}}"
  shows
    "map_spmf (λ(πa,πb,πc). ((πa,πb,πc),πa  πb  πc)) (pair_spmf (random_perm n) (pair_spmf (random_perm n) (random_perm n)))
     = map_spmf (λ(π,πb,πc). ((π  inv πc  inv πb,πb,πc),π)) (pair_spmf (random_perm n) (pair_spmf (random_perm n) (random_perm n)))"
    (is "?lhs = ?rhs")
proof -
  have "?lhs = (do {πb  random_perm n; πc  random_perm n; (πa,π)  map_spmf (λπa. (πa,πa  πb  πc)) (random_perm n); return_spmf ((πa, πb, πc), π)})"
    unfolding map_spmf_conv_bind_spmf pair_spmf_alt_def
    apply simp
    apply (subst (4) bind_commute_spmf)
    apply (subst (3) bind_commute_spmf)
    ..
  also
  have " = (do {πb  random_perm n; πc  random_perm n; map_spmf (λπ. ((π  inv πc  inv πb, πb, πc),π)) (random_perm n)})"
    unfolding random_perm_def
    supply [intro!] =
      bind_spmf_permutes_cong
    apply rule+
    apply (subst inv_uniform)
    subgoal
      apply (unfold comp_assoc)
      apply (rule permute_comp_right_inj_on)
      using permutes_compose .
    apply (simp add: permutes_inv_comp_right_right map_spmf_conv_bind_spmf image_Collect image_compose_permutations_right_right)
    done
  also
  have " = ?rhs"
    unfolding map_spmf_conv_bind_spmf pair_spmf_alt_def
    apply (subst (2) bind_commute_spmf)
    apply (subst (1) bind_commute_spmf)
    apply simp
    done
  finally show ?thesis .
qed

lemma case_prod_return_spmf:
  "case_prod (λa b. return_spmf (f a b)) = (λx. return_spmf (case_prod f x))"
  by auto

lemma sequence_share_nat_calc':
  assumes "r1r2" "r2r3" "r3r1"
  shows
  "sequence_spmf (map share_nat xs) = (do {
    let n = length xs;
    let random_seq = sequence_spmf (replicate n (spmf_of_set UNIV));
    (dp, dpn)  (pair_spmf random_seq random_seq);
    return_spmf (map3 (λx a b. make_sharing' r1 r2 r3 a b (x - (a + b))) xs dp dpn)
  })" (is "_ = ?rhs")
proof -
  have
  "sequence_spmf (map share_nat xs) = (do {
    let n = length xs;
    let random_seq = sequence_spmf (replicate n (spmf_of_set UNIV));
    (xs, dp, dpn)  pair_spmf (sequence_spmf (map return_spmf xs)) (pair_spmf random_seq random_seq);
    return_spmf (map3 (λx a b. make_sharing' r1 r2 r3 a b (x - (a + b))) xs dp dpn)
  })"
    apply (unfold Let_def)
    apply (unfold case_prod_return_spmf)
    apply (fold map_spmf_conv_bind_spmf)
    apply (subst map3_spmf_map3_sequence)
    subgoal by simp
    subgoal by simp
    apply (rule arg_cong[where f=sequence_spmf])
    apply (unfold map_eq_iff_list_all2)
    apply (rule list_all2_all_nthI)
    subgoal by simp
    unfolding share_nat_def_calc'[OF assms]
    apply (auto simp: map_spmf_conv_bind_spmf pair_spmf_alt_def)
    done
  also have " = ?rhs"
    by (auto simp: pair_spmf_alt_def sequence_map_return_spmf)
  finally show ?thesis .
qed


lemma reconstruct_stack_sharing_eq_reconstruct:
  "reconstruct  aby3_stack_sharing r = reconstruct"
  unfolding aby3_stack_sharing_def reconstruct_def
  by (cases r) (auto simp: make_sharing'_sel)

lemma map2_ignore1:
  "length xs = length ys  map2 (λ_. f) xs ys = map f ys"
  apply (unfold map_eq_iff_list_all2)
  apply (rule list_all2_all_nthI)
  by auto

lemma map2_ignore2:
  "length xs = length ys  map2 (λa b. f a) xs ys = map f xs"
  apply (unfold map_eq_iff_list_all2)
  apply (rule list_all2_all_nthI)
  by auto


lemma map_sequence_share_nat_reconstruct:
  "map_spmf (λx. (x, map reconstruct x)) (sequence_spmf (map share_nat y)) = map_spmf (λx. (x, y)) (sequence_spmf (map share_nat y))"
  apply (unfold map_spmf_conv_bind_spmf)
  apply (rule bind_spmf_cong[OF refl])
  apply (auto simp: set_sequence_spmf list_eq_iff_nth_eq list_all2_conv_all_nth share_nat_def)
  done

(* the main theorem of security of the shuffling protocol *)
theorem shuffle_secrecy:
  assumes
    "is_uniform_sharing_list x_dist"
  shows
    "(do {
       x  x_dist;
       (msg, y)  aby3_shuffleR x;
       return_spmf (map (get_party r) x,
                    get_party r msg,
                    y)
     })
     =
     (do {
       x  x_dist;
       y  aby3_shuffleF x;
       let xr = map (get_party r) x;
       let yr = map (get_party r) y;
       msg  S r xr yr;
       return_spmf (xr, msg, y)
     })"
    (is "?lhs = ?rhs")
proof -
  obtain xs where xs: "x_dist = xs  (sequence_spmf  map share_nat)"
    using assms unfolding is_uniform_sharing_list_def by auto

  have left_unfolded:
    "(do {
       x  x_dist;
       (msg, y)  aby3_shuffleR x;
       return_spmf (map (get_party r) x, get_party r msg, y)})
     =
     (do {
       xs  xs;
       x  sequence_spmf (map share_nat xs); 

― ‹round 1›
         let n = length x;
         ζa  sequence_spmf (replicate n zero_sharing);
         πa  spmf_of_set {π. π permutes {..<n}};
         let x' = map (get_party (prev_role Party1)) x;
         let y' = map (aby3_stack_sharing Party1) x;
         let a = map2 (map_sharing2 (+)) (permute_list πa y') ζa;

― ‹round 2›
         let n = length a;
         ζb  sequence_spmf (replicate n zero_sharing);
         πb  spmf_of_set {π. π permutes {..<n}};
         let a' = map (get_party (prev_role Party2)) a;
         let y' = map (aby3_stack_sharing Party2) a;
         let b = map2 (map_sharing2 (+)) (permute_list πb y') ζb;

― ‹round 3›
         let n = length b;
         ζc  sequence_spmf (replicate n zero_sharing);
         πc  spmf_of_set {π. π permutes {..<n}};
         let b' = map (get_party (prev_role Party3)) b;
         let y' = map (aby3_stack_sharing Party3) b;
         let c = map2 (map_sharing2 (+)) (permute_list πc y') ζc;

       let msg1 = ((map (get_party Party1) ζa, map (get_party Party1) ζb, map (get_party Party1) ζc), b', πa, πc);
       let msg2 = ((map (get_party Party2) ζa, map (get_party Party2) ζb, map (get_party Party2) ζc), x', πa, πb);
       let msg3 = ((map (get_party Party3) ζa, map (get_party Party3) ζb, map (get_party Party3) ζc), a', πb, πc);
       let msg = make_sharing msg1 msg2 msg3;
       return_spmf (map (get_party r) x, get_party r msg, c)
   })"
    unfolding xs aby3_shuffleR_def aby3_do_permute_def
    by (auto simp: case_prod_unfold Let_def)

  also have clarify_length:
    " = (do {
       xs  xs;
       let n = length xs;
       x  sequence_spmf (map share_nat xs); 

― ‹round 1›
         ζa  sequence_spmf (replicate n zero_sharing);
         πa  spmf_of_set {π. π permutes {..<n}};
         let x' = map (get_party (prev_role Party1)) x;
         let y' = map (aby3_stack_sharing Party1) x;
         let a = map2 (map_sharing2 (+)) (permute_list πa y') ζa;

― ‹round 2›
         ζb  sequence_spmf (replicate n zero_sharing);
         πb  spmf_of_set {π. π permutes {..<n}};
         let a' = map (get_party (prev_role Party2)) a;
         let y' = map (aby3_stack_sharing Party2) a;
         let b = map2 (map_sharing2 (+)) (permute_list πb y') ζb;

― ‹round 3›
         ζc  sequence_spmf (replicate n zero_sharing);
         πc  spmf_of_set {π. π permutes {..<n}};
         let b' = map (get_party (prev_role Party3)) b;
         let y' = map (aby3_stack_sharing Party3) b;
         let c = map2 (map_sharing2 (+)) (permute_list πc y') ζc;

       let msg1 = ((map (get_party Party1) ζa, map (get_party Party1) ζb, map (get_party Party1) ζc), b', πa, πc);
       let msg2 = ((map (get_party Party2) ζa, map (get_party Party2) ζb, map (get_party Party2) ζc), x', πa, πb);
       let msg3 = ((map (get_party Party3) ζa, map (get_party Party3) ζb, map (get_party Party3) ζc), a', πb, πc);
       let msg = make_sharing msg1 msg2 msg3;
       return_spmf (map (get_party r) x, get_party r msg, c)
   })"
    supply [intro!] =
      bind_spmf_cong[OF refl]
      let_cong[OF refl]
      bind_spmf_sequence_map_cong
      bind_spmf_sequence_replicate_cong
      bind_spmf_permutes_cong
    by (auto simp: Let_def)

  also have hoist_permutations:
    " = (do {
       xs  xs;
       let n = length xs;
       x  sequence_spmf (map share_nat xs); 

         πa  spmf_of_set {π. π permutes {..<n}};
         πb  spmf_of_set {π. π permutes {..<n}};
         πc  spmf_of_set {π. π permutes {..<n}};

― ‹round 1›
         let x' = map (get_party (prev_role Party1)) x;
         let y' = map (aby3_stack_sharing Party1) x;
         ζa  sequence_spmf (replicate n zero_sharing);
         let a = map2 (map_sharing2 (+)) (permute_list πa y') ζa;

― ‹round 2›
         let a' = map (get_party (prev_role Party2)) a;
         let y' = map (aby3_stack_sharing Party2) a;
         ζb  sequence_spmf (replicate n zero_sharing);
         let b = map2 (map_sharing2 (+)) (permute_list πb y') ζb;

― ‹round 3›
         let b' = map (get_party (prev_role Party3)) b;
         let y' = map (aby3_stack_sharing Party3) b;
         ζc  sequence_spmf (replicate n zero_sharing);
         let c = map2 (map_sharing2 (+)) (permute_list πc y') ζc;

       let msg1 = ((map (get_party Party1) ζa, map (get_party Party1) ζb, map (get_party Party1) ζc), b', πa, πc);
       let msg2 = ((map (get_party Party2) ζa, map (get_party Party2) ζb, map (get_party Party2) ζc), x', πa, πb);
       let msg3 = ((map (get_party Party3) ζa, map (get_party Party3) ζb, map (get_party Party3) ζc), a', πb, πc);
       let msg = make_sharing msg1 msg2 msg3;
       return_spmf (map (get_party r) x, get_party r msg, c)
   })"
    apply (simp add: Let_def)
    apply (subst (1) bind_commute_spmf[where q="spmf_of_set _"])
    apply (subst (2) bind_commute_spmf[where q="spmf_of_set _"])
    apply (subst (2) bind_commute_spmf[where q="spmf_of_set _"])
    apply (subst (3) bind_commute_spmf[where q="spmf_of_set _"])
    apply (subst (3) bind_commute_spmf[where q="spmf_of_set _"])
    apply (subst (3) bind_commute_spmf[where q="spmf_of_set _"])
    by simp

  also have hoist_permutations:
    " = (do {
       xs  xs;
       let n = length xs;
       x  sequence_spmf (map share_nat xs); 

         πa  spmf_of_set {π. π permutes {..<n}};
         πb  spmf_of_set {π. π permutes {..<n}};
         πc  spmf_of_set {π. π permutes {..<n}};

― ‹round 1›
         let x' = map (get_party (prev_role Party1)) x;
         let y' = map (aby3_stack_sharing Party1) x;
         a  sequence_spmf (map (share_nat  reconstruct) (permute_list πa y'));
         let ζa = map2 (map_sharing2 (-)) a (permute_list πa y');

― ‹round 2›
         let a' = map (get_party (prev_role Party2)) a;
         let y' = map (aby3_stack_sharing Party2) a;
         b  sequence_spmf (map (share_nat  reconstruct) (permute_list πb y'));
         let ζb = map2 (map_sharing2 (-)) b (permute_list πb y');

― ‹round 3›
         let b' = map (get_party (prev_role Party3)) b;
         let y' = map (aby3_stack_sharing Party3) b;
         c  sequence_spmf (map (share_nat  reconstruct) (permute_list πc y'));
         let ζc = map2 (map_sharing2 (-)) c (permute_list πc y');

       let msg1 = ((map (get_party Party1) ζa, map (get_party Party1) ζb, map (get_party Party1) ζc), b', πa, πc);
       let msg2 = ((map (get_party Party2) ζa, map (get_party Party2) ζb, map (get_party Party2) ζc), x', πa, πb);
       let msg3 = ((map (get_party Party3) ζa, map (get_party Party3) ζb, map (get_party Party3) ζc), a', πb, πc);
       let msg = make_sharing msg1 msg2 msg3;
       return_spmf (map (get_party r) x, get_party r msg, c)
   })"
    supply [intro!] =
      bind_spmf_cong[OF refl]
      let_cong[OF refl]
      prod.case_cong[OF refl]
      bind_spmf_sequence_map_cong
      bind_spmf_sequence_replicate_cong
      bind_spmf_permutes_cong

    apply rule+
    apply (subst hoist_map_spmf[where s="sequence_spmf (replicate _ _)" and f = "map2 (map_sharing2 (+)) _"])
    apply (subst hoist_map_spmf'[where s="sequence_spmf (map _ _)" and f = "λys. map2 (map_sharing2 (-)) ys _"])
    apply (subst inv_zero_sharing_sequence)
    subgoal by simp
    apply (unfold map_spmf_conv_bind_spmf)
    apply (unfold bind_spmf_assoc)
    apply (unfold return_bind_spmf)
    apply rule+
    apply (subst (1 12) Let_def)
    apply rule+

    apply (subst hoist_map_spmf[where s="sequence_spmf (replicate _ _)" and f = "map2 (map_sharing2 (+)) _"])
    apply (subst hoist_map_spmf'[where s="sequence_spmf (map _ _)" and f = "λys. map2 (map_sharing2 (-)) ys _"])
    apply (subst inv_zero_sharing_sequence)
    subgoal by simp
    apply (unfold map_spmf_conv_bind_spmf)
    apply (unfold bind_spmf_assoc)
    apply (unfold return_bind_spmf)
    apply rule+
    apply (subst (1 9) Let_def)
    apply rule+

    apply (subst hoist_map_spmf[where s="sequence_spmf (replicate _ _)" and f = "map2 (map_sharing2 (+)) _"])
    apply (subst hoist_map_spmf'[where s="sequence_spmf (map _ _)" and f = "λys. map2 (map_sharing2 (-)) ys _"])
    apply (subst inv_zero_sharing_sequence)
    subgoal by simp
    apply (unfold map_spmf_conv_bind_spmf)
    apply (unfold bind_spmf_assoc)
    apply (unfold return_bind_spmf)
    apply rule+
    apply (subst (1 6) Let_def)
    apply rule+

    done

  also have reconstruct:
    " = (do {
       xs  xs;
       let n = length xs;
       x  sequence_spmf (map share_nat xs); 

         πa  spmf_of_set {π. π permutes {..<n}};
         πb  spmf_of_set {π. π permutes {..<n}};
         πc  spmf_of_set {π. π permutes {..<n}};

― ‹round 1›
         let x' = map (get_party (prev_role Party1)) x;
         let y' = map (aby3_stack_sharing Party1) x;
         a  sequence_spmf (map share_nat (permute_list πa xs));
         let ζa = map2 (map_sharing2 (-)) a (permute_list πa y');

― ‹round 2›
         let a' = map (get_party (prev_role Party2)) a;
         let y' = map (aby3_stack_sharing Party2) a;
         b  sequence_spmf (map share_nat (permute_list (πa  πb) xs));
         let ζb = map2 (map_sharing2 (-)) b (permute_list πb y');

― ‹round 3›
         let b' = map (get_party (prev_role Party3)) b;
         let y' = map (aby3_stack_sharing Party3) b;
         c  sequence_spmf (map share_nat (permute_list (πa  πb  πc) xs));
         let ζc = map2 (map_sharing2 (-)) c (permute_list πc y');

       let msg1 = ((map (get_party Party1) ζa, map (get_party Party1) ζb, map (get_party Party1) ζc), b', πa, πc);
       let msg2 = ((map (get_party Party2) ζa, map (get_party Party2) ζb, map (get_party Party2) ζc), x', πa, πb);
       let msg3 = ((map (get_party Party3) ζa, map (get_party Party3) ζb, map (get_party Party3) ζc), a', πb, πc);
       let msg = make_sharing msg1 msg2 msg3;
       return_spmf (map (get_party r) x, get_party r msg, c)
   })"

    supply [intro!] =
      bind_spmf_cong[OF refl]
      let_cong[OF refl]
      prod.case_cong[OF refl]
      bind_spmf_sequence_map_cong
      bind_spmf_sequence_replicate_cong
      bind_spmf_permutes_cong
      bind_spmf_sequence_map_share_nat_cong

    apply rule+
    apply (subst list.map_comp[symmetric])
    apply (rule bind_spmf_cong)
    subgoal by (auto simp:permute_list_map[symmetric] map_reconstruct_comp_eq_iff reconstruct_def set_sequence_spmf[unfolded list_all2_iff] make_sharing'_sel reconstruct_stack_sharing_eq_reconstruct comp_assoc)
    apply rule+
    apply (subst list.map_comp[symmetric])
    apply (rule bind_spmf_cong)
    subgoal for x l xa πa πb πc xb xc xd xe xf xg
      apply (subst permute_list_map[symmetric] )
      subgoal by (auto simp add: set_sequence_spmf[unfolded list_all2_iff])
      apply simp
      apply (subst map_reconstruct_comp_eq_iff)
      subgoal by (simp add: reconstruct_def make_sharing'_sel aby3_stack_sharing_def)
      unfolding set_sequence_spmf mem_Collect_eq 
      unfolding list_all2_map2
      apply (subst map_eq_iff_list_all2[where f=reconstruct and g=id and xs=xd and ys="permute_list πa x", simplified, THEN iffD2])
      subgoal by (erule list_all2_mono) (simp add: share_nat_def)
      apply (subst permute_list_compose)
      subgoal by auto
      ..
    apply rule+
    apply (subst list.map_comp[symmetric])
    apply (rule bind_spmf_cong)
    subgoal for x l xa πa πb πc xb xc xd xe xf xg xh xi xj xk
      apply (subst permute_list_map[symmetric] )
      subgoal by (auto simp add: set_sequence_spmf[unfolded list_all2_iff])
      apply simp
      apply (subst map_reconstruct_comp_eq_iff)
      subgoal by (simp add: reconstruct_def make_sharing'_sel aby3_stack_sharing_def)
      unfolding set_sequence_spmf mem_Collect_eq 
      unfolding list_all2_map2
      apply (subst map_eq_iff_list_all2[where f=reconstruct and g=id and xs=xh and ys="permute_list (πa  πb) x", simplified, THEN iffD2])
      subgoal by (erule list_all2_mono) (simp add: share_nat_def)
      apply (subst permute_list_compose[symmetric])
      subgoal by auto
      ..
    ..

  also have hoist:
    " = (do {
       xs  xs;
       let n = length xs;
       x  sequence_spmf (map share_nat xs);

         πa  spmf_of_set {π. π permutes {..<n}};
         πb  spmf_of_set {π. π permutes {..<n}};
         πc  spmf_of_set {π. π permutes {..<n}};

         a  sequence_spmf (map share_nat (permute_list πa xs));
         b  sequence_spmf (map share_nat (permute_list (πa  πb) xs));
         c  sequence_spmf (map share_nat (permute_list (πa  πb  πc) xs));

― ‹round 1›
         let x' = map (get_party (prev_role Party1)) x;
         let y' = map (aby3_stack_sharing Party1) x;
         let ζa = map2 (map_sharing2 (-)) a (permute_list πa y');

― ‹round 2›
         let a' = map (get_party (prev_role Party2)) a;
         let y' = map (aby3_stack_sharing Party2) a;
         let ζb = map2 (map_sharing2 (-)) b (permute_list πb y');

― ‹round 3›
         let b' = map (get_party (prev_role Party3)) b;
         let y' = map (aby3_stack_sharing Party3) b;
         let ζc = map2 (map_sharing2 (-)) c (permute_list πc y');

       let msg1 = ((map (get_party Party1) ζa, map (get_party Party1) ζb, map (get_party Party1) ζc), b', πa, πc);
       let msg2 = ((map (get_party Party2) ζa, map (get_party Party2) ζb, map (get_party Party2) ζc), x', πa, πb);
       let msg3 = ((map (get_party Party3) ζa, map (get_party Party3) ζb, map (get_party Party3) ζc), a', πb, πc);
       let msg = make_sharing msg1 msg2 msg3;
       return_spmf (map (get_party r) x, get_party r msg, c)
   })"
    unfolding Let_def ..
  finally have hoisted_save: "?lhs = " .
  let ?hoisted = 

  { assume r: "r = Party1"
    have project_to_Party1:
    "?hoisted = (do {
       xs  xs;
       let n = length xs;
       x  sequence_spmf (map share_nat xs);

         πa  spmf_of_set {π. π permutes {..<n}};
         πb  spmf_of_set {π. π permutes {..<n}};
         πc  spmf_of_set {π. π permutes {..<n}};

         a  sequence_spmf (map share_nat (permute_list πa xs));
         b  sequence_spmf (map share_nat (permute_list (πa  πb) xs));
         c  sequence_spmf (map share_nat (permute_list (πa  πb  πc) xs));

― ‹round 1›
         let x' = map (get_party Party3) x;
         let y' = map (aby3_stack_sharing Party1) x;
         let ζa1 = map (case_prod (-)) (zip (map (get_party Party1) a) (map (get_party Party1) (permute_list πa y')));

― ‹round 2›
         let a' = map (get_party Party1) a;
         let y' = map (aby3_stack_sharing Party2) a;
         let ζb1 = map (case_prod (-)) (zip (map (get_party Party1) b) (map (get_party Party1) (permute_list πb y')));

― ‹round 3›
         let b' = map (get_party Party2) b;
         let y' = map (aby3_stack_sharing Party3) b;
         let ζc1 = map (case_prod (-)) (zip (map (get_party Party1) c) (map (get_party Party1) (permute_list πc y')));

       let msg1 = ((ζa1, ζb1, ζc1), b', πa, πc);
       return_spmf (map (get_party Party1) x, msg1, c)
   })"
      by (simp add: r Let_def get_party_map_sharing2 map_map_prod_zip')

    also have project_to_Party1:
    " = (do {
       xs  xs;
       let n = length xs;
       x  sequence_spmf (map share_nat xs);

         πa  spmf_of_set {π. π permutes {..<n}};
         πb  spmf_of_set {π. π permutes {..<n}};
         πc  spmf_of_set {π. π permutes {..<n}};

         a  sequence_spmf (map share_nat (permute_list πa xs));
         b  sequence_spmf (map share_nat (permute_list (πa  πb) xs));
         c  sequence_spmf (map share_nat (permute_list (πa  πb  πc) xs));

― ‹round 1›
         let x' = map (get_party Party3) x;
         let y' = map (aby3_stack_sharing Party1) x;
         let ζa1 = map2 (-) (map (get_party Party1) a) (permute_list πa (map (get_party Party1) y'));

― ‹round 2›
         let a' = map (get_party Party1) a;
         let y' = map (aby3_stack_sharing Party2) a;
         let ζb1 = map2 (-) (map (get_party Party1) b) (permute_list πb (map (get_party Party1) y'));

― ‹round 3›
         let b' = map (get_party Party2) b;
         let y' = map (aby3_stack_sharing Party3) b;
         let ζc1 = map2 (-) (map (get_party Party1) c) (permute_list πc (map (get_party Party1) y'));

       let msg1 = ((ζa1, ζb1, ζc1), b', πa, πc);
       return_spmf (map (get_party Party1) x, msg1, c)
    })"
    supply [intro!] =
      bind_spmf_cong[OF refl]
      let_cong[OF refl]
      prod.case_cong[OF refl]
      bind_spmf_sequence_map_cong
      bind_spmf_sequence_replicate_cong
      bind_spmf_permutes_cong

      apply rule+
      apply (subst permute_list_map[symmetric])
      subgoal by simp

      apply rule+
      apply (subst permute_list_map[symmetric])
      subgoal by simp

      apply rule+
      apply (subst permute_list_map[symmetric])
      subgoal by simp
      ..

    also have reduce_Lets:
    " = (do {
       xs  xs;
       let n = length xs;
       x  sequence_spmf (map share_nat xs);

         πa  spmf_of_set {π. π permutes {..<n}};
         πb  spmf_of_set {π. π permutes {..<n}};
         πc  spmf_of_set {π. π permutes {..<n}};

         a  sequence_spmf (map share_nat (permute_list πa xs));
         b  sequence_spmf (map share_nat (permute_list (πa  πb) xs));
         c  sequence_spmf (map share_nat (permute_list (πa  πb  πc) xs));

― ‹round 1›
         let ζa1 = map2 (-) (map (get_party Party1) a) (permute_list πa (map (get_party Party1) x));

― ‹round 2›
         let ζb1 = map2 (-) (map (get_party Party1) b) (permute_list πb (replicate (length a) 0));

― ‹round 3›
         let b' = map (get_party Party2) b;
         let ζc1 = map2 (-) (map (get_party Party1) c) (permute_list πc (map2 (+) (map (get_party Party1) b) (map (get_party Party2) b)));

       let msg1 = ((ζa1, ζb1, ζc1), b', πa, πc);
       return_spmf (map (get_party Party1) x, msg1, c)
    })"
      unfolding Let_def
      unfolding aby3_stack_sharing_def
      by (simp add: comp_def make_sharing'_sel map_replicate_const zip_map_map_same[symmetric])

    also have simplify_minus_zero:
    " = (do {
       xs  xs;
       let n = length xs;
       x  sequence_spmf (map share_nat xs);

         πa  spmf_of_set {π. π permutes {..<n}};
         πb  spmf_of_set {π. π permutes {..<n}};
         πc  spmf_of_set {π. π permutes {..<n}};

         a  sequence_spmf (map share_nat (permute_list πa xs));
         b  sequence_spmf (map share_nat (permute_list (πa  πb) xs));
         c  sequence_spmf (map share_nat (permute_list (πa  πb  πc) xs));

― ‹round 1›
         let ζa1 = map2 (-) (map (get_party Party1) a) (permute_list πa (map (get_party Party1) x));

― ‹round 2›
         let ζb1 = (map (get_party Party1) b);

― ‹round 3›
         let b' = map (get_party Party2) b;
         let ζc1 = map2 (-) (map (get_party Party1) c) (permute_list πc (map2 (+) (map (get_party Party1) b) (map (get_party Party2) b)));

       let msg1 = ((ζa1, ζb1, ζc1), b', πa, πc);
       return_spmf (map (get_party Party1) x, msg1, c)
    })"
    supply [intro!] =
      bind_spmf_cong[OF refl]
      let_cong[OF refl]
      prod.case_cong[OF refl]
      bind_spmf_sequence_map_cong
      bind_spmf_sequence_replicate_cong
      bind_spmf_permutes_cong

      apply rule+
      apply (subst permute_list_replicate)
      subgoal by simp

      apply (subst map2_minus_zero)
      subgoal by simp
      subgoal by simp

      ..

    also have break_perms_1:
    " = (do {
       xs  xs;
       let n = length xs;
       x  sequence_spmf (map share_nat xs);

         (πa,πb,πc)  pair_spmf (spmf_of_set {π. π permutes {..<n}}) (pair_spmf (spmf_of_set {π. π permutes {..<n}}) (spmf_of_set {π. π permutes {..<n}}));

         a  sequence_spmf (map share_nat (permute_list πa xs));
         b  sequence_spmf (map share_nat (permute_list (πa  πb) xs));
         c  sequence_spmf (map share_nat (permute_list (πa  πb  πc) xs));

― ‹round 1›
         let ζa1 = map2 (-) (map (get_party Party1) a) (permute_list πa (map (get_party Party1) x));

― ‹round 2›
         let ζb1 = (map (get_party Party1) b);

― ‹round 3›
         let b' = map (get_party Party2) b;
         let ζc1 = map2 (-) (map (get_party Party1) c) (permute_list πc (map2 (+) (map (get_party Party1) b) (map (get_party Party2) b)));

       let msg1 = ((ζa1, ζb1, ζc1), b', πa, πc);
       return_spmf (map (get_party Party1) x, msg1, c)
    })"
      unfolding pair_spmf_alt_def by simp

    also have break_perms_2:
    " = (do {
       xs  xs;
       let n = length xs;
       x  sequence_spmf (map share_nat xs);

         ((πa,πb,πc),π)  map_spmf (λ(πa,πb,πc). ((πa,πb,πc), πa  πb  πc)) (pair_spmf (spmf_of_set {π. π permutes {..<n}}) (pair_spmf (spmf_of_set {π. π permutes {..<n}}) (spmf_of_set {π. π permutes {..<n}})));

         a  sequence_spmf (map share_nat (permute_list πa xs));
         b  sequence_spmf (map share_nat (permute_list (πa  πb) xs));
         c  sequence_spmf (map share_nat (permute_list π xs));

― ‹round 1›
         let ζa1 = map2 (-) (map (get_party Party1) a) (permute_list πa (map (get_party Party1) x));

― ‹round 2›
         let ζb1 = (map (get_party Party1) b);

― ‹round 3›
         let b' = map (get_party Party2) b;
         let ζc1 = map2 (-) (map (get_party Party1) c) (permute_list πc (map2 (+) (map (get_party Party1) b) (map (get_party Party2) b)));

       let msg1 = ((ζa1, ζb1, ζc1), b', πa, πc);
       return_spmf (map (get_party Party1) x, msg1, c)
    })"
      unfolding pair_spmf_alt_def map_spmf_conv_bind_spmf by simp

    also have break_perms_3:
    " = (do {
       xs  xs;
       let n = length xs;
       x  sequence_spmf (map share_nat xs);

         π  spmf_of_set {π. π permutes {..<n}};
         πa  spmf_of_set {π. π permutes {..<n}};
         πc  spmf_of_set {π. π permutes {..<n}};
         let πb = inv πa  π  inv πc;

         a  sequence_spmf (map share_nat (permute_list πa xs));
         b  sequence_spmf (map share_nat (permute_list (πa  πb) xs));
         c  sequence_spmf (map share_nat (permute_list π xs));

― ‹round 1›
         let ζa1 = map2 (-) (map (get_party Party1) a) (permute_list πa (map (get_party Party1) x));

― ‹round 2›
         let ζb1 = (map (get_party Party1) b);

― ‹round 3›
         let b' = map (get_party Party2) b;
         let ζc1 = map2 (-) (map (get_party Party1) c) (permute_list πc (map2 (+) (map (get_party Party1) b) (map (get_party Party2) b)));

       let msg1 = ((ζa1, ζb1, ζc1), b', πa, πc);
       return_spmf (map (get_party Party1) x, msg1, c)
    })"
      apply (unfold random_perm_middle)
      apply (unfold map_spmf_conv_bind_spmf pair_spmf_alt_def)
      by simp

    also have break_seqs_3:
    " = (do {
       xs  xs;
       let n = length xs;

       x  sequence_spmf (map share_nat xs);

         π  spmf_of_set {π. π permutes {..<n}};
         πa  spmf_of_set {π. π permutes {..<n}};
         πc  spmf_of_set {π. π permutes {..<n}};
         let πb = inv πa  π  inv πc;

         a1  sequence_spmf (replicate n (spmf_of_set UNIV));
         a2  sequence_spmf (replicate n (spmf_of_set UNIV));
         let a = map3 (λa b c. make_sharing b c (a - (b + c))) (permute_list πa xs) a1 a2;
         b1  sequence_spmf (replicate n (spmf_of_set UNIV));
         b2  sequence_spmf (replicate n (spmf_of_set UNIV));
         let b = map3 (λa b c. make_sharing b c (a - (b + c))) (permute_list (πa  πb) xs) b1 b2;
         c  sequence_spmf (map share_nat (permute_list π xs));

― ‹round 1›
         let ζa1 = map2 (-) (map (get_party Party1) a) (permute_list πa (map (get_party Party1) x));

― ‹round 2›
         let ζb1 = (map (get_party Party1) b);

― ‹round 3›
         let b' = map (get_party Party2) b;
         let ζc1 = map2 (-) (map (get_party Party1) c) (permute_list πc (map2 (+) (map (get_party Party1) b) (map (get_party Party2) b)));

       let msg1 = ((ζa1, ζb1, ζc1), b', πa, πc);
       return_spmf (map (get_party Party1) x, msg1, c)
    })"
      apply (unfold sequence_share_nat_calc'[of Party1 Party2 Party3, simplified])
      apply (simp add: pair_spmf_alt_def Let_def)
      done

    also have break_seqs_3:
    " = (do {
       xs  xs;
       let n = length xs;

       x  sequence_spmf (map share_nat xs);

         π  spmf_of_set {π. π permutes {..<n}};
         πa  spmf_of_set {π. π permutes {..<n}};
         πc  spmf_of_set {π. π permutes {..<n}};

         a1::natL list  sequence_spmf (replicate n (spmf_of_set UNIV));
         a2::natL list  sequence_spmf (replicate n (spmf_of_set UNIV));

         b1::natL list  sequence_spmf (replicate n (spmf_of_set UNIV));
         b2::natL list  sequence_spmf (replicate n (spmf_of_set UNIV));
         c  sequence_spmf (map share_nat (permute_list π xs));

― ‹round 1›
         let ζa1 = map2 (-) a1 (permute_list πa (map (get_party Party1) x));

― ‹round 2›
         let ζb1 = b1;

― ‹round 3›
         let b' = b2;
         let ζc1 = map2 (-) (map (get_party Party1) c) (permute_list πc (map2 (+) b1 b2));

       let msg1 = ((ζa1, ζb1, ζc1), b', πa, πc);
       return_spmf (map (get_party Party1) x, msg1, c)
    })"
    supply [intro!] =
      bind_spmf_cong[OF refl]
      let_cong[OF refl]
      prod.case_cong[OF refl]
      bind_spmf_sequence_map_cong
      bind_spmf_sequence_replicate_cong
      bind_spmf_permutes_cong
      unfolding Let_def
      apply rule+
      apply (auto simp: map2_ignore1 map2_ignore2 comp_def prod.case_distrib bind_spmf_const)
      done

    also have
      " = (do {x  x_dist; y  aby3_shuffleF x; let xr = map (get_party r) x; let yr = map (get_party r) y; msg  S r xr yr; return_spmf (xr, msg, y)})"
      unfolding xs
      unfolding aby3_shuffleF_def
      apply (simp add: bind_spmf_const map_spmf_conv_bind_spmf)
      apply (subst lossless_sequence_spmf[unfolded lossless_spmf_def])
      subgoal by simp
      apply simp
      apply (subst bind_commute_spmf[where q="sequence_spmf (map share_nat _)"])
      apply (subst bind_commute_spmf[where q="sequence_spmf (map share_nat _)"])
      apply (subst bind_commute_spmf[where q="sequence_spmf (map share_nat _)"])
      apply (subst bind_commute_spmf[where q="sequence_spmf (map share_nat _)"])
      apply (subst bind_commute_spmf[where q="sequence_spmf (map share_nat _)"])

      apply (subst (3) hoist_map_spmf[where s="sequence_spmf (map share_nat _)" and f="map reconstruct"])
      apply (subst map_sequence_share_nat_reconstruct)
      apply (simp add: map_spmf_conv_bind_spmf)
      apply (subst Let_def)

      supply [intro!] =
        bind_spmf_cong[OF refl]
        let_cong[OF refl]
        prod.case_cong[OF refl]
        bind_spmf_sequence_map_cong
        bind_spmf_sequence_replicate_cong
        bind_spmf_permutes_cong

      supply [simp] = finite_permutations

      apply rule
      apply rule
      apply simp
      apply rule
      apply rule
      unfolding S_def S1_def r
      apply (simp add: Let_def)
      done

    finally have "?hoisted = " .
  } note simulate_party1 = this

  { assume r: "r = Party2"
    have project_to_Party2:
    "?hoisted = (do {
       xs  xs;
       let n = length xs;
       x  sequence_spmf (map share_nat xs);

         πa  spmf_of_set {π. π permutes {..<n}};
         πb  spmf_of_set {π. π permutes {..<n}};
         πc  spmf_of_set {π. π permutes {..<n}};

         a  sequence_spmf (map share_nat (permute_list πa xs));
         b  sequence_spmf (map share_nat (permute_list (πa  πb) xs));
         c  sequence_spmf (map share_nat (permute_list (πa  πb  πc) xs));

― ‹round 1›
         let x' = map (get_party Party3) x;
         let y' = map (aby3_stack_sharing Party1) x;
         let ζa2 = map (case_prod (-)) (zip (map (get_party Party2) a) (map (get_party Party2) (permute_list πa y')));

― ‹round 2›
         let a' = map (get_party Party1) a;
         let y' = map (aby3_stack_sharing Party2) a;
         let ζb2 = map (case_prod (-)) (zip (map (get_party Party2) b) (map (get_party Party2) (permute_list πb y')));

― ‹round 3›
         let b' = map (get_party Party2) b;
         let y' = map (aby3_stack_sharing Party3) b;
         let ζc2 = map (case_prod (-)) (zip (map (get_party Party2) c) (map (get_party Party2) (permute_list πc y')));

       let msg2 = ((ζa2, ζb2, ζc2), x', πa, πb);
       return_spmf (map (get_party Party2) x, msg2, c)
   })"
      by (simp add: r Let_def get_party_map_sharing2 map_map_prod_zip')

    also have project_to_Party2:
    " = (do {
       xs  xs;
       let n = length xs;
       x  sequence_spmf (map share_nat xs);

         πa  spmf_of_set {π. π permutes {..<n}};
         πb  spmf_of_set {π. π permutes {..<n}};
         πc  spmf_of_set {π. π permutes {..<n}};

         a  sequence_spmf (map share_nat (permute_list πa xs));
         b  sequence_spmf (map share_nat (permute_list (πa  πb) xs));
         c  sequence_spmf (map share_nat (permute_list (πa  πb  πc) xs));

― ‹round 1›
         let x' = map (get_party Party3) x;
         let y' = map (aby3_stack_sharing Party1) x;
         let ζa2 = map2 (-) (map (get_party Party2) a) (permute_list πa (map (get_party Party2) y'));

― ‹round 2›
         let a' = map (get_party Party1) a;
         let y' = map (aby3_stack_sharing Party2) a;
         let ζb2 = map2 (-) (map (get_party Party2) b) (permute_list πb (map (get_party Party2) y'));

― ‹round 3›
         let b' = map (get_party Party2) b;
         let y' = map (aby3_stack_sharing Party3) b;
         let ζc2 = map2 (-) (map (get_party Party2) c) (permute_list πc (map (get_party Party2) y'));

       let msg2 = ((ζa2, ζb2, ζc2), x', πa, πb);
       return_spmf (map (get_party Party2) x, msg2, c)
    })"
    supply [intro!] =
      bind_spmf_cong[OF refl]
      let_cong[OF refl]
      prod.case_cong[OF refl]
      bind_spmf_sequence_map_cong
      bind_spmf_sequence_replicate_cong
      bind_spmf_permutes_cong

      apply rule+
      apply (subst permute_list_map[symmetric])
      subgoal by simp

      apply rule+
      apply (subst permute_list_map[symmetric])
      subgoal by simp

      apply rule+
      apply (subst permute_list_map[symmetric])
      subgoal by simp
      ..

    also have reduce_Lets:
    " = (do {
       xs  xs;
       let n = length xs;
       x  sequence_spmf (map share_nat xs);

         πa  spmf_of_set {π. π permutes {..<n}};
         πb  spmf_of_set {π. π permutes {..<n}};
         πc  spmf_of_set {π. π permutes {..<n}};

         a  sequence_spmf (map share_nat (permute_list πa xs));
         b  sequence_spmf (map share_nat (permute_list (πa  πb) xs));
         c  sequence_spmf (map share_nat (permute_list (πa  πb  πc) xs));

― ‹round 1›
         let x' = map (get_party Party3) x;
         let ζa2 = map2 (-) (map (get_party Party2) a) (permute_list πa (map2 (+) (map (get_party Party2) x) (map (get_party Party3) x)));

― ‹round 2›
         let ζb2 = map2 (-) (map (get_party Party2) b) (permute_list πb (map (get_party Party2) a));

― ‹round 3›
         let ζc2 = map2 (-) (map (get_party Party2) c) (permute_list πc (replicate (length b) 0));

       let msg2 = ((ζa2, ζb2, ζc2), x', πa, πb);

       return_spmf (map (get_party Party2) x, msg2, c)
    })"
      unfolding Let_def
      unfolding aby3_stack_sharing_def
      by (simp add: comp_def make_sharing'_sel map_replicate_const zip_map_map_same[symmetric])

    also have simplify_minus_zero:
    " = (do {
       xs  xs;
       let n = length xs;
       x  sequence_spmf (map share_nat xs);

         πa  spmf_of_set {π. π permutes {..<n}};
         πb  spmf_of_set {π. π permutes {..<n}};
         πc  spmf_of_set {π. π permutes {..<n}};

         a  sequence_spmf (map share_nat (permute_list πa xs));
         b  sequence_spmf (map share_nat (permute_list (πa  πb) xs));
         c  sequence_spmf (map share_nat (permute_list (πa  πb  πc) xs));

― ‹round 1›
         let x' = map (get_party Party3) x;
         let ζa2 = map2 (-) (map (get_party Party2) a) (permute_list πa (map2 (+) (map (get_party Party2) x) (map (get_party Party3) x)));

― ‹round 2›
         let ζb2 = map2 (-) (map (get_party Party2) b) (permute_list πb (map (get_party Party2) a));

― ‹round 3›
         let ζc2 = (map (get_party Party2) c);

       let msg2 = ((ζa2, ζb2, ζc2), x', πa, πb);
       return_spmf (map (get_party Party2) x, msg2, c)
    })"
    supply [intro!] =
      bind_spmf_cong[OF refl]
      let_cong[OF refl]
      prod.case_cong[OF refl]
      bind_spmf_sequence_map_cong
      bind_spmf_sequence_replicate_cong
      bind_spmf_permutes_cong

      apply rule+
      apply (subst permute_list_replicate)
      subgoal by simp

      apply (subst map2_minus_zero)
      subgoal by simp
      subgoal by simp

      ..

    also have break_perms_1:
    " = (do {
       xs  xs;
       let n = length xs;
       x  sequence_spmf (map share_nat xs);

         (πa,πb,πc)  pair_spmf (spmf_of_set {π. π permutes {..<n}}) (pair_spmf (spmf_of_set {π. π permutes {..<n}}) (spmf_of_set {π. π permutes {..<n}}));

         a  sequence_spmf (map share_nat (permute_list πa xs));
         b  sequence_spmf (map share_nat (permute_list (πa  πb) xs));
         c  sequence_spmf (map share_nat (permute_list (πa  πb  πc) xs));

― ‹round 1›
         let x' = map (get_party Party3) x;
         let ζa2 = map2 (-) (map (get_party Party2) a) (permute_list πa (map2 (+) (map (get_party Party2) x) (map (get_party Party3) x)));

― ‹round 2›
         let ζb2 = map2 (-) (map (get_party Party2) b) (permute_list πb (map (get_party Party2) a));

― ‹round 3›
         let ζc2 = (map (get_party Party2) c);

       let msg2 = ((ζa2, ζb2, ζc2), x', πa, πb);
       return_spmf (map (get_party Party2) x, msg2, c)
    })"
      unfolding pair_spmf_alt_def by simp

    also have break_perms_2:
    " = (do {
       xs  xs;
       let n = length xs;
       x  sequence_spmf (map share_nat xs);

         ((πa,πb,πc),π)  map_spmf (λ(πa,πb,πc). ((πa,πb,πc), πa  πb  πc)) (pair_spmf (spmf_of_set {π. π permutes {..<n}}) (pair_spmf (spmf_of_set {π. π permutes {..<n}}) (spmf_of_set {π. π permutes {..<n}})));

         a  sequence_spmf (map share_nat (permute_list πa xs));
         b  sequence_spmf (map share_nat (permute_list (πa  πb) xs));
         c  sequence_spmf (map share_nat (permute_list π xs));

― ‹round 1›
         let x' = map (get_party Party3) x;
         let ζa2 = map2 (-) (map (get_party Party2) a) (permute_list πa (map2 (+) (map (get_party Party2) x) (map (get_party Party3) x)));

― ‹round 2›
         let ζb2 = map2 (-) (map (get_party Party2) b) (permute_list πb (map (get_party Party2) a));

― ‹round 3›
         let ζc2 = (map (get_party Party2) c);

       let msg2 = ((ζa2, ζb2, ζc2), x', πa, πb);
       return_spmf (map (get_party Party2) x, msg2, c)
    })"
      unfolding pair_spmf_alt_def map_spmf_conv_bind_spmf by simp

    also have break_perms_3:
    " = (do {
       xs  xs;
       let n = length xs;
       x  sequence_spmf (map share_nat xs);

         π  spmf_of_set {π. π permutes {..<n}};
         πa  spmf_of_set {π. π permutes {..<n}};
         πb  spmf_of_set {π. π permutes {..<n}};
         let πc = inv πb  inv πa  π;

         a  sequence_spmf (map share_nat (permute_list πa xs));
         b  sequence_spmf (map share_nat (permute_list (πa  πb) xs));
         c  sequence_spmf (map share_nat (permute_list π xs));

― ‹round 1›
         let x' = map (get_party Party3) x;
         let ζa2 = map2 (-) (map (get_party Party2) a) (permute_list πa (map2 (+) (map (get_party Party2) x) (map (get_party Party3) x)));

― ‹round 2›
         let ζb2 = map2 (-) (map (get_party Party2) b) (permute_list πb (map (get_party Party2) a));

― ‹round 3›
         let ζc2 = (map (get_party Party2) c);

       let msg2 = ((ζa2, ζb2, ζc2), x', πa, πb);
       return_spmf (map (get_party Party2) x, msg2, c)
    })"
      apply (unfold random_perm_right)
      apply (unfold map_spmf_conv_bind_spmf pair_spmf_alt_def)
      by simp

    also have break_seqs_3:
    " = (do {
       xs  xs;
       let n = length xs;

       x2  sequence_spmf (replicate n (spmf_of_set UNIV));
       x3  sequence_spmf (replicate n (spmf_of_set UNIV));
       let x = map3 (λa b c. make_sharing' Party2 Party3 Party1 b c (a - (b + c))) xs x2 x3;


         π  spmf_of_set {π. π permutes {..<n}};
         πa  spmf_of_set {π. π permutes {..<n}};
         πb  spmf_of_set {π. π permutes {..<n}};

         a2  sequence_spmf (replicate n (spmf_of_set UNIV));
         a3  sequence_spmf (replicate n (spmf_of_set UNIV));
         let a = map3 (λa b c. make_sharing' Party2 Party3 Party1 b c (a - (b + c))) (permute_list πa xs) a2 a3;
         b2  sequence_spmf (replicate n (spmf_of_set UNIV));
         b3  sequence_spmf (replicate n (spmf_of_set UNIV));
         let b = map3 (λa b c. make_sharing' Party2 Party3 Party1 b c (a - (b + c))) (permute_list (πa  πb) xs) b2 b3;
         c  sequence_spmf (map share_nat (permute_list π xs));

― ‹round 1›
         let x' = map (get_party Party3) x;
         let ζa2 = map2 (-) (map (get_party Party2) a) (permute_list πa (map2 (+) (map (get_party Party2) x) (map (get_party Party3) x)));

― ‹round 2›
         let ζb2 = map2 (-) (map (get_party Party2) b) (permute_list πb (map (get_party Party2) a));

― ‹round 3›
         let ζc2 = (map (get_party Party2) c);

       let msg2 = ((ζa2, ζb2, ζc2), x', πa, πb);
       return_spmf (map (get_party Party2) x, msg2, c)
    })"
      apply (unfold sequence_share_nat_calc'[of Party2 Party3 Party1, simplified])
      apply (simp add: pair_spmf_alt_def Let_def)
      done

    also have break_seqs_3:
    " = (do {
       xs  xs;
       let n = length xs;

       x2  sequence_spmf (replicate n (spmf_of_set UNIV));
       x3  sequence_spmf (replicate n (spmf_of_set UNIV));

         π  spmf_of_set {π. π permutes {..<n}};
         πa  spmf_of_set {π. π permutes {..<n}};
         πb  spmf_of_set {π. π permutes {..<n}};

         a2::natL list  sequence_spmf (replicate n (spmf_of_set UNIV));
         a3::natL list  sequence_spmf (replicate n (spmf_of_set UNIV));

         b2::natL list  sequence_spmf (replicate n (spmf_of_set UNIV));
         b3::natL list  sequence_spmf (replicate n (spmf_of_set UNIV));

         c  sequence_spmf (map share_nat (permute_list π xs));

― ‹round 1›
         let x' = x3;
         let ζa2 = map2 (-) a2 (permute_list πa (map2 (+) x2 x3));

― ‹round 2›
         let ζb2 = map2 (-) b2 (permute_list πb a2);

― ‹round 3›
         let ζc2 = (map (get_party Party2) c);

       let msg2 = ((ζa2, ζb2, ζc2), x', πa, πb);
       return_spmf (x2, msg2, c)

    })"
    supply [intro!] =
      bind_spmf_cong[OF refl]
      let_cong[OF refl]
      prod.case_cong[OF refl]
      bind_spmf_sequence_map_cong
      bind_spmf_sequence_replicate_cong
      bind_spmf_permutes_cong
      unfolding Let_def
      apply rule+
      apply (auto simp: map2_ignore1 map2_ignore2 comp_def prod.case_distrib bind_spmf_const make_sharing'_sel)
      done

    also have
      " = (do {x  x_dist; y  aby3_shuffleF x; let xr = map (get_party r) x; let yr = map (get_party r) y; msg  S r xr yr; return_spmf (xr, msg, y)})"
    supply [intro!] =
      bind_spmf_cong[OF refl]
      let_cong[OF refl]
      prod.case_cong[OF refl]
      bind_spmf_sequence_map_cong
      bind_spmf_sequence_replicate_cong
      bind_spmf_permutes_cong

      unfolding xs
      unfolding aby3_shuffleF_def
      apply (simp add: bind_spmf_const map_spmf_conv_bind_spmf)
      apply (subst lossless_sequence_spmf[unfolded lossless_spmf_def])
      subgoal by simp
      apply (subst lossless_sequence_spmf[unfolded lossless_spmf_def])
      subgoal by simp
      apply simp

      apply rule

      apply (subst (2) sequence_share_nat_calc'[of Party2 Party3 Party1, simplified])
      apply (subst (2) Let_def)
      apply (simp add: pair_spmf_alt_def)
      apply (subst Let_def)
      unfolding S_def r
      apply (simp add: pair_spmf_alt_def comp_def prod.case_distrib map2_ignore2 make_sharing'_sel)

      apply (rule trans[rotated])
      apply (rule bind_spmf_sequence_replicate_cong)
       apply (rule bind_spmf_sequence_replicate_cong)
       apply (simp add: map2_ignore1 map2_ignore2)
      apply (subst bind_spmf_const)
      apply (subst lossless_sequence_spmf[unfolded lossless_spmf_def])
      subgoal by simp
      apply simp

      apply (subst (2) bind_commute_spmf[where p="sequence_spmf (replicate _ _)"])
      apply (subst bind_commute_spmf[where q="sequence_spmf (map share_nat _)"])
      apply (subst bind_commute_spmf[where q="sequence_spmf (map share_nat _)"])
      apply (subst bind_commute_spmf[where q="sequence_spmf (map share_nat _)"])
      apply (subst bind_commute_spmf[where q="sequence_spmf (map share_nat _)"])
      apply (subst bind_commute_spmf[where q="sequence_spmf (map share_nat _)"])

      apply rule
      apply rule
      apply rule

      unfolding S2_def Let_def
      apply simp
      done

    finally have "?hoisted = " .
  } note simulate_party2 = this


  { assume r: "r = Party3"
    have project_to_Party3:
    "?hoisted = (do {
       xs  xs;
       let n = length xs;
       x  sequence_spmf (map share_nat xs);

         πa  spmf_of_set {π. π permutes {..<n}};
         πb  spmf_of_set {π. π permutes {..<n}};
         πc  spmf_of_set {π. π permutes {..<n}};

         a  sequence_spmf (map share_nat (permute_list πa xs));
         b  sequence_spmf (map share_nat (permute_list (πa  πb) xs));
         c  sequence_spmf (map share_nat (permute_list (πa  πb  πc) xs));

― ‹round 1›
         let x' = map (get_party Party3) x;
         let y' = map (aby3_stack_sharing Party1) x;
         let ζa3 = map (case_prod (-)) (zip (map (get_party Party3) a) (map (get_party Party3) (permute_list πa y')));

― ‹round 2›
         let a' = map (get_party Party1) a;
         let y' = map (aby3_stack_sharing Party2) a;
         let ζb3 = map (case_prod (-)) (zip (map (get_party Party3) b) (map (get_party Party3) (permute_list πb y')));

― ‹round 3›
         let b' = map (get_party Party2) b;
         let y' = map (aby3_stack_sharing Party3) b;
         let ζc3 = map (case_prod (-)) (zip (map (get_party Party3) c) (map (get_party Party3) (permute_list πc y')));

       let msg3 = ((ζa3, ζb3, ζc3), a', πb, πc);
       return_spmf (map (get_party Party3) x, msg3, c)
   })"
      by (simp add: r Let_def get_party_map_sharing2 map_map_prod_zip')

    also have project_to_Party1:
    " = (do {
       xs  xs;
       let n = length xs;
       x  sequence_spmf (map share_nat xs);

         πa  spmf_of_set {π. π permutes {..<n}};
         πb  spmf_of_set {π. π permutes {..<n}};
         πc  spmf_of_set {π. π permutes {..<n}};

         a  sequence_spmf (map share_nat (permute_list πa xs));
         b  sequence_spmf (map share_nat (permute_list (πa  πb) xs));
         c  sequence_spmf (map share_nat (permute_list (πa  πb  πc) xs));

― ‹round 1›
         let x' = map (get_party Party3) x;
         let y' = map (aby3_stack_sharing Party1) x;
         let ζa3 = map2 (-) (map (get_party Party3) a) (permute_list πa (map (get_party Party3) y'));

― ‹round 2›
         let a' = map (get_party Party1) a;
         let y' = map (aby3_stack_sharing Party2) a;
         let ζb3 = map2 (-) (map (get_party Party3) b) (permute_list πb (map (get_party Party3) y'));

― ‹round 3›
         let b' = map (get_party Party2) b;
         let y' = map (aby3_stack_sharing Party3) b;
         let ζc3 = map2 (-) (map (get_party Party3) c) (permute_list πc (map (get_party Party3) y'));

       let msg3 = ((ζa3, ζb3, ζc3), a', πb, πc);
       return_spmf (map (get_party Party3) x, msg3, c)

    })"
    supply [intro!] =
      bind_spmf_cong[OF refl]
      let_cong[OF refl]
      prod.case_cong[OF refl]
      bind_spmf_sequence_map_cong
      bind_spmf_sequence_replicate_cong
      bind_spmf_permutes_cong

      apply rule+
      apply (subst permute_list_map[symmetric])
      subgoal by simp

      apply rule+
      apply (subst permute_list_map[symmetric])
      subgoal by simp

      apply rule+
      apply (subst permute_list_map[symmetric])
      subgoal by simp
      ..

    also have reduce_Lets:
    " = (do {
       xs  xs;
       let n = length xs;
       x  sequence_spmf (map share_nat xs);

         πa  spmf_of_set {π. π permutes {..<n}};
         πb  spmf_of_set {π. π permutes {..<n}};
         πc  spmf_of_set {π. π permutes {..<n}};

         a  sequence_spmf (map share_nat (permute_list πa xs));
         b  sequence_spmf (map share_nat (permute_list (πa  πb) xs));
         c  sequence_spmf (map share_nat (permute_list (πa  πb  πc) xs));

― ‹round 1›
         let ζa3 = map2 (-) (map (get_party Party3) a) (permute_list πa (replicate (length x) 0));

― ‹round 2›
         let a' = map (get_party Party1) a;
         let ζb3 = map2 (-) (map (get_party Party3) b) (permute_list πb (map2 (+) (map (get_party Party3) a) (map (get_party Party1) a)));

― ‹round 3›
         let ζc3 = map2 (-) (map (get_party Party3) c) (permute_list πc (map (get_party Party3) b));

       let msg3 = ((ζa3, ζb3, ζc3), a', πb, πc);
       return_spmf (map (get_party Party3) x, msg3, c)
    })"
      unfolding Let_def
      unfolding aby3_stack_sharing_def
      by (simp add: comp_def make_sharing'_sel map_replicate_const zip_map_map_same[symmetric])

    also have simplify_minus_zero:
    " = (do {
       xs  xs;
       let n = length xs;
       x  sequence_spmf (map share_nat xs);

         πa  spmf_of_set {π. π permutes {..<n}};
         πb  spmf_of_set {π. π permutes {..<n}};
         πc  spmf_of_set {π. π permutes {..<n}};

         a  sequence_spmf (map share_nat (permute_list πa xs));
         b  sequence_spmf (map share_nat (permute_list (πa  πb) xs));
         c  sequence_spmf (map share_nat (permute_list (πa  πb  πc) xs));

― ‹round 1›
         let ζa3 = (map (get_party Party3) a);

― ‹round 2›
         let a' = map (get_party Party1) a;
         let ζb3 = map2 (-) (map (get_party Party3) b) (permute_list πb (map2 (+) (map (get_party Party3) a) (map (get_party Party1) a)));

― ‹round 3›
         let ζc3 = map2 (-) (map (get_party Party3) c) (permute_list πc (map (get_party Party3) b));

       let msg3 = ((ζa3, ζb3, ζc3), a', πb, πc);
       return_spmf (map (get_party Party3) x, msg3, c)
    })"
    supply [intro!] =
      bind_spmf_cong[OF refl]
      let_cong[OF refl]
      prod.case_cong[OF refl]
      bind_spmf_sequence_map_cong
      bind_spmf_sequence_replicate_cong
      bind_spmf_permutes_cong

      apply rule+
      apply (subst permute_list_replicate)
      subgoal by simp

      apply (subst map2_minus_zero)
      subgoal by simp
      subgoal by simp

      ..

    also have break_perms_1:
    " = (do {
       xs  xs;
       let n = length xs;
       x  sequence_spmf (map share_nat xs);

         (πa,πb,πc)  pair_spmf (spmf_of_set {π. π permutes {..<n}}) (pair_spmf (spmf_of_set {π. π permutes {..<n}}) (spmf_of_set {π. π permutes {..<n}}));

         a  sequence_spmf (map share_nat (permute_list πa xs));
         b  sequence_spmf (map share_nat (permute_list (πa  πb) xs));
         c  sequence_spmf (map share_nat (permute_list (πa  πb  πc) xs));

― ‹round 1›
         let ζa3 = (map (get_party Party3) a);

― ‹round 2›
         let a' = map (get_party Party1) a;
         let ζb3 = map2 (-) (map (get_party Party3) b) (permute_list πb (map2 (+) (map (get_party Party3) a) (map (get_party Party1) a)));

― ‹round 3›
         let ζc3 = map2 (-) (map (get_party Party3) c) (permute_list πc (map (get_party Party3) b));

       let msg3 = ((ζa3, ζb3, ζc3), a', πb, πc);
       return_spmf (map (get_party Party3) x, msg3, c)
    })"
      unfolding pair_spmf_alt_def by simp

    also have break_perms_2:
    " = (do {
       xs  xs;
       let n = length xs;
       x  sequence_spmf (map share_nat xs);

         ((πa,πb,πc),π)  map_spmf (λ(πa,πb,πc). ((πa,πb,πc), πa  πb  πc)) (pair_spmf (spmf_of_set {π. π permutes {..<n}}) (pair_spmf (spmf_of_set {π. π permutes {..<n}}) (spmf_of_set {π. π permutes {..<n}})));

         a  sequence_spmf (map share_nat (permute_list πa xs));
         b  sequence_spmf (map share_nat (permute_list (πa  πb) xs));
         c  sequence_spmf (map share_nat (permute_list π xs));

― ‹round 1›
         let ζa3 = (map (get_party Party3) a);

― ‹round 2›
         let a' = map (get_party Party1) a;
         let ζb3 = map2 (-) (map (get_party Party3) b) (permute_list πb (map2 (+) (map (get_party Party3) a) (map (get_party Party1) a)));

― ‹round 3›
         let ζc3 = map2 (-) (map (get_party Party3) c) (permute_list πc (map (get_party Party3) b));

       let msg3 = ((ζa3, ζb3, ζc3), a', πb, πc);
       return_spmf (map (get_party Party3) x, msg3, c)
    })"
      unfolding pair_spmf_alt_def map_spmf_conv_bind_spmf by simp

    also have break_perms_3:
    " = (do {
       xs  xs;
       let n = length xs;
       x  sequence_spmf (map share_nat xs);

         π  spmf_of_set {π. π permutes {..<n}};
         πb  spmf_of_set {π. π permutes {..<n}};
         πc  spmf_of_set {π. π permutes {..<n}};
         let πa = π  inv πc  inv πb;

         a  sequence_spmf (map share_nat (permute_list πa xs));
         b  sequence_spmf (map share_nat (permute_list (πa  πb) xs));
         c  sequence_spmf (map share_nat (permute_list π xs));

― ‹round 1›
         let ζa3 = (map (get_party Party3) a);

― ‹round 2›
         let a' = map (get_party Party1) a;
         let ζb3 = map2 (-) (map (get_party Party3) b) (permute_list πb (map2 (+) (map (get_party Party3) a) (map (get_party Party1) a)));

― ‹round 3›
         let ζc3 = map2 (-) (map (get_party Party3) c) (permute_list πc (map (get_party Party3) b));

       let msg3 = ((ζa3, ζb3, ζc3), a', πb, πc);
       return_spmf (map (get_party Party3) x, msg3, c)
    })"
      apply (unfold random_perm_left)
      apply (unfold map_spmf_conv_bind_spmf pair_spmf_alt_def)
      by (simp add: Let_def)

    also have break_seqs_3:
    " = (do {
       xs  xs;
       let n = length xs;
       x  sequence_spmf (map share_nat xs);

         π  spmf_of_set {π. π permutes {..<n}};
         πb  spmf_of_set {π. π permutes {..<n}};
         πc  spmf_of_set {π. π permutes {..<n}};
         let πa = π  inv πc  inv πb;

         a3  sequence_spmf (replicate n (spmf_of_set UNIV));
         a1  sequence_spmf (replicate n (spmf_of_set UNIV));
         let a = map3 (λa b c. make_sharing' Party3 Party1 Party2 b c (a - (b + c))) (permute_list πa xs) a3 a1;
         b3  sequence_spmf (replicate n (spmf_of_set UNIV));
         b1  sequence_spmf (replicate n (spmf_of_set UNIV));
         let b = map3 (λa b c. make_sharing' Party3 Party1 Party2 b c (a - (b + c))) (permute_list (πa  πb) xs) b3 b1;
         c  sequence_spmf (map share_nat (permute_list π xs));

― ‹round 1›
         let ζa3 = (map (get_party Party3) a);

― ‹round 2›
         let a' = map (get_party Party1) a;
         let ζb3 = map2 (-) (map (get_party Party3) b) (permute_list πb (map2 (+) (map (get_party Party3) a) (map (get_party Party1) a)));

― ‹round 3›
         let ζc3 = map2 (-) (map (get_party Party3) c) (permute_list πc (map (get_party Party3) b));

       let msg3 = ((ζa3, ζb3, ζc3), a', πb, πc);
       return_spmf (map (get_party Party3) x, msg3, c)
    })"
      apply (unfold sequence_share_nat_calc'[of Party3 Party1 Party2, simplified])
      apply (simp add: pair_spmf_alt_def Let_def)
      done

    also have break_seqs_3:
    " = (do {
       xs  xs;
       let n = length xs;

       x  sequence_spmf (map share_nat xs);

         π  spmf_of_set {π. π permutes {..<n}};
         πb  spmf_of_set {π. π permutes {..<n}};
         πc  spmf_of_set {π. π permutes {..<n}};
         let πa = π  inv πc  inv πb;

         a3::natL list  sequence_spmf (replicate n (spmf_of_set UNIV));
         a1::natL list  sequence_spmf (replicate n (spmf_of_set UNIV));

         b3::natL list  sequence_spmf (replicate n (spmf_of_set UNIV));
         b1::natL list  sequence_spmf (replicate n (spmf_of_set UNIV));

         c  sequence_spmf (map share_nat (permute_list π xs));

― ‹round 1›
         let ζa3 = a3;

― ‹round 2›
         let a' = a1;
         let ζb3 = map2 (-) b3 (permute_list πb (map2 (+) a3 a1));

― ‹round 3›
         let ζc3 = map2 (-) (map (get_party Party3) c) (permute_list πc b3);

       let msg3 = ((ζa3, ζb3, ζc3), a', πb, πc);
       return_spmf (map (get_party Party3) x, msg3, c)
    })"
    supply [intro!] =
      bind_spmf_cong[OF refl]
      let_cong[OF refl]
      prod.case_cong[OF refl]
      bind_spmf_sequence_map_cong
      bind_spmf_sequence_replicate_cong
      bind_spmf_permutes_cong
      unfolding Let_def
      apply rule+
      apply (auto simp: map2_ignore1 map2_ignore2 comp_def prod.case_distrib bind_spmf_const make_sharing'_sel)
      done

    also have
      " = (do {x  x_dist; y  aby3_shuffleF x; let xr = map (get_party r) x; let yr = map (get_party r) y; msg  S r xr yr; return_spmf (xr, msg, y)})"
      unfolding xs
      unfolding aby3_shuffleF_def
      apply (subst bind_spmf_const)
      apply (subst lossless_sequence_spmf[unfolded lossless_spmf_def])
      subgoal by simp
      apply (simp add:  map_spmf_conv_bind_spmf)
      apply (subst bind_commute_spmf[where q="sequence_spmf (map share_nat _)"])
      apply (subst bind_commute_spmf[where q="sequence_spmf (map share_nat _)"])
      apply (subst bind_commute_spmf[where q="sequence_spmf (map share_nat _)"])
      apply (subst bind_commute_spmf[where q="sequence_spmf (map share_nat _)"])
      apply (subst bind_commute_spmf[where q="sequence_spmf (map share_nat _)"])

      apply (subst (3) hoist_map_spmf[where s="sequence_spmf (map share_nat _)" and f="map reconstruct"])
      apply (subst map_sequence_share_nat_reconstruct)
      apply (simp add: map_spmf_conv_bind_spmf)
      apply (subst Let_def)

    supply [intro!] =
      bind_spmf_cong[OF refl]
      let_cong[OF refl]
      prod.case_cong[OF refl]
      bind_spmf_sequence_map_cong
      bind_spmf_sequence_replicate_cong
      bind_spmf_permutes_cong

    supply [simp] = finite_permutations

      apply rule
      apply rule
      apply simp
      apply rule
      apply rule
      unfolding S_def S3_def r
      apply (simp add: Let_def)
      done

    finally have "?hoisted = " .
  } note simulate_party3 = this

  show ?thesis
    unfolding hoisted_save
    apply (cases r)
    subgoal using simulate_party1 .
    subgoal using simulate_party2 .
    subgoal using simulate_party3 .
    done
qed


lemma Collect_case_prod:
  "{f x y | x y. P x y} = (case_prod f) ` (Collect (case_prod P))"
  by auto

lemma inj_split_Cons': "inj_on (λ(n, xs). n#xs) X"
  by (auto intro!: inj_onI)

lemma finite_indicator_eq_sum:
  "finite A  indicat_real A x = sum (indicat_real {x}) A"
  by (induction rule: finite_induct) (auto  simp: indicator_def)

lemma spmf_of_set_Cons:
  "spmf_of_set (set_Cons A B) = map2_spmf (#) (spmf_of_set A) (spmf_of_set B)"
  unfolding set_Cons_def pair_spmf_of_set
  apply (rule spmf_eq_iff_set)
  subgoal unfolding Collect_case_prod apply (auto simp: set_spmf_of_set )
     apply (subst (asm) finite_image_iff)
    subgoal by (rule inj_split_Cons')
    subgoal by (auto simp: finite_cartesian_product_iff)
    done
  subgoal unfolding Collect_case_prod
    apply (auto simp: spmf_of_set map_spmf_conv_bind_spmf spmf_bind integral_spmf_of_set)
    apply (subst card_image)
    subgoal by (rule inj_split_Cons')
    apply (auto simp: card_eq_0_iff indicator_single_Some)
    apply (subst (asm) finite_indicator_eq_sum)
    subgoal by (simp add: finite_image_iff inj_split_Cons')
    apply (subst (asm) sum.reindex)
    subgoal by (simp add: finite_image_iff inj_split_Cons')
    apply (auto)
    done
  done

lemma sequence_spmf_replicate:
  "sequence_spmf (replicate n (spmf_of_set A)) = spmf_of_set (listset (replicate n A))"
  apply (induction n)
  subgoal by (auto simp: spmf_of_set_singleton)
  subgoal by (auto simp: spmf_of_set_Cons)
  done

lemma listset_replicate:
  "listset (replicate n A) = {l. length l = n  set l  A}"
  apply (induction n)
   apply (auto simp: set_Cons_def)
  subgoal for n x
    by (cases x; simp)
  done

lemma map2_map2_map3:
  "map2 f (map2 g x y) z = map3 (λx y. f (g x y)) x y z"
  by (auto simp: zip_assoc map_zip_map)

lemma inv_add_sequence:
  assumes "n = length x"
  shows "
  map_spmf (λζ::natL list. (ζ, map2 (+) ζ x)) (sequence_spmf (replicate n (spmf_of_set UNIV)))
  =
  map_spmf (λy. (map2 (-) y x, y)) (sequence_spmf (replicate n (spmf_of_set UNIV)))"
  unfolding sequence_spmf_replicate
  apply (subst map_spmf_of_set_inj_on)
  subgoal   unfolding inj_on_def by simp
  apply (subst map_spmf_of_set_inj_on)
  subgoal  unfolding inj_on_def by simp
  apply (rule arg_cong[where f="spmf_of_set"])
  using assms apply (auto simp: image_def listset_replicate map2_map2_map3 zip_same_conv_map map_zip_map2 map2_ignore2)
  done
  
lemma S1_def_simplified:
    "S1 x1 yc1 = (do {
       let n = length x1;

       πa  spmf_of_set {π. π permutes {..<n}};
       πc  spmf_of_set {π. π permutes {..<n}};
       ζa1::natL list  sequence_spmf (replicate n (spmf_of_set UNIV));
       yb1::natL list  sequence_spmf (replicate n (spmf_of_set UNIV));
       yb2::natL list  sequence_spmf (replicate n (spmf_of_set UNIV));

       let ζc1 = map2 (-) (yc1) (permute_list πc (map2 (+) yb1 yb2));
       return_spmf ((ζa1, yb1, ζc1), yb2, πa, πc)
    })"

  unfolding S1_def

    supply [intro!] =
      bind_spmf_cong[OF refl]
      let_cong[OF refl]
      prod.case_cong[OF refl]
      bind_spmf_sequence_map_cong
      bind_spmf_sequence_replicate_cong
      bind_spmf_permutes_cong
      bind_spmf_sequence_map_share_nat_cong

  apply rule
  apply rule
  apply rule
  apply (subst hoist_map_spmf'[where s="sequence_spmf _" and f="λx. map2 (-) x _"])
  apply (subst inv_add_sequence[symmetric])
  subgoal by simp
  unfolding map_spmf_conv_bind_spmf
  apply simp
  done

lemma S2_def_simplified:
    "S2 x2 yc2 = (do {
       let n = length x2;

       x3  sequence_spmf (replicate n (spmf_of_set UNIV));
       πa  spmf_of_set {π. π permutes {..<n}};
       πb  spmf_of_set {π. π permutes {..<n}};
       ζa2::natL list  sequence_spmf (replicate n (spmf_of_set UNIV));
       ζb2::natL list  sequence_spmf (replicate n (spmf_of_set UNIV));

       let msg2 = ((ζa2, ζb2, yc2), x3, πa, πb);
       return_spmf msg2
    })"
  unfolding S2_def

    supply [intro!] =
      bind_spmf_cong[OF refl]
      let_cong[OF refl]
      prod.case_cong[OF refl]
      bind_spmf_sequence_map_cong
      bind_spmf_sequence_replicate_cong
      bind_spmf_permutes_cong
      bind_spmf_sequence_map_share_nat_cong

  apply rule
  apply rule
  apply rule
  apply rule
  apply (rule trans)
  apply (rule bind_spmf_sequence_replicate_cong)

  apply (subst hoist_map_spmf'[where s="sequence_spmf _" and f="λx. map2 (-) x _"])
  apply (subst inv_add_sequence[symmetric])
  subgoal by simp
   apply (rule refl)
  apply (unfold map_spmf_conv_bind_spmf)
  apply simp
  apply (subst hoist_map_spmf'[where s="sequence_spmf _" and f="λx. map2 (-) x _"])
  apply (subst inv_add_sequence[symmetric])
  subgoal by simp
  apply (unfold map_spmf_conv_bind_spmf)
  apply simp
  done

lemma S3_def_simplified:
    "S3 x3 yc3 = (do {
       let n = length x3;

       πb  spmf_of_set {π. π permutes {..<n}};
       πc  spmf_of_set {π. π permutes {..<n}};
       ya3::natL list  sequence_spmf (replicate n (spmf_of_set UNIV));
       ya1::natL list  sequence_spmf (replicate n (spmf_of_set UNIV));
       ζb3::natL list  sequence_spmf (replicate n (spmf_of_set UNIV));

       let ζc3 = map2 (-) yc3 (permute_list πc (map2 (+) ζb3 (permute_list πb (map2 (+) ya3 ya1))));
       return_spmf ((ya3, ζb3, ζc3), ya1, πb, πc)
    })"

  unfolding S3_def

    supply [intro!] =
      bind_spmf_cong[OF refl]
      let_cong[OF refl]
      prod.case_cong[OF refl]
      bind_spmf_sequence_map_cong
      bind_spmf_sequence_replicate_cong
      bind_spmf_permutes_cong
      bind_spmf_sequence_map_share_nat_cong

  apply rule
  apply rule
  apply rule
  apply rule
  apply rule
  apply (subst hoist_map_spmf'[where s="sequence_spmf _" and f="λx. map2 (-) x _"])
  apply (subst inv_add_sequence[symmetric])
  subgoal by simp
  apply (unfold map_spmf_conv_bind_spmf)
  apply simp
  done

end