Theory Smolka_AI

theory Smolka_AI
  imports Main
begin

section ‹AI-Authored proof formalization of the roundtrip algorithm›

text ‹
  Formalization of the Smolka-Blanchette printing--parsing roundtrip algorithm for Isabelle.

  The algorithm takes a fully-typed term and selects a locally minimal and complete
  set of type annotations so the term can be unambiguously re-parsed via
  Hindley--Milner type inference.
›

subsection ‹Types and Type Substitutions›

text ‹Types are built from type variables and
  type constructors with a fixed arity (implicit in the list length).›

datatype ty =
  TVar string
| TCons string "ty list"

text ‹We define the function type constructor as a distinguished binary constructor.›

definition fun_ty :: "ty  ty  ty" (infixr "" 65) where
  "fun_ty τ1 τ2 = TCons ''fun'' [τ1, τ2]"

text ‹Type variables occurring in a type.›

fun tvars_ty :: "ty  string set" where
  "tvars_ty (TVar v) = {v}"
| "tvars_ty (TCons _ ts) = (set (map tvars_ty ts))"

text ‹Definition (Type Substitution). A type substitution is a function
  from variable names to types. It extends homomorphically to types.›

fun subst_ty :: "(string  ty)  ty  ty" where
  "subst_ty σ (TVar v) = σ v"
| "subst_ty σ (TCons k ts) = TCons k (map (subst_ty σ) ts)"

text ‹The arrow type τ → τ› is strictly larger than τ›,
  hence τ → τ ≠ τ›. This is used in the minimality proof.
  The proof uses the built-in size› function from the datatype package.›

lemma arrow_neq_self: "τ  τ  τ"
proof
  assume "τ  τ = τ"
  then have "size (τ  τ) = size τ" by simp
  then show False by (simp add: fun_ty_def)
qed

text ‹Lemma (Uniqueness of Type Matching). If two substitutions agree
  when applied to a type, they agree on all type variables of that type.›

lemma unique_type_match:
  "subst_ty σ1 τ = subst_ty σ2 τ  α  tvars_ty τ  σ1 α = σ2 α"
  by (induction τ) auto

text ‹Substitution on type variables: the domain is precisely termtvars_ty τ.›

lemma tvars_subst_ty: "tvars_ty (subst_ty σ τ) = (tvars_ty ` σ ` tvars_ty τ)"
  by (induction τ) auto

lemma subst_ty_id:
  "(v. v  tvars_ty τ  σ v = TVar v)  subst_ty σ τ = τ"
  by (induction τ) (auto intro: map_idI)

lemma subst_ty_agree:
  "(v. v  tvars_ty τ  σ1 v = σ2 v)  subst_ty σ1 τ = subst_ty σ2 τ"
  by (induction τ) auto

lemma subst_ty_compose:
  "subst_ty σ1 (subst_ty σ2 τ) = subst_ty (λv. subst_ty σ1 (σ2 v)) τ"
  by (induction τ) auto

subsection ‹Contexts›

text ‹A context is a finite partial function from variable names to types.›

type_synonym ctx = "string  ty"

text ‹Type variables of a context.›

definition tvars_ctx :: "ctx  string set" where
  "tvars_ctx Γ = (tvars_ty ` ran Γ)"

text ‹Applying a type substitution to a context.›

definition subst_ctx :: "(string  ty)  ctx  ctx" where
  "subst_ctx σ Γ = map_option (subst_ty σ)  Γ"

lemma subst_ctx_dom [simp]: "dom (subst_ctx σ Γ) = dom Γ"
  by (auto simp: subst_ctx_def dom_def)

lemma subst_ctx_app: "x  dom Γ  subst_ctx σ Γ x = Some (subst_ty σ (the (Γ x)))"
  by (auto simp: subst_ctx_def dom_def)

lemma subst_ctx_update:
  "subst_ctx σ (Γ(x  τ)) = (subst_ctx σ Γ)(x  subst_ty σ τ)"
  by (auto simp: subst_ctx_def)

lemma subst_ctx_id:
  "(v. v  tvars_ctx Γ  σ v = TVar v)  subst_ctx σ Γ = Γ"
  using ranI by (fastforce intro!: map_option_idI subst_ty_id simp: subst_ctx_def tvars_ctx_def)

subsection ‹Terms›

subsubsection ‹Raw Terms›

text ‹Raw terms include type constraints and binder constraints as annotations.›

datatype raw_term =
  RConst string
| RVar string
| RAbs string raw_term
| RAbsT string ty raw_term
| RApp raw_term raw_term
| RConstrain raw_term ty

text ‹The function strip› removes all constraints from a raw term.›

fun strip :: "raw_term  raw_term" where
  "strip (RConst c) = RConst c"
| "strip (RVar x) = RVar x"
| "strip (RAbs x t) = RAbs x (strip t)"
| "strip (RAbsT x _ t) = RAbs x (strip t)"
| "strip (RApp t1 t2) = RApp (strip t1) (strip t2)"
| "strip (RConstrain t _) = strip t"

text ‹A term is constraint-free if stripping is the identity.›

fun constraint_free :: "raw_term  bool" where
  "constraint_free (RConst _) = True"
| "constraint_free (RVar _) = True"
| "constraint_free (RAbs _ t) = constraint_free t"
| "constraint_free (RAbsT _ _ _) = False"
| "constraint_free (RApp t1 t2) = (constraint_free t1  constraint_free t2)"
| "constraint_free (RConstrain _ _) = False"

lemma strip_idem [simp]: "strip (strip t) = strip t"
  by (induction t) auto

lemma strip_constraint_free: "constraint_free (strip t)"
  by (induction t) auto

subsubsection ‹Annotated Terms›

text ‹Definition (Annotated Terms). Every node carries exactly one type annotation.›

datatype aterm =
  AConst string ty
| AVar string ty
| AAbs string ty aterm ty
| AApp aterm aterm ty

text ‹Definition (Type of an Annotated Term). Reads the outermost type.›

fun typeof_aterm :: "aterm  ty" where
  "typeof_aterm (AConst _ τ) = τ"
| "typeof_aterm (AVar _ τ) = τ"
| "typeof_aterm (AAbs _ _ _ τ') = τ'"
| "typeof_aterm (AApp _ _ τ) = τ"

text ‹Definition (Erasure). Strips all types from an annotated term,
  producing a constraint-free raw term.›

fun erase :: "aterm  raw_term" where
  "erase (AConst c _) = RConst c"
| "erase (AVar x _) = RVar x"
| "erase (AAbs x _ a _) = RAbs x (erase a)"
| "erase (AApp a1 a2 _) = RApp (erase a1) (erase a2)"

lemma erase_constraint_free: "constraint_free (erase a)"
  by (induction a) auto

lemma strip_erase [simp]: "strip (erase a) = erase a"
  by (induction a) auto

text ‹Definition (Type Substitution on Annotated Terms).›

fun subst_aterm :: "(string  ty)  aterm  aterm" where
  "subst_aterm σ (AConst c τ) = AConst c (subst_ty σ τ)"
| "subst_aterm σ (AVar x τ) = AVar x (subst_ty σ τ)"
| "subst_aterm σ (AAbs x τ a τ') = AAbs x (subst_ty σ τ) (subst_aterm σ a) (subst_ty σ τ')"
| "subst_aterm σ (AApp a1 a2 τ) = AApp (subst_aterm σ a1) (subst_aterm σ a2) (subst_ty σ τ)"

text ‹Key property: erasure is invariant under type substitution.›

lemma erase_subst [simp]: "erase (subst_aterm σ a) = erase a"
  by (induction a) auto

text ‹Key property: termtypeof_aterm commutes with type substitution.›

lemma typeof_subst [simp]: "typeof_aterm (subst_aterm σ a) = subst_ty σ (typeof_aterm a)"
  by (cases a) auto

text ‹Type variables of an annotated term.›

fun tvars_aterm :: "aterm  string set" where
  "tvars_aterm (AConst _ τ) = tvars_ty τ"
| "tvars_aterm (AVar _ τ) = tvars_ty τ"
| "tvars_aterm (AAbs _ τ a τ') = tvars_ty τ  tvars_aterm a  tvars_ty τ'"
| "tvars_aterm (AApp a1 a2 τ) = tvars_aterm a1  tvars_aterm a2  tvars_ty τ"

text ‹If two substitutions agree on the type variables of a term, they give the same result.›

lemma subst_aterm_agree:
  "(v. v  tvars_aterm a  σ1 v = σ2 v)  subst_aterm σ1 a = subst_aterm σ2 a"
  by (induction a) (auto intro: subst_ty_agree)

