import numpy as np
from numpy.linalg import norm
import matplotlib.pyplot as plt
= 10000 # sample size
n = 15000 # dimension
p = 1000 # sparsity
s = 100 # number of iterations
T = (1+np.sqrt(p/n))**-2 # learning rate
eta = 1/eta # smoothness of the loss
L = 0.01 # tuning parameter
lam = 1.5 # noise level
sigma
# true regression vector
= 15 * np.hstack([np.ones(s), np.zeros(p-s)]) / np.sqrt(n)
beta
# data generation process
= np.random.default_rng(2025)
rng = rng.standard_normal(size=(n, p)) # feature matrix
X = X @ beta + sigma * rng.standard_normal(size=(n, )) # signal + noise y
Estimating the generalization performance of an iterative algorithm along its trajectory
In a linear model \(y_i = x_i^T\beta + \epsilon_i\) with true regression vector \(\beta\in\mathbb R^p\), noise \(\epsilon_i\sim N(0,\sigma^2)\) and feature matrix \(X\in \mathbb R^{n\times p}\) with rows \((x_i)_{i\in [n]}\), we are interested in estimationg the generalization error of an iterative algorithm of the form \[\begin{equation} \hat b^t = g_t \Big( \hat b^{t-1},~\hat b^{t-2},~ \frac{X^T(y - X\hat b^{t-1})}{n},~ \frac{X^T(y-X\hat b^{t-2}}{n}) \Big) \end{equation}\] for some function \(g_t\) of the two previous iterates and their gradients (here, with respect to the square loss). The iterate and gradient at iteration \(t-2\) are included to allow for accelerated methods that require momentum.
The following example illustrates the estimator proposed in (Bellec and Tan 2024).
Proximal Gradient Descent
In the simple example below, we focus on the proximal gradient descent (Parikh, Boyd, et al. 2014) iterates \[\begin{equation} \hat b^t = g \Big( \hat b^{t-1} + \frac{\eta X^T(y-X\hat b^{t-1}}{n}) \Big) \end{equation}\] where \(\eta>0\) is a learning rate parameter, and the nonlinear function \(g\) is the proximal operator of the L1 norm with parameter \(\eta \lambda\), namely, the soft-thresholding \[\begin{equation} g(u) = \text{soft}(u; v), \text{ where } v=\lambda\eta \text{ and } \text{soft}(u; v) = \text{sign}(u)(|u| - v)_+. \end{equation}\]
def soft(x, t):
return np.sign(x)*np.clip(np.abs(x)-t, 0, None)
-3,3), soft(np.linspace(-3, 3), 2)) plt.plot(np.linspace(
Hutchinson’s trace approximation
We will use the following trick to efficiently estimate the trace of a large matrix: if \(R = (r_{ik})_{i\in[n],k=[m]}\) has iid symmetric \(\pm 1/\sqrt m\), then for \(Q\in \mathbb R^{n\times n}\) the Hutchinson approximation \(\text{trace}(Q)\approx\text{trace}(R^TQR)/m\) holds with \(m\) of constant order provided \(\|Q\|_F \lll \text{trace}(Q)\).
= 10
m = rng.choice([-1.0, 1.0], size=(n, m)) / np.sqrt(m) r
For instance, let us verify this approximation with a random matrix:
= rng.normal(size=(n, n))
G = G.T @ np.diag(np.linspace(1, 10, num=n)) @ G temporary_matrix
@ temporary_matrix @ r) np.trace(temporary_matrix), np.trace(r.T
(550122897.8034426, 550825944.5616682)
Estimating the generalization error along the algorithm trajectory
We are now ready to compute the iterates of the proximal gradient algorithm, as well as the estimate of its generalization error from (Bellec and Tan 2024).
# initialization of arrays
= np.zeros((n, T)) # residuals
F = np.zeros((p, T), dtype=np.float16) # error vectors
H = np.zeros_like(beta) # iterate
bt 0] = y - X @ bt
F[:, 0] = bt - beta
H[:, = np.zeros((T, T)) # memory matrix
A = X.T @ r # r has iid Rademacher entries
XTr
for t in range(1, T):
print(t, end="..")
= soft(bt + eta * X.T @ (y-X@bt) /n, lam * eta)
bt = np.zeros(p, dtype=bool)
Dt 0]] = 1
Dt[np.nonzero(bt)[= Dt.reshape((p, 1))
Dt if t >= 1:
= Dt * XTr if t == 1 else np.hstack([Dt * (Rt - (eta/n)*X.T @ ( X @ Rt)), Dt * XTr])
Rt = np.trace((eta * XTr.T @ Rt / n).reshape((m, t, m)),
A[t, :t] =0, axis2=2)
axis1
= y - X @ bt
F[:, t] = bt - beta
H[:, t]
= np.linalg.solve(np.eye(T)-A/n, F.T).T
M = H.T @ H + np.ones((T, T))*sigma**2
generalization_error = M.T @ M/n generalization_error_estimate
1..2..3..4..5..6..7..8..9..10..11..12..13..14..15..16..17..18..19..20..21..22..23..24..25..26..27..28..29..30..31..32..33..34..35..36..37..38..39..40..41..42..43..44..45..46..47..48..49..50..51..52..53..54..55..56..57..58..59..60..61..62..63..64..65..66..67..68..69..70..71..72..73..74..75..76..77..78..79..80..81..82..83..84..85..86..87..88..89..90..91..92..93..94..95..96..97..98..99..
import pandas as pd
= pd.DataFrame(np.column_stack([
df
np.diag(generalization_error_estimate),
np.diag(generalization_error),
]))= ('generalization error', 'estimate of the generalization error')
df.columns df.plot()
round(2) df.
generalization error | estimate of the generalization error | |
---|---|---|
0 | 25.18 | 24.75 |
1 | 18.60 | 18.08 |
2 | 16.06 | 15.48 |
3 | 14.53 | 13.96 |
4 | 13.42 | 12.88 |
... | ... | ... |
95 | 3.94 | 3.83 |
96 | 3.94 | 3.83 |
97 | 3.92 | 3.82 |
98 | 3.92 | 3.82 |
99 | 3.91 | 3.82 |
100 rows × 2 columns