Theory Pell.Pell_Algorithm
subsection ‹Executable code›
theory Pell_Algorithm
subsubsection ‹Efficient computation of powers by squaring›
text ‹
The following is a tail-recursive implementation of exponentiation by squaring.
It works for any binary operation ‹f› that fulfils ‹f x (f x z) = f (f x x) z›, i.\,e.\
some weak form of associativity.
fixes f :: "'a ⇒ 'a ⇒ 'a"
function efficient_power :: "'a ⇒ 'a ⇒ nat ⇒ 'a" where
"efficient_power y x 0 = y"
| "efficient_power y x (Suc 0) = f x y"
| "n ≠ 0 ⟹ even n ⟹ efficient_power y x n = efficient_power y (f x x) (n div 2)"
| "n ≠ 1 ⟹ odd n ⟹ efficient_power y x n = efficient_power (f x y) (f x x) (n div 2)"
by force+
termination by (relation "measure (snd ∘ snd)") (auto elim: oddE)
lemma efficient_power_code [code]:
"efficient_power y x n =
(if n = 0 then y
else if n = 1 then f x y
else if even n then efficient_power y (f x x) (n div 2)
else efficient_power (f x y) (f x x) (n div 2))"
by (induction y x n rule: efficient_power.induct) auto
lemma efficient_power_correct:
assumes "⋀x z. f x (f x z) = f (f x x) z"
shows "efficient_power y x n = (f x ^^ n) y"
proof -
have [simp]: "f ^^ 2 = (λx. f (f x))" for f :: "'a ⇒ 'a"
by (simp add: eval_nat_numeral o_def)
show ?thesis
by (induction y x n rule: efficient_power.induct)
(auto elim!: evenE oddE simp: funpow_mult [symmetric] funpow_Suc_right assms
simp del: funpow.simps(2))
subsubsection ‹Multiplication and powers of solutions›
text ‹
We define versions of Pell solution multiplication and exponentiation specialised
to natural numbers, both for efficiency reasons and to circumvent the problem of
generating code for definitions made inside locales.
fun pell_mul_nat :: "nat ⇒ nat × nat ⇒ _" where
"pell_mul_nat D (a, b) (x, y) = (a * x + D * b * y, a * y + b * x)"
lemma (in pell) pell_mul_nat_correct [simp]: "pell_mul_nat D = pell.pell_mul D"
by (auto simp add: pell_mul_def fun_eq_iff)
definition efficient_pell_power :: "nat ⇒ nat × nat ⇒ nat ⇒ nat × nat" where
"efficient_pell_power D z n = efficient_power (pell_mul_nat D) (1, 0) z n"
lemma efficient_pell_power_correct [simp]:
"efficient_pell_power D z n = (pell_mul_nat D z ^^ n) (1, 0)"
unfolding efficient_pell_power_def
by (intro efficient_power_correct) (auto simp: algebra_simps)
subsubsection ‹Finding the fundamental solution›
text ‹
In the following, we set up a very simple algorithm for computing the fundamental
solution ‹(x, y)›. We try inreasing values for ‹y› until $1 + Dy^2$ is a perfect
square, which we check using an efficient square-detection algorithm. This is efficient
enough to work on some interesting small examples.
Much better algorithms (typically based on the continued fraction expansion of $\sqrt{D}$)
are available, but they are also considerably more complicated.
lemma Discrete_sqrt_square_is_square:
assumes "is_square n"
shows "Discrete.sqrt n ^ 2 = n"
using assms unfolding is_nth_power_def by force
definition find_fund_sol_step :: "nat ⇒ nat × nat + nat × nat ⇒ _" where
"find_fund_sol_step D = (λInl (y, y') ⇒
(case get_nat_sqrt y' of
Some x ⇒ Inr (x, y)
| None ⇒ Inl (y + 1, y' + D * (2 * y + 1))))"
definition find_fund_sol where
"find_fund_sol D =
(if square_test D then
(0, 0)
sum.projr (while sum.isl (find_fund_sol_step D) (Inl (1, 1 + D))))"
lemma fund_sol_code:
assumes "¬is_square (D :: nat)"
shows "pell.fund_sol D = sum.projr (while isl (find_fund_sol_step D) (Inl (Suc 0, Suc D)))"
proof -
from assms interpret pell D by unfold_locales
note [simp] = find_fund_sol_step_def
define f where "f = find_fund_sol_step D"
define P :: "nat ⇒ bool" where "P = (λy. y > 0 ∧ is_square (y^2 * D + 1))"
define Q :: "nat × nat ⇒ bool" where
"Q = (λ(x,y). P y ∧ (∀y'∈{0<..<y}. ¬P y') ∧ x = Discrete.sqrt (y^2 * D + 1))"
define R :: "nat × nat + nat × nat ⇒ bool"
where "R = (λs. case s of
Inl (m, m') ⇒ m > 0 ∧ (m' = m^2 * D + 1) ∧ (∀y∈{0<..<m}. ¬is_square (y^2 * D + 1))
| Inr x ⇒ Q x)"
define rel :: "((nat × nat + nat × nat) × (nat × nat + nat × nat)) set"
where "rel = {(A,B). (case (A, B) of
(Inl (m, _), Inl (m', _)) ⇒ m' > 0 ∧ m > m' ∧ m ≤ snd fund_sol
| (Inr _, Inl (m', _)) ⇒ m' ≤ snd fund_sol
| _ ⇒ False) ∧ A = f B}"
obtain x y where xy: "sum.projr (while isl f (Inl (Suc 0, Suc D))) = (x, y)"
by (cases "sum.projr (while isl f (Inl (Suc 0, Suc D)))")
have neq_fund_solI: "y ≠ snd fund_sol" if "¬ is_square (Suc (y⇧2 * D))" for y
assume "y = snd fund_sol"
with fund_sol_is_nontriv_solution have "Suc (y⇧2 * D) = fst fund_sol ^ 2"
by (simp add: nontriv_solution_def case_prod_unfold)
hence "is_square (Suc (y⇧2 * D))" by simp
with that show False by contradiction
have "case_sum (λ_. False) Q (while sum.isl f (Inl (m, m^2 * D + 1)))"
if "∀y∈{0<..<m}. ¬is_square (y^2 * D + 1)" "m > 0" for m
proof (rule while_rule[where b = sum.isl])
show "R (Inl (m, m⇧2 * D + 1))"
using that by (auto simp: R_def)
fix s assume "R s" "isl s"
thus "R (f s)"
by (auto simp: not_less_less_Suc_eq Q_def P_def R_def f_def get_nat_sqrt_def
power2_eq_square algebra_simps split: sum.splits prod.splits)
fix s assume "R s" "¬isl s"
thus "case s of Inl _ ⇒ False | Inr x ⇒ Q x"
by (auto simp: R_def split: sum.splits)
fix s assume s: "R s" "isl s"
show "(f s, s) ∈ rel"
proof (cases s)
case [simp]: (Inl s')
obtain a b where [simp]: "s' = (a, b)" by (cases s')
from s have *: "a > 0" "b = Suc (a⇧2 * D)" "⋀y. y ∈ {0<..<a} ⟹ ¬ is_square (Suc (y⇧2 * D))"
by (auto simp: R_def)
have "a < snd fund_sol" if **: "¬ is_square (Suc (a⇧2 * D))"
proof -
from neq_fund_solI have "y' ≠ snd fund_sol" if "y' ∈ {0<..<Suc a}" for y'
using * ** that by (cases "y' = a") auto
moreover have "snd fund_sol ≠ 0" using fund_sol_is_nontriv_solution
by (intro notI, cases fund_sol) (auto simp: nontriv_solution_altdef)
ultimately have "∀y'≤a. y' ≠ snd fund_sol" by (auto simp: less_Suc_eq_le)
thus "snd fund_sol > a" by (cases "a < snd fund_sol") (auto simp: not_less)
moreover have "a ≤ snd fund_sol"
proof -
have "∀y'∈{0<..<a}. y' ≠ snd fund_sol" using neq_fund_solI *
by (auto simp: less_Suc_eq_le)
moreover have "snd fund_sol ≠ 0" using fund_sol_is_nontriv_solution
by (intro notI, cases fund_sol) (auto simp: nontriv_solution_altdef)
ultimately have "∀y'<a. y' ≠ snd fund_sol" by (auto simp: less_Suc_eq_le)
thus "snd fund_sol ≥ a" by (cases "a ≤ snd fund_sol") (auto simp: not_less)
ultimately show ?thesis using *
by (auto simp: f_def get_nat_sqrt_def rel_def)
qed (insert s, auto)
define rel'
where "rel' = {(y, x). (case x of Inl (m, _) ⇒ m ≤ snd fund_sol | Inr _ ⇒ False) ∧ y = f x}"
have "wf rel'" unfolding rel'_def
by (rule wf_if_measure[where f = "λz. case z of Inl (m, _) ⇒ Suc (snd fund_sol) - m | _ ⇒ 0"])
(auto split: prod.splits sum.splits simp: f_def get_nat_sqrt_def)
moreover have "rel ⊆ rel'"
proof safe
fix w z assume "(w, z) ∈ rel"
thus "(w, z) ∈ rel'" by (cases w; cases z) (auto simp: rel_def rel'_def)
ultimately show "wf rel" by (rule wf_subset)
from this[of 1] and xy have *: "Q (x, y)"
by (auto split: sum.splits)
from * have "is_square (Suc (y⇧2 * D))" by (simp add: Q_def P_def)
with * have "x⇧2 = Suc (y⇧2 * D)" "y > 0"
by (auto simp: Q_def P_def Discrete_sqrt_square_is_square)
hence "nontriv_solution (x, y)"
by (auto simp: nontriv_solution_def)
from this have "snd fund_sol ≤ snd (x, y)"
by (rule fund_sol_minimal'')
moreover have "snd fund_sol ≥ y"
proof -
from * have "(∀y'∈{0<..<y}. ¬ is_square (Suc (y'⇧2 * D)))"
by (simp add: Q_def P_def)
with neq_fund_solI have "(∀y'∈{0<..<y}. y' ≠ snd fund_sol)"
by auto
moreover have "snd fund_sol ≠ 0"
using fund_sol_is_nontriv_solution
by (cases fund_sol) (auto intro!: Nat.gr0I simp: nontriv_solution_altdef)
ultimately have "(∀y'<y. y' ≠ snd fund_sol)" by auto
thus "snd fund_sol ≥ y" by (cases "snd fund_sol ≥ y") (auto simp: not_less)
ultimately have "snd fund_sol = y" by simp
with solutions_linorder_strict[of x y "fst fund_sol" "snd fund_sol"]
fund_sol_is_nontriv_solution ‹nontriv_solution (x, y)›
have "fst fund_sol = x" by (cases fund_sol) (auto simp: nontriv_solution_altdef)
with ‹snd fund_sol = y› have "fund_sol = (x, y)"
by (cases fund_sol) simp
with xy show ?thesis by (simp add: f_def)
lemma find_fund_sol_correct: "find_fund_sol D = (if is_square D then (0, 0) else pell.fund_sol D)"
by (simp add: find_fund_sol_def fund_sol_code square_test_correct)
subsubsection ‹The infinite list of all solutions›
definition pell_solutions :: "nat ⇒ (nat × nat) stream" where
"pell_solutions D = (let z = find_fund_sol D in siterate (pell_mul_nat D z) (1, 0))"
lemma (in pell) "snth (pell_solutions D) n = nth_solution n"
by (simp add: pell_solutions_def Let_def find_fund_sol_correct nonsquare_D nth_solution_def
pell_power_def pell_mul_commutes[of _ fund_sol])
subsubsection ‹Computing the $n$-th solution›
definition find_nth_solution :: "nat ⇒ nat ⇒ nat × nat" where
"find_nth_solution D n =
(if is_square D then (0, 0) else
let z = sum.projr (while isl (find_fund_sol_step D) (Inl (Suc 0, Suc D)))
in efficient_pell_power D z n)"
lemma (in pell) find_nth_solution_correct: "find_nth_solution D n = nth_solution n"
by (simp add: find_nth_solution_def nonsquare_D nth_solution_def fund_sol_code
pell_power_def pell_mul_commutes[of _ "projr _"])