text ‹Converse: if two substitutions give the same annotated term, they agree on the
  type variables of that term. This is ``injectivity'' of substitution on annotated terms.›

lemma subst_aterm_injective:
  "subst_aterm σ1 a = subst_aterm σ2 a  α  tvars_aterm a  σ1 α = σ2 α"
  using unique_type_match by (induction a arbitrary: α) auto

lemma subst_ty_TVar [simp]: "subst_ty TVar τ = τ"
  by (induction τ) (auto simp: map_idI)

lemma subst_aterm_id:
  "(v. v  tvars_aterm a  σ v = TVar v)  subst_aterm σ a = a"
  by (induction a) (auto intro: subst_ty_id)

subsubsection ‹Well-Typedness›

text ‹We parameterize the development by a function giving the declared type scheme
  of each constant. This is the type scheme that must be instantiated at each use site.›

text ‹An annotated term is well-typed in a context
  if each node satisfies the appropriate typing constraint.›

inductive well_typed :: "(string  ty)  ctx  aterm  bool" where
  wt_const: "ρ. τ = subst_ty ρ (const_type c) 
    well_typed const_type Γ (AConst c τ)"
| wt_var: "Γ x = Some τ 
    well_typed const_type Γ (AVar x τ)"
| wt_abs: "well_typed const_type (Γ(x  τ)) a 
    τ' = τ  typeof_aterm a 
    well_typed const_type Γ (AAbs x τ a τ')"
| wt_app: "well_typed const_type Γ a1 
    well_typed const_type Γ a2 
    typeof_aterm a1 = typeof_aterm a2  τ 
    well_typed const_type Γ (AApp a1 a2 τ)"

text ‹Lemma (Substitution Preserves Well-Typedness). If terma is well-typed in termΓ, 
  then termsubst_aterm σ a is well-typed in termsubst_ctx σ Γ.›

lemma subst_preserves_wt:
  "well_typed ct Γ a  well_typed ct (subst_ctx σ Γ) (subst_aterm σ a)"
proof (induction rule: well_typed.induct)
  case (wt_const τ ct c Γ)
  then obtain ρ where "τ = subst_ty ρ (ct c)" by auto
  then have "subst_ty σ τ = subst_ty (λv. subst_ty σ (ρ v)) (ct c)"
    by (simp add: subst_ty_compose)
  then show ?case by (auto intro: well_typed.wt_const)
next
  case (wt_var Γ x τ ct)
  then show ?case by (auto simp: subst_ctx_def intro: well_typed.wt_var)
next
  case (wt_abs ct Γ x τ a τ')
  then have ih: "well_typed ct ((subst_ctx σ Γ)(x  subst_ty σ τ)) (subst_aterm σ a)"
    using subst_ctx_update by metis
  from wt_abs.hyps(2) have "subst_ty σ τ' = subst_ty σ τ  typeof_aterm (subst_aterm σ a)"
    by (simp add: fun_ty_def)
  with ih show ?case by (auto intro: well_typed.wt_abs)
next
  case (wt_app ct Γ a1 a2 τ)
  from wt_app.hyps(3) have
    "typeof_aterm (subst_aterm σ a1) = typeof_aterm (subst_aterm σ a2)  subst_ty σ τ"
    by (simp add: fun_ty_def)
  with wt_app.IH show ?case by (auto intro: well_typed.wt_app)
qed

text ‹Corollary: if termσ is the identity on termtvars_ctx Γ, then termsubst_ctx σ Γ = Γ, 
  so termsubst_aterm σ a is well-typed in termΓ.›

corollary subst_wt_identity_ctx:
  "well_typed ct Γ a  (v. v  tvars_ctx Γ  σ v = TVar v) 
   well_typed ct Γ (subst_aterm σ a)"
  using subst_preserves_wt[of ct Γ a σ] subst_ctx_id[of Γ σ] by simp

subsection ‹Positions and Post-Order Enumeration›

text ‹Definition (Post-Order Enumeration). For a fully annotated term terma,
  the function enum_aterm a› returns a list of triples term(p, s, τ)
  where termp is a position number (in post-order), terms is the raw subterm,
  and termτ is the type at that position.

  The helper shift_enum k L› adds termk to all position numbers in termL.›

definition shift_enum :: "nat  (nat × raw_term × ty) list  (nat × raw_term × ty) list" where
  "shift_enum k L = map (λ(p, s, τ). (k + p, s, τ)) L"

fun enum_aterm :: "aterm  (nat × raw_term × ty) list" where
  "enum_aterm (AConst c τ) = [(0, RConst c, τ)]"
| "enum_aterm (AVar x τ) = [(0, RVar x, τ)]"
| "enum_aterm (AAbs x τ a1 τ') =
    (let L = enum_aterm a1; n = length L in
     [(0, RVar x, τ)] @ shift_enum 1 L @ [(1 + n, RAbs x (erase a1), τ')])"
| "enum_aterm (AApp a1 a2 τ) =
    (let L1 = enum_aterm a1; n1 = length L1;
         L2 = enum_aterm a2; n2 = length L2 in
     L1 @ shift_enum n1 L2 @ [(n1 + n2, RApp (erase a1) (erase a2), τ)])"

text ‹The set of positions.›

definition pos_set :: "aterm  nat set" where
  "pos_set a = fst ` set (enum_aterm a)"

text ‹The type at a given position in the enumeration.›

definition type_at_pos :: "aterm  nat  ty" where
  "type_at_pos a p = (THE τ. s. (p, s, τ)  set (enum_aterm a))"

text ‹The subterm at a given position in the enumeration.›

definition subterm_at_pos :: "aterm  nat  raw_term" where
  "subterm_at_pos a p = (THE s. τ. (p, s, τ)  set (enum_aterm a))"

subsubsection ‹Basic Properties of Enumeration›

lemma length_shift_enum [simp]: "length (shift_enum k L) = length L"
  by (simp add: shift_enum_def)

lemma set_shift_enum: "set (shift_enum k L) = (λ(p, s, τ). (k + p, s, τ)) ` set L"
  by (auto simp: shift_enum_def)

lemma in_shift_enum_iff:
  "(p, s, τ)  set (shift_enum k L)  (p'. p = k + p'  (p', s, τ)  set L)"
  by (auto simp: shift_enum_def image_iff split: prod.splits)

lemma shift_enum_shift_enum:
  "shift_enum k1 (shift_enum k2 L) = shift_enum (k1 + k2) L"
  by (induction L) (auto simp: shift_enum_def)

text ‹The length of the enumeration equals the number of nodes in the term.
  For an annotated term, each node produces one entry, with abstractions
  additionally producing a binder position.›

fun aterm_node_count :: "aterm  nat" where
  "aterm_node_count (AConst _ _) = 1"
| "aterm_node_count (AVar _ _) = 1"
| "aterm_node_count (AAbs _ _ a _) = 2 + aterm_node_count a"
| "aterm_node_count (AApp a1 a2 _) = 1 + aterm_node_count a1 + aterm_node_count a2"

lemma length_enum_aterm: "length (enum_aterm a) = aterm_node_count a"
  by (induction a) (auto simp: Let_def)

text ‹Every annotated term has at least one position.›

lemma enum_aterm_nonempty: "enum_aterm a  []"
  by (cases a) (auto simp: Let_def)

lemma aterm_node_count_pos: "aterm_node_count a  1"
  by (cases a) auto

subsubsection ‹Substitution Distributes to Positions›

text ‹Lemma (Substitution Distributes to Positions).
  The enumeration of termsubst_aterm σ a is obtained from the enumeration
  of terma by applying termσ to each type, leaving position numbers and
  subterms unchanged.

  We first define the operation that applies a substitution to the types in an
  enumeration list.›

definition map_enum_ty :: 
  "(string  ty)  (nat × raw_term × ty) list  (nat × raw_term × ty) list" where
  "map_enum_ty σ L = map (λ(p, s, τ). (p, s, subst_ty σ τ)) L"

lemma length_map_enum_ty [simp]: "length (map_enum_ty σ L) = length L"
  by (simp add: map_enum_ty_def)

lemma map_enum_ty_shift: "map_enum_ty σ (shift_enum k L) = shift_enum k (map_enum_ty σ L)"
  by (induction L) (auto simp: map_enum_ty_def shift_enum_def)

lemma map_enum_ty_append: "map_enum_ty σ (L1 @ L2) = map_enum_ty σ L1 @ map_enum_ty σ L2"
  by (simp add: map_enum_ty_def)

text ‹The main lemma: the enumeration commutes with type substitution.›

lemma enum_subst_aterm:
  "enum_aterm (subst_aterm σ a) = map_enum_ty σ (enum_aterm a)"
  by (induction a) (auto simp: map_enum_ty_def shift_enum_def Let_def split: prod.splits)

corollary type_at_pos_subst_iff:
  "(p, s, τ')  set (enum_aterm (subst_aterm σ a)) 
   (τ. (p, s, τ)  set (enum_aterm a)  τ' = subst_ty σ τ)"
  by (auto simp: enum_subst_aterm map_enum_ty_def image_iff split: prod.splits)

corollary type_at_pos_subst:
  "(p, s, τ)  set (enum_aterm a)  (p, s, subst_ty σ τ)  set (enum_aterm (subst_aterm σ a))"
  using type_at_pos_subst_iff by auto

text ‹The node count is invariant under type substitution.›

lemma aterm_node_count_subst [simp]:
  "aterm_node_count (subst_aterm σ a) = aterm_node_count a"
  by (induction a) auto

text ‹The position set is invariant under type substitution,
  and more generally, is the same for any two terms with the same erasure.
  We prove the substitution case first.›

lemma pos_set_subst [simp]: "pos_set (subst_aterm σ a) = pos_set a"
  unfolding pos_set_def
  by (force simp: map_enum_ty_def enum_subst_aterm)

text ‹All type variables of terma appear in some type annotation at a position.›

lemma tvars_aterm_subset_enum:
  "α  tvars_aterm a  p s τ. (p, s, τ)  set (enum_aterm a)  α  tvars_ty τ"
  by (induction a) (auto simp: Let_def in_shift_enum_iff)

text ‹Converse: every type at an enumeration position contributes to the term's type variables.›

lemma enum_tvars_subset_aterm:
  "(p, s, τ)  set (enum_aterm a)  tvars_ty τ  tvars_aterm a"
  by (induction a arbitrary: p s τ) (fastforce simp: Let_def in_shift_enum_iff)+

subsubsection ‹Distinctness and Range of Position Numbers›

text ‹Position numbers in the enumeration form a contiguous range starting from 0.
  This gives distinctness of positions as an immediate corollary.›

lemma map_fst_shift_enum: "map fst (shift_enum k L) = map ((+) k) (map fst L)"
  by (induction L) (auto simp: shift_enum_def)

lemma map_fst_enum_aterm: "map fst (enum_aterm a) = [0..<aterm_node_count a]"
  using upt_add_eq_append[symmetric] map_add_upt
  by (induction a) 
    (auto simp: upt_conv_Cons Let_def map_fst_shift_enum length_enum_aterm add.commute)
 
text ‹Corollary: position numbers are distinct.›

lemma distinct_enum_fst: "distinct (map fst (enum_aterm a))"
  by (simp add: map_fst_enum_aterm)

text ‹Corollary: the position set is exactly the range term{0..<aterm_node_count a}.›

lemma pos_set_range: "pos_set a = {0..<aterm_node_count a}"
  unfolding pos_set_def
  by (metis map_fst_enum_aterm set_map set_upt)

text ‹A useful corollary: if term(p, s, τ) and term(p, s', τ') are both
  in the enumeration, then terms = s' and termτ = τ' (positions are unique keys).›

lemma enum_aterm_unique:
  assumes "(p, s, τ)  set (enum_aterm a)" "(p, s', τ')  set (enum_aterm a)"
  shows "s = s'" "τ = τ'"
  using assms eq_key_imp_eq_value distinct_enum_fst 
  by fastforce+

text ‹Note: The coverage lemma is proved below in
  annotation_problem› locale as coverage_initial›. The locale
  instantiation connecting the reverse-greedy algorithm to annotation_selection›
  is also done below via the rg› interpretation.›

subsection ‹The Annotation Algorithm›

subsubsection ‹Type Inference Assumption›

text ‹Assumption (Type Inference). We work in a locale that fixes:
  - termconst_type: the type scheme of each constant
  - termΓ: a fixed context
  - terma: a fully annotated, well-typed term in termΓ
  - terma_star: the principal typing (generalized term)
  - termσ: the matching substitution with termσ a_star = a:›

locale annotation_problem =
  fixes const_type :: "string  ty"
    and Γ :: ctx
    and a :: aterm
    and a_star :: aterm
    and σ :: "string  ty"
  assumes a_wt: "well_typed const_type Γ a"
    and a_star_wt: "well_typed const_type Γ a_star"
    and same_erasure: "erase a_star = erase a"
    and matching: "subst_aterm σ a_star = a"
    and freshness: "α. α  tvars_ctx Γ  σ α = TVar α"
begin

text ‹Definition (Inference Variables). The type variables introduced by
  inference that are not present in the context.›

definition V :: "string set" where
  "V = tvars_aterm a_star - tvars_ctx Γ"

text ‹Definition (Key). The inference variables visible at a position.
  We define key p› as the set of inference variables that appear in any type
  annotation at position termp in the enumeration of terma_star, intersected with termV.
  This avoids the need to show that position numbers are unique in the
  enumeration.›

definition key :: "nat  string set" where
  "key p = {α. s τ. (p, s, τ)  set (enum_aterm a_star)  α  tvars_ty τ}  V"

text ‹The termσ is the identity on termtvars_ctx Γ, hence moves only
  variables in termV.›

lemma sigma_id_on_ctx: "α  tvars_ctx Γ  σ α = TVar α"
  using freshness by simp

text ‹Positions of terma_star and terma coincide.›

lemma pos_set_eq: "pos_set a = pos_set a_star"
  using matching pos_set_subst[of σ a_star] by simp

text ‹Lemma (Coverage). Every inference variable appears in the key of some position.
  This follows directly from the fact that every type variable of terma_star appears at some
  position in the enumeration, and termV is a subset of termtvars_aterm a_star.›

lemma coverage_initial:
  assumes "α  V"
  shows "p  pos_set a_star. α  key p"
  using assms tvars_aterm_subset_enum unfolding pos_set_def key_def V_def by force

text ‹Consistency of a term terma' with annotations at a set of positions termP':
  for every position termp  P', the type at position termp in terma' agrees with
  the type at position termp in terma (i.e., with termsubst_ty σ τ where
  termτ is the type at position termp in terma_star).

  This directly follows the paper's definition: since termerase a' = erase a_star,
  position sets coincide, and consistency means that for each termp  P' the
  type at position termp in terma' equals termsubst_ty σ τ where
  term(p, s, τ)  set (enum_aterm a_star).›

definition consistent_with :: "aterm  nat set  bool" where
  "consistent_with a' P' 
    p s τ. p  P'  (p, s, τ)  set (enum_aterm a_star) 
      (τ'. (p, s, τ')  set (enum_aterm a')  τ' = subst_ty σ τ)"

lemma consistent_withI:
  "(p s τ. p  P'  (p, s, τ)  set (enum_aterm a_star) 
      τ'. (p, s, τ')  set (enum_aterm a')  τ' = subst_ty σ τ)
    consistent_with a' P'"
  unfolding consistent_with_def by auto

lemma consistent_withE:
  assumes "consistent_with a' P'" "p  P'" "(p, s, τ)  set (enum_aterm a_star)"
  obtains τ' where "(p, s, τ')  set (enum_aterm a')" "τ' = subst_ty σ τ"
  using assms unfolding consistent_with_def by auto

lemma consistent_with_mono:
  "consistent_with a' P'  Q  P'  consistent_with a' Q"
  unfolding consistent_with_def by auto

text ‹When terma' is a substitution instance of terma_star, i.e.,
  terma' = subst_aterm σ' a_star, consistency at termP' reduces to
  type agreement: termsubst_ty σ' τ = subst_ty σ τ for all positions
  termp  P'. This is the form used in the completeness proof.›

lemma consistent_with_substI:
  assumes "p s τ. p  P'  (p, s, τ)  set (enum_aterm a_star) 
             subst_ty σ' τ = subst_ty σ τ"
  shows "consistent_with (subst_aterm σ' a_star) P'"
proof (rule consistent_withI)
  fix p s τ
  assume "p  P'" "(p, s, τ)  set (enum_aterm a_star)"
  then have "subst_ty σ' τ = subst_ty σ τ" using assms by auto
  moreover have "(p, s, subst_ty σ' τ)  set (enum_aterm (subst_aterm σ' a_star))"
    using (p, s, τ)  set (enum_aterm a_star) type_at_pos_subst by auto
  ultimately show "τ'. (p, s, τ')  set (enum_aterm (subst_aterm σ' a_star))  τ' = subst_ty σ τ"
    by auto
qed

lemma consistent_with_substD:
  assumes "consistent_with (subst_aterm σ' a_star) P'"
    "p  P'" "(p, s, τ)  set (enum_aterm a_star)"
  shows "subst_ty σ' τ = subst_ty σ τ"
  using consistent_withE[OF assms] assms enum_aterm_unique type_at_pos_subst
  by metis

end

text ‹An annotation selection extends the annotation problem with a finite set of
  positions termP that covers all inference variables and where each position has a witness
  variable (a variable in its key not appearing in any other position's key in termP).
  We separate the proof of the main theorems from the algorithm that produces termP.›

locale annotation_selection = annotation_problem +
  fixes P :: "nat set"
  assumes P_subset: "P  pos_set a_star"
    and coverage: "(key ` P) = V"
    and witness: "p. p  P  α  key p. p'  P. p'  p  α  key p'"
begin

text ‹Every inference variable appears in the key of some kept position.›

lemma coverage_mem: "α  V  p  P. α  key p"
  using coverage by blast

subsubsection ‹Annotations determine the substitution›

lemma sigma_agree_on_V:
  assumes agreement: "p s τ. p  P  (p, s, τ)  set (enum_aterm a_star)  subst_ty σ' τ = subst_ty σ τ"
  and α  V
shows "σ' α = σ α"
proof -
  obtain p where "p  P" "α  key p" using coverage_mem α  V by blast
  from α  key p obtain s τ where
    mem: "(p, s, τ)  set (enum_aterm a_star)" and var: "α  tvars_ty τ"
    unfolding key_def by auto
  from p  P mem have "subst_ty σ' τ = subst_ty σ τ" using agreement by auto
  with var show "σ' α = σ α" using unique_type_match by blast
qed

text ‹The main result:›

lemma annotations_determine_subst:
  assumes agreement: "p s τ. p  P  (p, s, τ)  set (enum_aterm a_star) 
                    subst_ty σ' τ = subst_ty σ τ"
    and fresh': "α. α  tvars_ctx Γ  σ' α = TVar α"
  shows "subst_aterm σ' a_star = a"
proof -
  text ‹We show termσ' = σ on termtvars_aterm a_star, then conclude
    by @{thm subst_aterm_agree} and @{thm matching}.›
  have agree: "σ' α = σ α" if "α  tvars_aterm a_star" for α
  proof (cases "α  V")
    case True
    then show ?thesis
      using sigma_agree_on_V[OF agreement] by simp
  next
    case False
    with that have "α  tvars_ctx Γ" unfolding V_def by auto
    then have "σ' α = TVar α" using fresh' by auto
    moreover have "σ α = TVar α" using α  tvars_ctx Γ freshness by auto
    ultimately show ?thesis by simp
  qed
  have "subst_aterm σ' a_star = subst_aterm σ a_star"
    by (intro subst_aterm_agree agree)
  also have "... = a" using matching by simp
  finally show ?thesis .
qed

subsubsection ‹Completeness›

text ‹Theorem (Completeness). Any well-typed terma' with
  termerase a' = erase a and consistent with the constraints at termP satisfies a' = a›.

  The principality assumption gives: any well-typed term with the same erasure
  is a substitution instance of terma_star, where the substitution is identity on
  context variables.›

theorem completeness:
  assumes a'_wt: "well_typed const_type Γ a'"
    and a'_erase: "erase a' = erase a"
    and principality: "σ'. subst_aterm σ' a_star = a'
                           (α  tvars_ctx Γ. σ' α = TVar α)"
    and consist: "consistent_with a' P"
  shows "a' = a"
proof -
  from principality obtain σ' where
    inst: "subst_aterm σ' a_star = a'" and
    fresh': "α  tvars_ctx Γ. σ' α = TVar α"
    by blast
  have "a' = subst_aterm σ' a_star" using inst by simp
  also have "... = a"
  proof (rule annotations_determine_subst)
    fix p s τ
    assume "p  P" "(p, s, τ)  set (enum_aterm a_star)"
    then show "subst_ty σ' τ = subst_ty σ τ"
      using consistent_with_substD[OF consist[folded inst] p  P (p, s, τ)  set (enum_aterm a_star)] by auto
  next
    fix α
    assume "α  tvars_ctx Γ"
    then show "σ' α = TVar α" using fresh' by auto
  qed
  finally show "a' = a" .
qed

subsubsection ‹Local Minimality›

text ‹Theorem (Local Minimality). For every termp  P, removing the annotation
  at position termp makes the typing non-unique: there exists terma' different from terma
  that is well-typed in Γ› with termerase a' = erase a and consistent with
  the annotations at termP - {p}.›

theorem local_minimality:
  assumes "p  P"
  shows "a'. a'  a
     well_typed const_type Γ a'
     erase a' = erase a
     consistent_with a' (P - {p})"
proof -
  text ‹Step 1: Obtain witness variable.›
  from witness[OF p  P] obtain α where
    α_in_key: "α  key p" and
    α_unique: "p'  P. p'  p  α  key p'"
    by blast

  text ‹Step 2: Define the altered substitution.›
  define σ_star where "σ_star = (λv. if v = α then σ α  σ α else σ v)"

  text ‹Step 3: Define a'› and show it satisfies all four properties.›
  define a' where "a' = subst_aterm σ_star a_star"

  text ‹Property 1: terma' differs from terma.›
  have σ_star_diff: "σ_star α  σ α"
    unfolding σ_star_def using arrow_neq_self by simp

  have "α  V" using α_in_key unfolding key_def by auto
  then have "α  tvars_aterm a_star" unfolding V_def by auto

  have "a'  a"
  proof
    assume "a' = a"
    then have "subst_aterm σ_star a_star = subst_aterm σ a_star" using matching a'_def by simp
    then have "σ_star α = σ α"
      using α  tvars_aterm a_star subst_aterm_injective by blast
    with σ_star_diff show False by contradiction
  qed

  text ‹Property 2: terma' is well-typed.›
  have fresh_star: "α'  tvars_ctx Γ  σ_star α' = TVar α'" for α'
  proof -
    assume "α'  tvars_ctx Γ"
    then have "α'  V" unfolding V_def by auto
    then have "α'  α" using α  V by auto
    then have "σ_star α' = σ α'" unfolding σ_star_def by simp
    also have "... = TVar α'" using α'  tvars_ctx Γ freshness by auto
    finally show "σ_star α' = TVar α'" .
  qed
  have "well_typed const_type Γ a'"
    unfolding a'_def using subst_wt_identity_ctx[OF a_star_wt fresh_star] .

  text ‹Property 3: termerase a' = erase a.›
  have "erase a' = erase a"
    unfolding a'_def using erase_subst same_erasure by simp

  text ‹Property 4: terma' is consistent with termP - {p}.›
  have consistent_minus_p: "consistent_with a' (P - {p})"
    unfolding a'_def
  proof (rule consistent_with_substI)
    fix p' s τ
    assume "p'  P - {p}" "(p', s, τ)  set (enum_aterm a_star)"
    then have "p'  P" "p'  p" by auto
    text ‹Step 1: termα is not in termtvars_ty τ.›
    have "α  tvars_ty τ"
    proof
      assume "α  tvars_ty τ"
      then have "α  key p'"
        using (p', s, τ)  set (enum_aterm a_star) α  V
        unfolding key_def by auto
      with α_unique p'  P p'  p show False by auto
    qed
    text ‹Step 2: termσ_star = σ on termtvars_ty τ.›
    have "σ_star β = σ β" if "β  tvars_ty τ" for β
    proof -
      from that α  tvars_ty τ have "β  α" by auto
      then show ?thesis unfolding σ_star_def by simp
    qed
    then show "subst_ty σ_star τ = subst_ty σ τ"
      by (rule subst_ty_agree)
  qed

  from a'  a well_typed const_type Γ a' erase a' = erase a consistent_minus_p
  show ?thesis by blast
qed

end

subsection ‹Annotation Insertion›

text ‹Definition (Annotation Insertion).
  Given the matching substitution termσ, the generalized term terma_star,
  the selected positions termP, and a starting position counter termk,
  the function termins traverses the raw term and the annotated term in lockstep,
  inserting type constraints at positions in termP.

  Returns a pair (annotated raw term, number of positions traversed).›

fun ins :: "(string  ty)  aterm  nat set  nat  raw_term × nat" where
  "ins σ (AConst c τ) P k =
    (if k  P then RConstrain (RConst c) (subst_ty σ τ) else RConst c, 1)"
| "ins σ (AVar x τ) P k =
    (if k  P then RConstrain (RVar x) (subst_ty σ τ) else RVar x, 1)"
| "ins σ (AAbs x τ a1 τ') P k =
    (let (t1', n1) = ins σ a1 P (k + 1);
         t_binder = (if k  P then RAbsT x (subst_ty σ τ) t1'
                     else RAbs x t1');
         t' = (if k + 1 + n1  P then RConstrain t_binder (subst_ty σ τ')
               else t_binder)
     in (t', n1 + 2))"
| "ins σ (AApp a1 a2 τ) P k =
    (let (t1', n1) = ins σ a1 P k;
         (t2', n2) = ins σ a2 P (k + n1);
         t_app = RApp t1' t2';
         t' = (if k + n1 + n2  P then RConstrain t_app (subst_ty σ τ)
               else t_app)
     in (t', n1 + n2 + 1))"

text ‹The top-level annotation function.›

definition annotate :: "(string  ty)  aterm  nat set  raw_term" where
  "annotate σ a_star P = fst (ins σ a_star P 0)"

text ‹The number of positions traversed equals the node count.›

lemma ins_count: "snd (ins σ a P k) = aterm_node_count a"
  by (induction a arbitrary: k) (auto simp: Let_def case_prod_unfold)

text ‹Stripping constraints from the output of ins recovers the erasure.›

lemma strip_ins: "strip (fst (ins σ a P k)) = erase a"
  by (induction a arbitrary: k) (auto simp: Let_def case_prod_unfold)

text ‹Corollary: annotate preserves the raw term under stripping.›

corollary strip_annotate: "strip (annotate σ a P) = erase a"
  unfolding annotate_def using strip_ins by simp

subsection ‹The Reverse-Greedy Algorithm›

text ‹Definition (Reverse-Greedy Selection).
  The algorithm processes candidate positions in decreasing cost order.
  A position is dropped if every variable in its key has count > 1› (i.e.,
  is covered by at least one other remaining candidate); otherwise it is kept.

  We model the algorithm as a fold over the candidate list (sorted by decreasing cost).
  The state is a count function tracking how many undropped candidates cover each variable.
  
  The fold processes from head (highest cost) to tail (lowest cost).
  When a position is kept, the count is unchanged.
  When a position is dropped, the count is decremented for each variable in its key.›

text ‹One step: a position with key termK is kept iff some variable in termK has count ≤ 1›.›

definition rg_keep :: "(string  nat)  string set  bool" where
  "rg_keep cnt K = (α  K. cnt α  1)"

text ‹The fold: processes a list of (position, key) pairs. Returns (kept set, final count).›

fun rg_fold :: "(nat × string set) list  (string  nat)  nat set × (string  nat)" where
  "rg_fold [] cnt = ({}, cnt)"
| "rg_fold ((p, K) # rest) cnt =
    (if rg_keep cnt K then
       let (P', cnt') = rg_fold rest cnt in (insert p P', cnt')
     else
       let cnt' = (λα. if α  K then cnt α - 1 else cnt α) in
       let (P', cnt'') = rg_fold rest cnt' in (P', cnt''))"

text ‹Initialize the count for each variable.›

definition init_count :: "(nat × string set) list  string  nat" where
  "init_count cands α = length (filter (λ(_, K). α  K) cands)"

text ‹The full algorithm.›

definition reverse_greedy :: "(nat × string set) list  nat set" where
  "reverse_greedy cands = fst (rg_fold cands (init_count cands))"

subsubsection ‹Basic Properties of the Fold›

text ‹The kept set is a subset of the candidates.›

lemma rg_fold_subset: "fst (rg_fold cands cnt)  fst ` set cands"
  by (induction cands arbitrary: cnt) (fastforce split: if_splits simp: case_prod_unfold Let_def)+

text ‹If a variable is not in any candidate's key, the count is unchanged by the fold.›

lemma rg_fold_cnt_unchanged:
  "(p, K)  set cands. α  K  snd (rg_fold cands cnt) α = cnt α"
  by (induction cands arbitrary: cnt) (auto simp: case_prod_unfold Let_def)

text ‹The count never increases during the fold.›

lemma rg_fold_cnt_mono: "snd (rg_fold cands cnt) α  cnt α"
proof (induction cands arbitrary: cnt) 
  case (Cons pc rest)
  thus ?case
    using order.trans[OF Cons]
    by (cases pc) (auto simp: Let_def case_prod_unfold)
qed simp

text ‹If no kept position has termα in its key, then the final count for termα equals term0
  (assuming the initial count equals the number of candidates covering termα).›

lemma rg_fold_no_kept_zero:
  assumes no_kept: "p K. (p, K)  set cands  α  K  p  fst (rg_fold cands cnt)"
    and cnt_eq: "cnt α = length (filter (λ(_, K). α  K) cands)"
  shows "snd (rg_fold cands cnt) α = 0"
  using assms
  by (induction cands arbitrary: cnt) (auto simp: case_prod_unfold Let_def)

subsubsection ‹Coverage›

text ‹Lemma (Reverse-Greedy Preserves Full Coverage).
  The key invariant: the fold maintains termcnt α  1 for every termα
  that appears in some candidate's key. This is because:
  - At a keep step: count is unchanged (passes to recursive call as-is).
  - At a drop step: we only drop when all counts in key p› are > 1›, so after
    decrementing they are ≥ 1›. For variables NOT in key p›, the count is unchanged.

  At termination, termcnt α  1 means at least one undropped candidate covers
  termα. But there are no more candidates to process, so every undropped candidate
  is a kept position. Hence termα is covered by a kept position.

  The formal proof combines this invariant with the fact that the final count equals
  the number of kept positions covering termα.›

text ‹The key invariant: if all counts are non-zero, they remain non-zero after the fold.
  This is because a position is only dropped when all variables in its key have count > 1›.›

lemma rg_fold_preserves_ge1:
  assumes "α. cnt α  1"
  shows "snd (rg_fold cands cnt) α  1"
  using assms
proof (induction cands arbitrary: cnt)
next
  case (Cons pc rest)
  thus ?case
    by (cases pc) (fastforce simp: case_prod_unfold Let_def rg_keep_def intro!: Cons[simplified])
qed simp

text ‹The fold result depends only on the count values for variables in candidate keys.›

lemma rg_fold_cnt_agree:
  assumes "α. ((p, K)  set cands. α  K)  cnt1 α = cnt2 α"
  shows "fst (rg_fold cands cnt1) = fst (rg_fold cands cnt2)"
    and "α. ((p, K)  set cands. α  K)  
           snd (rg_fold cands cnt1) α = snd (rg_fold cands cnt2) α"
proof -
  have both: "fst (rg_fold cands cnt1) = fst (rg_fold cands cnt2) 
    (α. ((p, K)  set cands. α  K)  
           snd (rg_fold cands cnt1) α = snd (rg_fold cands cnt2) α)"
    using assms
  proof (induction cands arbitrary: cnt1 cnt2)
    case Nil then show ?case by simp
next
  case (Cons pc rest)
  obtain p K where [simp]: "pc = (p, K)" by (cases pc)
  have key_eq: "α  K. cnt1 α = cnt2 α"
    using Cons.prems by auto
  then have keep_eq: "rg_keep cnt1 K = rg_keep cnt2 K"
    unfolding rg_keep_def by auto
  show ?case
  proof (cases "rg_keep cnt1 K")
    case True
    have rest_agree: "α. ((p, K)  set rest. α  K)  cnt1 α = cnt2 α"
      using Cons.prems by auto
    from Cons.IH[OF rest_agree] have IH: 
      "fst (rg_fold rest cnt1) = fst (rg_fold rest cnt2)"
      "α. ((p, K)  set rest. α  K)  
             snd (rg_fold rest cnt1) α = snd (rg_fold rest cnt2) α"
      by auto
    have snd_eq: "β. ((p, K)  set (pc # rest). β  K)  
           snd (rg_fold (pc # rest) cnt1) β = snd (rg_fold (pc # rest) cnt2) β"
    proof (intro allI impI)
      fix β assume ex: "(p, K)  set (pc # rest). β  K"
      obtain P1 c1 P2 c2 where 
        r1: "rg_fold rest cnt1 = (P1, c1)" and r2: "rg_fold rest cnt2 = (P2, c2)" 
        by (cases "rg_fold rest cnt1"; cases "rg_fold rest cnt2")
      from IH(1) r1 r2 have "P1 = P2" by simp
      have fold1: "rg_fold (pc # rest) cnt1 = (insert p P1, c1)"
        using rg_keep cnt1 K r1 by (simp add: Let_def)
      have fold2: "rg_fold (pc # rest) cnt2 = (insert p P2, c2)"
        using rg_keep cnt1 K keep_eq r2 by (simp add: Let_def)
      show "snd (rg_fold (pc # rest) cnt1) β = snd (rg_fold (pc # rest) cnt2) β"
      proof (cases "(p, K)  set rest. β  K")
        case True
        from IH(2) True have "c1 β = c2 β" using r1 r2 by auto
        then show ?thesis using fold1 fold2 by simp
      next
        case False
        then have nk1: "(p, K)  set rest. β  K" by auto
        have "snd (rg_fold rest cnt1) β = cnt1 β" using rg_fold_cnt_unchanged[OF nk1] by simp
        then have "c1 β = cnt1 β" using r1 by simp
        moreover have "snd (rg_fold rest cnt2) β = cnt2 β" using rg_fold_cnt_unchanged[OF nk1] by simp
        then have "c2 β = cnt2 β" using r2 by simp

        moreover from ex False have "β  K" by auto
        then have "cnt1 β = cnt2 β" using key_eq by auto
        ultimately show ?thesis using fold1 fold2 by simp
      qed
    qed
    from IH(1) snd_eq True keep_eq
    show ?thesis by (auto simp: Let_def split: prod.splits)

  next
    case False
    define cnt1' where "cnt1' = (λα. if α  K then cnt1 α - 1 else cnt1 α)"
    define cnt2' where "cnt2' = (λα. if α  K then cnt2 α - 1 else cnt2 α)"
    have rest_agree: "α. ((p, K)  set rest. α  K)  cnt1' α = cnt2' α"
      using Cons.prems key_eq unfolding cnt1'_def cnt2'_def by auto
    from Cons.IH[OF rest_agree] have IH:
      "fst (rg_fold rest cnt1') = fst (rg_fold rest cnt2')"
      "α. ((p, K)  set rest. α  K)  
             snd (rg_fold rest cnt1') α = snd (rg_fold rest cnt2') α"
      by auto
    have snd_eq: "α. ((p, K)  set (pc # rest). α  K)  
           snd (rg_fold (pc # rest) cnt1) α = snd (rg_fold (pc # rest) cnt2) α"
    proof (intro allI impI)
      fix β assume "(p, K)  set (pc # rest). β  K"
      then have "β  K  ((p, K)  set rest. β  K)" by auto
      then show "snd (rg_fold (pc # rest) cnt1) β = snd (rg_fold (pc # rest) cnt2) β"
      proof
        assume "β  K"
        show ?thesis
        proof (cases "(p, K)  set rest. β  K")
          case True
          then show ?thesis using IH(2) False keep_eq cnt1'_def cnt2'_def
            by (simp add: Let_def rg_keep_def split: prod.splits)
        next
          case Falser: False
          then have nk2: "(p, K)  set rest. β  K" by auto
          have "snd (rg_fold rest cnt1') β = cnt1' β"
            using rg_fold_cnt_unchanged[OF nk2] by simp
          moreover have "snd (rg_fold rest cnt2') β = cnt2' β"
            using rg_fold_cnt_unchanged[OF nk2] by simp
          moreover have "cnt1' β = cnt2' β" 
            using key_eq β  K unfolding cnt1'_def cnt2'_def by auto
          ultimately show ?thesis using False keep_eq cnt1'_def cnt2'_def
            by (simp add: Let_def rg_keep_def split: prod.splits)

        qed
      next
        assume "(p, K)  set rest. β  K"
        then show ?thesis using IH(2) False keep_eq cnt1'_def cnt2'_def
          by (simp add: Let_def rg_keep_def split: prod.splits)
      qed
    qed
    from IH(1) snd_eq False keep_eq cnt1'_def cnt2'_def
    show ?thesis by (simp add: Let_def rg_keep_def split: prod.splits)
  qed
  qed
  from both show "fst (rg_fold cands cnt1) = fst (rg_fold cands cnt2)" by simp
  from both show "α. ((p, K)  set cands. α  K)  
           snd (rg_fold cands cnt1) α = snd (rg_fold cands cnt2) α" by simp
qed

text ‹Corollary: the fold preserves count ≥ 1› for variables in candidate keys.
  We prove this by using @{thm rg_fold_cnt_agree} to relate the fold with terminit_count
  to the fold with a count that is ≥ 1› everywhere.›

lemma rg_fold_preserves_ge1_on_keys:
  assumes "α. ((p, K)  set cands. α  K)  cnt α  1"
    and "((p, K)  set cands. α  K)"
  shows "snd (rg_fold cands cnt) α  1"
proof -
  define cnt' where "cnt' β = (if (p, K)  set cands. β  K then cnt β else 1)" for β
  have ge1: "β. cnt' β  1" unfolding cnt'_def using assms(1) by auto
  have agree: "β. ((p, K)  set cands. β  K)  cnt β = cnt' β"
    unfolding cnt'_def by auto
  from rg_fold_cnt_agree(2)[OF agree] assms(2) 
  have "snd (rg_fold cands cnt) α = snd (rg_fold cands cnt') α" by auto
  then have "snd (rg_fold cands cnt) α = snd (rg_fold cands cnt') α" by simp
  also have "  1" using rg_fold_preserves_ge1 ge1 by auto
  finally show ?thesis .
qed

subsubsection ‹Witness Property›

text ‹Lemma (Witness Variable). For every kept position termp, there exists a variable
  termα in its key such that termα does not appear in the key of any other kept position.

  The proof proceeds by induction on the candidate list, using a generalized invariant
  that tracks an extra count (representing kept positions from the prefix that are no
  longer in the candidate list). The key insight: in the keep case, the extra count
  being zero for the witness variable automatically excludes it from the current head's
  key, since the extra count absorbs the head's contribution.›

text ‹Auxiliary generalized lemma: the count function may over-count by an extra amount.
  The conclusion finds a witness termα whose extra contribution is zero, meaning termα only
  appears in candidates from the current list (not from any prior kept positions).›

lemma rg_fold_witness_aux:
  assumes dist: "distinct (map fst cands)"
    and cnt_eq: "α. cnt α = length (filter (λ(_, K). α  K) cands) + extra α"
    and mem: "(p, K)  set cands"
    and kept: "p  fst (rg_fold cands cnt)"
  shows "α  K. extra α = 0 
    (p' K'. (p', K')  set cands  p'  fst (rg_fold cands cnt)  p'  p  α  K')"
  using assms
proof (induction cands arbitrary: cnt extra p K)
  case Nil
  then show ?case by simp
next
  case (Cons pc rest)
  obtain ph Kh where pc_def: "pc = (ph, Kh)" by (cases pc)
  from Cons.prems(1) pc_def have dist_rest: "distinct (map fst rest)"
    and ph_notin: "ph  fst ` set rest" by auto
  show ?case
  proof (cases "p = ph")
    case True
    text ‹Target is the head. First establish termK = Kh using distinctness.›
    from Cons.prems(3) pc_def p = ph have "(ph, K)  set ((ph, Kh) # rest)" by simp
    then have "K = Kh  (ph, K)  set rest" by auto
    with ph_notin have K_eq: "K = Kh" by (auto simp: image_iff)
    show ?thesis
    proof (cases "rg_keep cnt Kh")
      case True
      text ‹Head is kept:›
      from True obtain α where α_in: "α  Kh" and cnt_le: "cnt α  1"
        unfolding rg_keep_def by auto
      from Cons.prems(2) pc_def α_in have
        cnt_eq_α: "cnt α = 1 + length (filter (λ(_, K). α  K) rest) + extra α"
        by simp
      from cnt_le cnt_eq_α have rest_empty: "length (filter (λ(_, K). α  K) rest) = 0"
        and extra_zero: "extra α = 0" by arith+
      from rest_empty have no_α_rest: "(p', K')  set rest. α  K'"
        by (auto simp: filter_empty_conv)
      text ‹Show the witness property for termα.›
      have wit: "p' K'. (p', K')  set (pc # rest)  
        p'  fst (rg_fold (pc # rest) cnt)  p'  p  α  K'"
      proof (intro allI impI)
        fix p' K' 
        assume p'_in: "(p', K')  set (pc # rest)"
          and p'_kept: "p'  fst (rg_fold (pc # rest) cnt)"
          and p'_ne: "p'  p"
        from p'_ne p = ph have "p'  ph" by simp
        from p'_in pc_def p'  ph have "(p', K')  set rest" by auto
        with no_α_rest show "α  K'" by auto
      qed
      from α_in K_eq extra_zero wit show ?thesis 
        by (intro bexI[of _ α]) auto
    next
      case False
      text ‹Head is dropped but termp = ph, so termp must be in the result of the recursive call.
        This leads to a contradiction since termph cannot be in termfst of termrest's result.›
      define cnt' where "cnt' = (λα. if α  Kh then cnt α - 1 else cnt α)"
      from False pc_def have fold_eq: "rg_fold (pc # rest) cnt = 
        (let (P', cnt'') = rg_fold rest cnt' in (P', cnt''))"
        by (simp add: Let_def cnt'_def)
      obtain P' cnt'' where rec: "rg_fold rest cnt' = (P', cnt'')" 
        by (cases "rg_fold rest cnt'")
      from fold_eq rec have "fst (rg_fold (pc # rest) cnt) = P'" by (simp add: Let_def)
      with Cons.prems(4) p = ph have "ph  P'" by simp
      from rg_fold_subset[of rest cnt'] rec have "P'  fst ` set rest" by simp
      with ph  P' ph_notin show ?thesis by auto
    qed
  next
    case False
    text ‹Target termp is in the tail.›
    from Cons.prems(3) pc_def p  ph have p_in_rest: "(p, K)  set rest" by auto
    show ?thesis
    proof (cases "rg_keep cnt Kh")
      case True
      text ‹Head is kept. The recursive call uses the same termcnt.
        The extra count for rest includes the head's contribution.›
      obtain P' cnt_out where rec: "rg_fold rest cnt = (P', cnt_out)" 
        by (cases "rg_fold rest cnt")
      from True pc_def rec have fold_eq: 
        "fst (rg_fold (pc # rest) cnt) = insert ph P'"
        by (simp add: Let_def)
      from Cons.prems(4) fold_eq p  ph have p_kept_rest: "p  P'" by simp
      with rec have p_in_fold: "p  fst (rg_fold rest cnt)" by simp
      text ‹Define the extra count for rest:›
      define extra' where "extra' α = (if α  Kh then 1 else 0) + extra α" for α
      have cnt_rest: "α. cnt α = length (filter (λ(_, K). α  K) rest) + extra' α"
      proof
        fix α
        from Cons.prems(2) pc_def have 
          "cnt α = length (filter (λ(_, K). α  K) (pc # rest)) + extra α" by simp
        then show "cnt α = length (filter (λ(_, K). α  K) rest) + extra' α"
          unfolding extra'_def pc_def by simp
      qed
      text ‹Apply IH on rest.›
      from Cons.IH[OF dist_rest cnt_rest p_in_rest p_in_fold]
      obtain α where α_in: "α  K" 
        and extra'_zero: "extra' α = 0"
        and wit_rest: "p' K'. (p', K')  set rest  
          p'  fst (rg_fold rest cnt)  p'  p  α  K'"
        by auto
      text ‹From termextra' α = 0, deduce termα  Kh and termextra α = 0.›
      from extra'_zero extra'_def have α_notin_Kh: "α  Kh" and extra_zero: "extra α = 0"
        by (auto split: if_splits)
      text ‹Extend the witness property to the full list.›
      have wit_full: "p' K'. (p', K')  set (pc # rest)  
        p'  fst (rg_fold (pc # rest) cnt)  p'  p  α  K'"
      proof (intro allI impI)
        fix p' K' 
        assume p'_in: "(p', K')  set (pc # rest)"
          and p'_kept: "p'  fst (rg_fold (pc # rest) cnt)"
          and p'_ne: "p'  p"
        show "α  K'"
          using α_notin_Kh p'_in pc_def ph_notin fold_eq rec p'_kept 
          by (fastforce simp: image_iff p'_ne wit_rest)
      qed
      from α_in extra_zero wit_full show ?thesis by (intro bexI[of _ α]) auto
    next
      case False
      text ‹Head is dropped. Decrement count for variables in termKh.›
      define cnt' where "cnt' = (λα. if α  Kh then cnt α - 1 else cnt α)"
      obtain P' cnt'' where rec: "rg_fold rest cnt' = (P', cnt'')" 
        by (cases "rg_fold rest cnt'")
      from False pc_def cnt'_def rec have fold_eq:
        "fst (rg_fold (pc # rest) cnt) = P'"
        by (simp add: Let_def rg_keep_def split: prod.splits)
      from Cons.prems(4) fold_eq have p_kept_rest: "p  P'" by simp
      with rec have p_in_fold: "p  fst (rg_fold rest cnt')" by simp
      text ‹Compute the termcnt' equation for rest.›
      have cnt_rest: "α. cnt' α = length (filter (λ(_, K). α  K) rest) + extra α"
      proof
        fix α
        from Cons.prems(2) pc_def have 
          cnt_full: "cnt α = (if α  Kh then 1 else 0) + length (filter (λ(_, K). α  K) rest) + extra α"
          by simp
        show "cnt' α = length (filter (λ(_, K). α  K) rest) + extra α"
        proof (cases "α  Kh")
          case True
          then have "cnt α  1" using cnt_full by simp
          with True show ?thesis unfolding cnt'_def using cnt_full by simp
        next
          case Falseα: False
          then show ?thesis unfolding cnt'_def using cnt_full by simp
        qed
      qed
      text ‹Apply IH on rest with extra unchanged.›
      from Cons.IH[OF dist_rest cnt_rest p_in_rest p_in_fold]
      obtain α where α_in: "α  K"
        and extra_zero: "extra α = 0"
        and wit_rest: "p' K'. (p', K')  set rest  
          p'  fst (rg_fold rest cnt')  p'  p  α  K'"
        by auto
      text ‹Extend witness property to full list. The head was dropped, so termph  P'.›
      from rg_fold_subset[of rest cnt'] rec have "P'  fst ` set rest" by simp
      with ph_notin have ph_notin_P: "ph  P'" by auto
      have wit_full: "p' K'. (p', K')  set (pc # rest)  
        p'  fst (rg_fold (pc # rest) cnt)  p'  p  α  K'"
      proof (intro allI impI)
        fix p' K' 
        assume p'_in: "(p', K')  set (pc # rest)"
          and p'_kept: "p'  fst (rg_fold (pc # rest) cnt)"
          and p'_ne: "p'  p"
        from p'_kept fold_eq have "p'  P'" by simp
        with ph_notin_P have "p'  ph" by auto
        from p'_in pc_def p'  ph have "(p', K')  set rest" by auto
        from p'  P' rec have "p'  fst (rg_fold rest cnt')" by simp
        from wit_rest (p', K')  set rest this p'_ne
        show "α  K'" by auto
      qed
      from α_in extra_zero wit_full show ?thesis by (intro bexI[of _ α]) auto
    qed
  qed
qed

text ‹The main witness lemma: specialization with termextra = 0.›

lemma rg_fold_witness:
  assumes "distinct (map fst cands)"
    and "α. cnt α = length (filter (λ(_, K). α  K) cands)"
    and "(p, K)  set cands" and "p  fst (rg_fold cands cnt)"
  shows "α  K. p'  fst (rg_fold cands cnt). p'  p 
    (K'. (p', K')  set cands  α  K')"
  using rg_fold_witness_aux[OF assms(1) _ assms(3,4)] assms(2)
  by fastforce

subsubsection ‹Connecting the Algorithm to the Locale›

text ‹We now connect the reverse-greedy algorithm to the locale-based proofs.
  The candidate list is constructed from the enumeration of terma_star, filtering
  for positions with non-empty keys. The reverse-greedy produces a set termP
  satisfying coverage and witness properties.›

context annotation_problem
begin

text ‹Candidate list: positions paired with their keys, derived from the post-order
  enumeration of terma_star. 
  Candidates are processed in the termenum_aterm order (generalization to other orders is easy).
  Positions with empty keys are included
  but are always dropped by the reverse-greedy (the keep condition is vacuously false).›

definition candidates :: "(nat × string set) list" where
  "candidates = map (λ(p, s, τ). (p, tvars_ty τ  V)) (enum_aterm a_star)"

text ‹The candidate positions have distinct first components.›

lemma candidates_distinct: "distinct (map fst candidates)"
  unfolding candidates_def by (simp add: comp_def case_prod_beta distinct_enum_fst)

text ‹For termα  V, the initial count is at least 1.›

lemma candidates_count_ge1:
  assumes "α  V"
  shows "init_count candidates α  1"                                                 
proof -
  from coverage_initial[OF assms] obtain p where
    "p  pos_set a_star" "α  key p" by auto
  from α  key p obtain s τ where
    mem: "(p, s, τ)  set (enum_aterm a_star)" and tv: "α  tvars_ty τ" and "α  V"
    unfolding key_def by auto
  then have mem_cand: "(p, tvars_ty τ  V)  set candidates"
    unfolding candidates_def by force
  have "α  tvars_ty τ  V" using tv assms by auto
  with mem_cand have "(p, tvars_ty τ  V)  set (filter (λ(_, K). α  K) candidates)"
    by auto
  then have "0 < length (filter (λ(_, K). α  K) candidates)"
    by (rule length_pos_if_in_set)
  then have "length (filter (λ(_, K). α  K) candidates)  1" by linarith
  then show ?thesis unfolding init_count_def by simp
qed

text ‹The kept set from the reverse-greedy is a subset of candidate positions,
  which are positions in termpos_set a_star.›

lemma rg_result_subset: "reverse_greedy candidates  pos_set a_star"
proof -
  have "reverse_greedy candidates  fst ` set candidates"
    unfolding reverse_greedy_def using rg_fold_subset by auto
  also have "  pos_set a_star"
    unfolding candidates_def pos_set_def by force
  finally show ?thesis .
qed

text ‹Key in the candidate list matches key in the locale.›

lemma candidates_key_eq:
  assumes "(p, K)  set candidates"
  shows "K = key p"
  using assms
  unfolding key_def candidates_def
  by (fastforce dest: enum_aterm_unique)

text ‹Coverage property: the kept positions cover all inference variables.›

lemma rg_coverage: "(key ` reverse_greedy candidates) = V"
proof (intro equalityI subsetI)
  fix α assume "α  (key ` reverse_greedy candidates)"
  then obtain p where "p  reverse_greedy candidates" "α  key p" by auto
  then show "α  V" unfolding key_def by auto
next
  fix α assume α_V: "α  V"
  text ‹By @{thm rg_fold_preserves_ge1_on_keys}, the final count for termα is ≥ 1›.
    This means at least one candidate covering termα was not dropped.›
  from coverage_initial[OF α_V] obtain p0 where
    "p0  pos_set a_star" "α  key p0" by auto
  from α  key p0 obtain s0 τ0 where 
    mem0: "(p0, s0, τ0)  set (enum_aterm a_star)" and "α  tvars_ty τ0" "α  V"
    unfolding key_def by auto
  then have "(p0, tvars_ty τ0  V)  set candidates"
    unfolding candidates_def by force
  then have ex_cand: "(p, K)  set candidates. α  K"
    using α  tvars_ty τ0 α  V by (intro bexI[of _ "(p0, tvars_ty τ0  V)"]) auto
  have cnt_ge1: "β. ((p, K)  set candidates. β  K)  init_count candidates β  1"
  proof (intro allI impI)
    fix β assume "(p, K)  set candidates. β  K"
    then obtain p1 K1 where "(p1, K1)  set candidates" "β  K1" by auto
    from candidates_key_eq[OF this(1)] β  K1 have "β  key p1" by auto
    then have "β  V" unfolding key_def by auto
    then show "init_count candidates β  1" using candidates_count_ge1 by auto
  qed
  from rg_fold_preserves_ge1_on_keys[OF cnt_ge1 ex_cand]
  have final_ge1: "snd (rg_fold candidates (init_count candidates)) α  1" .
  text ‹Final count ≥ 1› means at least one position covering termα was kept.
    Since final count ≥ 1›, at least one candidate with termα was kept.›
  from rg_fold_cnt_mono[of candidates "init_count candidates" α]
  have "snd (rg_fold candidates (init_count candidates)) α  init_count candidates α" .

  have "(pk, Kk)  set candidates. α  Kk  pk  fst (rg_fold candidates (init_count candidates))"
  proof (rule ccontr)
    assume "¬ ((pk, Kk)  set candidates. α  Kk  pk  fst (rg_fold candidates (init_count candidates)))"
    then have no_kept: "(pk, Kk)  set candidates. α  Kk  pk  fst (rg_fold candidates (init_count candidates))"
      by auto
    text ‹If no kept position covers termα, the count must drop to term0:›
    have nk: "p K. (p, K)  set candidates  α  K 
        p  fst (rg_fold candidates (init_count candidates))"
      using no_kept by auto
    have "snd (rg_fold candidates (init_count candidates)) α = 0"
      by (rule rg_fold_no_kept_zero[OF nk]) (simp add: init_count_def)
    with final_ge1 show False by simp
  qed
  then obtain pk Kk where pk_mem: "(pk, Kk)  set candidates"
    and "α  Kk" and pk_kept: "pk  fst (rg_fold candidates (init_count candidates))"
    by auto
  from pk_kept have "pk  reverse_greedy candidates" unfolding reverse_greedy_def by simp
  from candidates_key_eq[OF pk_mem] α  Kk have "α  key pk" by auto
  from pk  reverse_greedy candidates α  key pk
  show "α  (key ` reverse_greedy candidates)" by auto
qed

text ‹Witness property: each kept position has a witness variable.›

lemma rg_witness:
  assumes "p  reverse_greedy candidates"
  shows "α  key p. p'  reverse_greedy candidates. p'  p  α  key p'"
proof -
  from assms have p_in: "p  fst (rg_fold candidates (init_count candidates))"
    unfolding reverse_greedy_def by simp
  have "p  fst ` set candidates"
    using p_in rg_fold_subset[of candidates "init_count candidates"] by auto
  then obtain K where pK1: "(p, K)  set candidates"
    by (force simp: candidates_def)
  note pK = pK1 p_in

  from rg_fold_witness[OF candidates_distinct _ pK]
  have wit: "α  K. p'  fst (rg_fold candidates (init_count candidates)). p'  p 
    (K'. (p', K')  set candidates  α  K')"
    by (simp add: init_count_def)
  then obtain α where "α  K" and
    excl: "p'  fst (rg_fold candidates (init_count candidates)). p'  p 
      (K'. (p', K')  set candidates  α  K')" by auto
  from candidates_key_eq[OF pK(1)] α  K have "α  key p" by simp
  moreover have "p'  reverse_greedy candidates. p'  p  α  key p'"
  proof (intro ballI impI)
    fix p' assume "p'  reverse_greedy candidates" "p'  p"
    then have p'_in: "p'  fst (rg_fold candidates (init_count candidates))"
      unfolding reverse_greedy_def by simp
    from rg_fold_subset have "p'  fst ` set candidates"
      using p'_in by (metis in_mono)
    then obtain K' where "(p', K')  set candidates" by auto
    from excl p'_in p'  p this have "α  K'" by auto
    from candidates_key_eq[OF (p', K')  set candidates] this
    show "α  key p'" by simp
  qed
  ultimately show ?thesis by auto
qed

text ‹Instantiate localeannotation_selection with termP = reverse_greedy candidates.
  This makes the abstract completeness and local minimality theorems available as
  concrete theorems about the reverse-greedy algorithm's output.›

sublocale annotation_selection const_type Γ a a_star σ "reverse_greedy candidates"
proof (unfold_locales)
  show "reverse_greedy candidates  pos_set a_star"
    by (rule rg_result_subset)
  show "(key ` reverse_greedy candidates) = V"
    by (rule rg_coverage)
  fix p assume "p  reverse_greedy candidates"
  then show "α  key p. p'  reverse_greedy candidates. p'  p  α  key p'"
    by (rule rg_witness)
qed

text ‹The output of the full algorithm: t_out› is the annotated raw term produced
  by running the reverse-greedy algorithm to select positions and then inserting
  annotations. This is the t_out› from the paper's Theorem 1.›

definition t_out :: raw_term where
  "t_out = annotate σ a_star (reverse_greedy candidates)"

text ‹Key property: stripping the annotations from termt_out recovers the erasure of terma.›

lemma strip_t_out: "strip t_out = erase a"
  unfolding t_out_def using strip_annotate same_erasure by simp

text ‹Theorem (Completeness), stated in terms of t_out›:›
theorem completeness_t_out:
  assumes "well_typed const_type Γ a'"
    and "erase a' = strip t_out"
    and "σ'. subst_aterm σ' a_star = a'  (α  tvars_ctx Γ. σ' α = TVar α)"
    and "consistent_with a' (reverse_greedy candidates)"
  shows "a' = a"
  using assms strip_t_out completeness by auto

text ‹Theorem (Local Minimality), stated for the reverse-greedy selection:›

thm local_minimality

end

end