Deep learning theory at ICML 2020
Published:
This post is intended to serve as a quick summary of the advancements made in theoretical deep learning that were presented at ICML 2020. You are encouraged to check the paper (linked with the title) that interests you the most. In this blog, I discuss four of my favorites from the conference:
- Rethinking bias-variance trade-off for generalization of neural networks
- Dynamics of Deep Neural Networks and Neural Tangent Hierarchy
- Proving the lottery ticket hypothesis: Pruning is all you need
- Linear Mode Connectivity and the Lottery Ticket Hypothesis
Rethinking bias-variance trade-off for generalization of neural networks
Neural networks are heavily overparameterized models and if we are to go by classical theory, such models shouldn’t generalize. This paper is a step forward in the goal of demystifying the generalization puzzle in deep learning. Concretely, this paper studies the bias-variance curves for neural networks with varying number of parameters and shows that variance behaves differently to what the classical theory suggests.
Classical idea: Monotonic bias, Monotonic variance
We know that $\textrm{Risk}$ can be decomposed into $\textrm{Bias}^2$ and $\textrm{Variance}$. Classical theory tells us that the $\textrm{Bias}^2$ decreases montonically with model complexity, while $\textrm{Variance}$ increases montonically. $\textrm{Risk}$ initially decreases as the model’s capacity increases, but then starts increasing in the overfitting regime. Here, the increase in $\textrm{Variance}$ dominates the decrease in $\textrm{Bias}^2$ leading to a U-shaped $\textrm{Risk}$ curve.
But this U-shaped risk curve isn’t seen in deep learning. Increasing width or depth of neural networks usually results in decreased risk. One of the explanation for such a behavior is Double descent, which extends the classical theory for overparameterized models like neural networks. Double descent introduces a new regime beyond the overfitting regime, called as the interpolating regime, where risk decreases with model complexity. Thus, double descent suggests a peak in risk at the transition between the overfitting and the interpolating regime.
We typically observe montonically decreasing risk (no peaks) or a small bump (short overfitting regime) with real-world data. However, when label noise is injected in data double descent curves are clearly seen.
Proposed idea: Unimodal variance (along with Montonic bias)
Now, the authors are interested in the behavior of bias and variance curves in deep learning. They propose that the variance curve is actually unimodal and not monotonic, while the bias curve is montonically decreasing. The unimodal behavior of variance essentially means that variance increases, peaks and then decreases with model complexity. Depending on the relative magnitude of the $\textrm{Bias}^2$ and $\textrm{Variance}$, one the following $\textrm{Risk}$ curves can be observed:
The unimodal behavior of variance explains why we see double descent. The risk curve depicted in Case 2 in the figure above exhibits double descent behavior.
By computing the value of variance for models of varying capacity across multiple computer vision datasets, the authors empirically prove that variance is indeed unimodal. However, the authors don’t provide an explanation as to why variance behaves in this manner.
Computing the bias and variance: Devil is in the details
We know that, $\textrm{Risk} = \textrm{Bias}^2 + \textrm{Variance}$. You can take a look at the paper for the exact bias-variance decomposition for MSE and Cross Entropy losses.
One of the ways $\textrm{Variance}$ is estimated is described below:
- Split the training dataset, $D$ into two halves $D_1$ and $D_2$.
- Train classifiers $f_1$ and $f_2$ on $D_1$ and $D_2$ respectively.
- Unbiased estimator of $\textrm{Variance}$ is given by $\frac{1}{2} (f_1(x) - f_2(x))^2$.
- Average the above estimate across the entire test set.
- Repeat the above steps for different splits of $D$ and average across the results.
- $\textrm{Bias}^2$ is obtained by subtracting the $\textrm{Variance}$ from $\textrm{Risk}$
Unimodal variance as explantion for Double descent behavior
Double descent is clearly seen when label noise is injected to real world data. Otherwise, it can be very small and we could miss it easily. The figure below shows that double descent becomes more and more prominent with increasing label noise.
The authors show that with increasing label noise, the $\textrm{Variance}$ increases and peaks at higher value, increasing the value of risk’s peak higher and which causes a more prominent double descent behavior.
Random design v/s Fixed design
Random design
All the experiments and the observed behavior is for the random design setting, in which the expectation in $\textrm{Bias}^2$ and $\textrm{Variance}$ is over different training sets, $\mathcal{T}$. This is the usual way of doing things in machine learning.
where $\bar f(x) = \mathbb{E}_{\mathcal{T}}[f(x, \mathcal{T})]$.
The first term is $\textrm{Bias}^2$ while the second one is $\textrm{Variance}$.
Fixed design
But theoretical analysis is usually done in the fixed design setting. The covariates $x_i$ (training instances) are fixed and the randomness comes from $y_i \sim \mathbb{P}[Y \vert X = x_i]$. Typically, $y_i = f_0(x) + \epsilon_i \; \textrm{where} \; \epsilon_i \sim \mathcal{N}(0, \sigma_i^2)$.
Usually in the fixed design setting, a larger bias and a smaller variance exists (I don’t understand why). In this fixed design setting, monotonic bias and unimodal variance don’t necessarily hold. Refer to Mei & Montanari, 2019 for more.
Miscellaneous
Increased bias explains why risk curve moves upwards for out-of-distribution samples.
Dynamics of Deep Neural Networks and Neural Tangent Hierarchy
The motivation behind this paper is to understand the optimization process in neural networks. Neural networks are trained with loss functions that are highly non convex with respect to their parameters. How is that Stochastic Gradient Descent (SGD) is able to efficiently converge to solutions that generalize well? To answer this question, we can study the dynamics of the neural networks, i.e. how the neural network changes during the training process, which is what this paper does.
Neural Tangent Kernel
Neural networks are complex objects and often assumptions need to be made for mathematical convenience. Here, neural networks are defined in a slightly modified fashion.
Neural network: $\displaystyle \quad f(x, \theta) = a^T x^{(H)}, \quad x^{(l)} = \frac{1}{\sqrt m} \sigma(W^{(l)}, x^{(l-1)}) \quad l =1, \cdots, H$ and $\theta = [\textrm{vec}(W^{(1)}), \cdots, \textrm{vec}(W^{(H)}), a$].
The MSE loss function is considered, Loss function: $\displaystyle L(\theta) = \frac{1}{2n}\sum_{1}^{n} (f(x_i, \theta) - y_i)^2$
The dynamics of the parameters are given by their gradient wrt the loss function,
Recent works study the training dynamic in the trajectory space instead of parameter space as the trajectory space is compact $(\mathbb R^n)$ and easier to interpret.
Trajectory space: $(f(x_1, \theta_t), f(x_2, \theta_t), \cdots, f(x_n, \theta_t))$
Parameter space: $\theta_t = (W^{(1)}_t, W^{(2)}_t, \cdots, W^{(H)}_t, a_t)$
Now the dynamic in the trajectory space can we worked out as follows:
where,
is the Neural Tangent Kernel (NTK).
The following theorems are known for the Neural Tangent Kernel:
Theorem by Jacot et. al, ‘18: $K^{(2)}_t(\cdot, \cdot) = K^{(2)}_0(\cdot, \cdot)$ when $m$ approaches infinity.
Theorem by Du et. al, ‘18 (a ; b): $K^{(2)}_t(\cdot, \cdot) \approx K^{(2)}_0(\cdot, \cdot)$ when $m > n^4.$
In the infinite width case, the training dynamic becomes analytically solvable and it can be seen that an infinitely wide neural network behaves like a kernel regression model.
HOWEVER,
NTKs perform worse than real-life deep learning (Arora et. al ‘19: Convolutional NTKs). Some improvements have been made (Enhanced CNTKs).
The authors say that,
It is possible to show that the class of finite width neural networks is more expressive than the limiting NTK. It has been constructed in (Ghorbani et al., 2019; Yehudai & Shamir, 2019; Allen-Zhu & Li, 2019) that there are simple functions that can be efficiently learnt by finite width neural networks, but not the kernel regression using the limiting NTK.
The observed disparity stems from the fact that NTK varies over time in the finite width case. We can obtain equations for the dynamics of the NTK:
where,
If we keep continuing,
we obtain the Neural Tangent Hierarchy, which is given by the equations above.
Truncated Neural Tangent Hierarchy
Let’s truncate this hierarchy to $p$.
For $2 \le r\le p-1:$
-
$\displaystyle \partial_t \tilde{K}^{(r)}_t (x_{i_1}, x_{i_2}, \cdots, x_{i_r}) = - \frac{1}{n} \sum_{j=1}^{n} \tilde{K}^{(r+1)}_t(x_{i_1}, x_{i_2}, \cdots, x_{i_r}, x_j)(\tilde f(x_j, \theta_t) - y_j)$ -
$\displaystyle \partial_t \tilde K^{(p)}_t(x, x') = 0$
$\tilde f(x_j, \theta_0) = f(x_j, \theta_0)$ and $\tilde K^{(r)}_0 = K^{(r)}_0$
Note that for $p=2$, we have the NTK theorems by Jacot et. al. and Du et. al.
Main theorem of the paper
Let $p^* \ge 2$ and $\tilde f$ be the solution to the truncated Neural tangent hierarchy at $p^*$
we have,
under some (minor) assumptions on the data.
The truncated neural tangent hierarchy at $p$, approximates the dynamic of the finite width neural network upto a time, $t$. This approximation is better and is valid for a longer time for a larger value of $m$ and a larger value of $p$.
The conjecture:
This slide from the author’s presentation states their conjecture: Truncated NTH generalizes better with increasing $p$.
Summary
For finite width neural networks, the gradient dynamic is captured by an infinite hierarchy of recursive differential equations. Truncating these equations to two (p=2), we get the Neural Tangent Kernel in the infinite width limit. However, to accurately represent the training dynamics over the entire duration of training of finite width networks we need the entire infinite hierarchy of equations. For feasibility, truncating this hierarchy to p equations allows us to approximate the training dynamic to certain time, t. Increasing the width, m and/or increasing p, allows us to approximate the training dynamics for a longer time.
Proving the lottery ticket hypothesis: Pruning is all you need
As the title suggests, this paper provides a proof for the Lottery Ticket Hypothesis (LTH). Though most of the content in the paper is the proof, I won’t talk a lot about it. The LTH is fascinating and I try to highlight the significance of actually proving it. But first, I define some terms used in the context of this paper.
Neural net:
Subnetwork:
Here, $B_{l}$ is a binary matrix of the same size as $W_{l}$ and $\sigma$ denotes the ReLU activation function.
(Weak) Lottery ticket Hypothesis: Consider a randomly initialized network, $N$. $N$ contains a subnetwork $n$ such that when $n$ is trained in isolation it achieves the same performance as (trained) $N$ using at most the same number of iterations.
But, starting with $n$ is entering like entering the lottery with just one ticket. There’s a very small chance (close to $0$) that $n$ has the right initialization of parameters to achieve the same performance. A bigger network $N$ contains (exponentially) many small subnetworks, $n$ (many random initializations i.e. many lottery tickets) which ensures that you are achieve a good performance when trained.
(Strong) Lottery ticket Hypothesis: Let $F$ be a fixed target network and let $N$ be a network obtained by overparametrizing $F$. When $N$ is randomly initialized, it contains a subnetwork, $n$, that performs as good as the target network $F$.
The reason this is the Strong LTH is that this version states that you don’t need to train the subnetwork at all. However, it’s important to understand that it does not make any explicit claims on the size of the bigger network, $N$.
Moreover, I found the nomenclature to be a bit confusing. In mathematics, if you assume a stronger version of a hypothesis, the weaker version can be easily proved. This may not be true for LTH. It is definitely not obvious.
Denote the width of the network, $F$, by $w$ and it’s depth by $d$. The authors prove the strong LTH for $N$ with width, $W = \textrm{polynomial}(w, d)$, and depth, $D = 2d$. Their proof method relies on pruning. Essentially, you can prune $N$ to obtain $n$ which approximates $F$.
So, if I know that a 10-layer network with width 100 would be sufficient from achieving good performance for a task, can I randomly initialize a network, $N$ with depth 2*10 and width $polynomial(100, 10)$ and prune it to obtain the desired subnetwork? Well, you could. But pruning the network, $N$ would be computationally as hard as training it.
Linear Mode Connectivity and the Lottery Ticket Hypothesis
Consider a randomly initialized network. Start the optimization with different SGD seeds, i.e. a different shuffling of training dataset with possibly different data augmentation (angle of rotation, flips, etc.). How are the end results of the optimization processes (the trained weights) related to each other? Before answering the question, let’s define some terms.
Linear mode connectivity
Two different solutions, $W_1$ and $W_2$ are said to be linear mode connected if error along the linear path between them doesn’t increase.
Concretely, let $W_\alpha = \alpha W_1 + (1-\alpha) W_2$ where $\alpha \in [0, 1]$. Let $\mathcal{E}(W)$ denote the (train/test) loss where $W$ represents the weights of the network.
Instability index, $\mathcal{N}$ is defined as $\displaystyle \sup_{\alpha \in [0, 1]} [ \mathcal{E}(W_\alpha) - \mathcal{E}(0.5(W_1 + W_2) ]$.
$W_1$ and $W_2$ are said to be linear mode connected when instability index is close to zero, i.e. $\mathcal{N} \approx 0$.
A network is said to be SGD noise stable if two different runs of SGD (different SGD noises) result in solutions that are linear mode connected.
The authors show that small networks like LeNet are SGD noise stable at initialization. Varying SGD noises with the same initialization results in solutions that are linear mode connected. Larger networks aren’t necessarily noise stable at initialization.
However, these bigger networks quickly become stable to SGD noise once they’re trained for a few thousand iterations. For ResNet-20 on CIFAR-10, stability occurs at 3% of the total training time, while ResNet-50 on ImageNet becomes stable at 20% of the training process.
Stability and Pruning LTH Subnetworks
The authors in their previous work propose a methodology called Iterative Magnitude Pruning (IMP) for finding the subnetwork from the (weak) LTH that achieves the same performance as the larger one. Interestingly, they find that the IMP subnetworks only train to full accuracy when they are stable to SGD noise. Please refer to the paper for more details regarding this observation.