EE227C course page
Download ipynb file
In this lecture we will derive an intriguing way to speed up gradient descent using Chebyshev polynomials. In doing so, we'll also fill in the details for a blog post I wrote four and a half years ago called zen of gradient descent, just in case you were still waiting on that.
Overview:
Bibliographic note: Linear Algebra and its Applications by Peter D. Lax has a fantastic exposition of this material in Chapter 17. It's generally a fabulous linear algebra text that I highly recommend.
%matplotlib inline
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from plotters import kwargs, setup_layout
setup_layout()
Consider minimizing the objective function
$$f(x) = \frac12 x^\top A x - b^\top x$$
over all of $\mathbb{R}^n,$ where $A\in\mathbb{R}^{n\times n}$ is a symmetric positive definite matrix. We might for instance want to do this when we solve linear equations corresponding to a Laplacian operator. Note that
$$\nabla f(x) = Ax-b\qquad\text{and}\qquad \nabla^2 f(x) = A.$$
As we can see, the gradient vanishes when $Ax=b.$ We denote the condition number of the matrix by $\kappa=\beta/\alpha$ where $\beta=\lambda_1(A)$ and $\alpha=\lambda_n(A)>0.$ In particular, we know that the objective function is $\beta$-smooth and $\alpha$-strongly convex.
def quadratic(A, b, x):
"""Quadratic defined by A and b at x."""
return 0.5 * x.dot(A.dot(x)) - b.dot(x)
We already saw in Lecture 3 that gradient descent achieves a linear convergence rate of the form $\exp(-t\alpha/\beta)$ for all $\alpha$-strongly convex and $\beta$-smooth functions. Let's rederive this result for quadratics, where it's almost trivial.
Let $x^*$ be the unique solution of the linear system $Ax=b$ and put
$$ e_t = \|x_t-x^*\| $$
where $x_{t+1}=x_t - \eta_t (Ax_t-b)$ is defined recursively starting from some $x_0,$ and $\eta_t$ is a step size we'll determine shortly.
Note that $x^*$ satisfies (for any $t$)
$$ x^* = (I-\eta_t A)x^* +\eta b. $$
Hence,
$$ \begin{align} e_{t+1} & = x_{t+1}-x^* \\ & = (I-\eta_t A)x_t +\eta_t b - ( (I-\eta_t A)x^* +\eta_t b ) \\ & = (I-\eta_t A)e_t. \end{align} $$
The above calculation gives us $e_t = p_t(A)e_0,$ where $p_t$ is the polynomial
$$ p_t(a) = \prod_{i=1}^t (1-\eta_ta)\,.\qquad(*) $$
We can upper bound the norm of the error term as
$$ \|e_t\|\le \|p_t(A)\|\cdot\|e_0\|\,. $$
Since $A$ is a symmetric matrix with eigenvalues in $[\alpha,\beta],$ it's not hard to justify that
$$ \|p_t(A)\|\le \max_{\alpha\le a\le\beta} \left|p_t(a)\right|\,. $$
This leads to an intriguing problem: Among all polynomials that satisfy $p_t(0)=1$ we're looking for a polynomial whose magnitude is as small as possible in the interval $[\alpha,\beta].$
A naive solution is to choose a uniform step size $\eta_t=2/(\alpha+\beta)$ in the expression $(*).$ This rescales the eigenvalues of $A$ just in the right way so that a simple calculation gives the bound:
$$ \|e_t\|\le \left(1-\frac1{\kappa}\right)^t\|e_1\| \le \exp\left(-\alpha t/\beta\right)\|x_0-x^*\|\,. $$
This is exactly the rate we proved in Lecture 3 for any smooth and strongly convex function. So, no surprise—yet!
Let's look at this polynomial a bit closer. In the example below we choose $\alpha=1$ and $\beta=10$ so that $\kappa=10.$ The relevant interval is therefore $[1,10].$
def p(k, x, alpha=1.0, beta=10.0):
return np.power(1.0 - 2.0*x/(alpha+beta), k)
alpha, beta = 1.0, 10.0
xs = np.linspace(0, beta, 100)
plt.figure(figsize=(14, 8))
plt.title('Naive polynomial')
plt.plot(xs, p(3, xs), 'g-', label='degree 3', **kwargs)
plt.plot(xs, [p(3, alpha)]*len(xs), 'g--', label='max value')
plt.plot(xs, p(6, xs), 'b-', label='degree 6', **kwargs)
plt.plot(xs, [p(6, alpha)]*len(xs), 'b--', label='max value')
plt.legend();
Doubling the degree roughly halves the maximum absolute value that the polynomial attains in the interval $[\alpha, \beta].$
Can we do better than this? Surprisingly, the answer is yes!
Chebyshev polynomials turn out to give an optimal answer to the question that we asked. Suitably rescaled, they minimize the absolute value in a desired interval $[\alpha, \beta]$ while satisfying the normalization constraint of having value $1$ at the origin.
Before we do the rescaling, we will first look at the standard Chebyshev polynomials ("of the first kind") on the interval $[-1,1].$ These polynomials have several natural equivalent definitions.
Definition. We define the Chebyshev polynomial $T_k$ recursively as follows:
$$ \begin{align} T_0(a) &= 1\\ T_1(a) &= x\\ T_k(a) &=2aT_{k-1}(a)-T_{k-2}(a),\qquad k\ge 2.\\ \end{align} $$
def T(k, a):
"""Chebyshev polynomial of degree k"""
if k <= 1:
return a**k
else:
return 2.0*a*T(k-1, a) - T(k-2, a)
Here are the first 5 Chebyshev polynomials.
xs = np.linspace(-1, 1, 100)
plt.figure(figsize=(14, 10))
_ = [plt.plot(xs, T(k, xs), **kwargs) for k in range(0, 5)]
So far so good. The magic happens when we rescale these polynomials so as to satisfy our requirements.
Recall that the eigenvalues of the matrix we consider are in the interval $[\alpha, \beta].$ We need to rescale the Chebyshev polynomials so that they're supported on this interval and still attain value $1$ at the origin. This is accomplished by the polynomial
$$ P_k(a) = T_k\left(\frac{\beta+\alpha-2a}{\beta-\alpha}\right)\cdot T\left(\frac{\beta+\alpha}{\beta-\alpha}\right)^{-1}\,. $$
def P(k, a, alpha=1, beta=10.0):
"""Rescaled Chebyshev polynomial."""
assert beta > alpha
normalization = T(k, (beta+alpha)/(beta-alpha))
return T(k, (beta+alpha-2*a)/(beta-alpha))/normalization
alpha, beta = 1.0, 10.0
xs = np.linspace(0, beta, 100)
plt.figure(figsize=(14, 8))
plt.title('Rescaled chebyshev')
plt.plot(xs, P(3, xs), 'g-', label='degree 3', **kwargs)
plt.plot(xs, [P(3, alpha)]*len(xs), 'g--', label='max value')
plt.plot(xs, P(6, xs), 'b-', label='degree 6', **kwargs)
plt.plot(xs, [P(6, alpha)]*len(xs), 'b--', label='max value')
plt.legend();
We see that doubling the degree has a much more dramatic effect on the magnitude of the polynomial in the interval $[\alpha, \beta].$
Let's compare this beautiful Chebyshev polynomial side by side with the naive polynomial we saw earlier.
plt.figure(figsize=(14, 8))
plt.title('Rescaled chebyshev vs naive polynomial')
plt.plot(xs, P(6, xs), 'g-', label='deg-6 chebyshev', **kwargs)
plt.plot(xs, [P(6, alpha)]*len(xs), 'g--', label='max value')
plt.plot(xs, p(6, xs), 'b-', label='deg-6 naive', **kwargs)
plt.plot(xs, [p(6, alpha)]*len(xs), 'b--', label='max value')
plt.legend();
That looks promising. But what exactly is the error bound that comes out of it and how does it lead to an iterative algorithm? These are important questions that we'll answer next.
The Chebyshev polynomial leads to an accelerated version of gradient descent. Before we describe the iterative process, let's first see what error bound comes out of the Chebyshev polynomial.
So, just how large is the polynomial in the interval $[\alpha, \beta]$? First, note that the maximum value is attained at $\alpha$. Plugging this into the definition of the rescaled Chebyshev polynomial we get the upper bound for any $a\in[\alpha, \beta],$
$$ |P_k(a)| \le P_k(\alpha)= T\left(\frac{\beta+\alpha}{\beta-\alpha}\right)^{-1}. $$
Recalling the condition number $\kappa=\beta/\alpha,$ we have
$$ \frac{\beta+\alpha}{\beta-\alpha} =\frac{\kappa+1}{\kappa-1}=1+\epsilon $$ for some $\epsilon\approx 2/\kappa.$
We therefore only need to understand the value of $T_k$ at $1+\epsilon.$
A well-known fact about the Chebyshev polynomials comes in handy.
Fact. For $x>1,$ we have $T_k(x)=\cosh(k\cdot\mathrm{arccosh}(x)).$
Recall that $\cosh(x) = (e^x+e^{-x})/2$ and $\mathrm{arccosh}(x) = \ln(x+\sqrt{x^2-1}).$ How do you pronounce any of these?
Let's illustrate these functions in the relevant regime.
vs = np.linspace(1., 1.5, 100)
plt.figure(figsize=(14, 7))
plt.subplot(121, title="arccosh")
plt.plot(vs, np.arccosh(vs), **kwargs)
plt.subplot(122, title="cosh")
vs = np.linspace(0, 6, 100)
plt.plot(vs, np.cosh(vs), **kwargs);
Lemma. Assume $\epsilon<1/2.$ Then
$$ T_k(1+\epsilon)\ge \frac{\big(1+\sqrt{\epsilon}\big)^k}2 $$
Proof. Let $\phi=\mathrm{arccosh}(1+\epsilon)$ and note
$$ e^{\phi} = 1+\epsilon+\sqrt{2\epsilon+\epsilon^2} \ge 1+\sqrt{\epsilon}\,. $$
Therefore,
$$ T_k(1+\epsilon)=\frac{(e^{\phi})^k+(e^{\phi})^{-k}}{2} \ge \frac{\big(1+\sqrt{\epsilon}\big)^k}2\,. \qquad\qquad\square $$
Let's illustrate the lemma.
plt.figure(figsize=(14,7))
epsilons = np.linspace(0., 0.1, 100)
plt.xlabel('$\epsilon$')
plt.plot(epsilons, T(10, 1+epsilons),
label='$T_{10}(1+\epsilon)$', **kwargs)
plt.plot(epsilons, (1+np.sqrt(epsilons))**10/2,
label='$(1+\sqrt{\epsilon})^{10}/2$', **kwargs)
plt.legend();
The reciprocal is what we needed to upper bound the error of our algorithm, so we have
$$ T_k\left(\frac{\beta+\alpha}{\beta-\alpha}\right)^{-1} \le 2\big(1+\sqrt{\epsilon}\big)^{-k}\,. $$
This establishes that the Chebyshev polynomial achieves the error bound
$$ \|e_k\|\le 2\big(1+\sqrt{\epsilon}\big)^{-k}\|e_0\| $$
Recalling that $\epsilon\approx 1/\kappa,$ this means that for large $\kappa,$ we get quadratic savings in the degree we need before the error drops of exponentially.
Due to the recursive definition of the Chebyshev polynomial, we directly get an iterative algorithm out of it. Transferring the recursive definition to our rescaled Chebyshev polynomial, we have
$$ P_{k+1}(a) = (\eta_k a + \gamma_k)P_k(a) + \mu_k P_{k-1}(a), $$
where we can work out the coefficients $\eta_k,\gamma_k,\mu_k$ from the recurrence definition. Since $P_k(0)=1,$ we must have $\gamma_k+\mu_k=1.$ This leads to a simple update rule for our iterates:
$$ \begin{align} x_{k+1} &= (\eta_k A + (1-\mu_k))x_k + \mu_k x_{k-1}-\eta_k b\\ &= x_k - \eta_k\nabla f(x_k) + \mu_k (x_k - x_{k-1}) \end{align} $$
We can easily compute this by keeping around not just the last iterate, but also the previous one.
We see that the update rule above is actually very similar to plain gradient descent except for the additional term $\mu_k(x_k - x_{k-1}).$
This term can be interpreted as a momentum term, pushing the algorithm in the direction of where it was headed before.
In the next lecture, we'll dig deaper into momentum and see how to generalize the result for quadratics to general convex functions.