Initialization

[math] \newcommand{\ul}{\mathbf} \newcommand{\symbf}{\bm} \newcommand\subsetap{\mathrel{\overset{\makebox[0pt]{\mbox{\normalfont\tiny\sffamily ap.}}}{\rule{0pt}{.8ex}\smash{\subset}}}} \newcommand{\rident}[1]{\mathrm{#1}} \newcommand{\iident}[1]{\mathit{#1}} \newcommand{\wip}{\emoji{construction}} \newcommand{\pointright}{\emoji{backhand-index-pointing-right-light-skin-tone}} \DeclareMathOperator*{\argmax}{arg\,max} \DeclareMathOperator*{\argmin}{arg\,min} \DeclareMathOperator*{\Arg}{Arg} \DeclareMathOperator*{\Var}{Var} \DeclareMathOperator*{\dom}{dom} \DeclareMathOperator{\Div}{div} \DeclareMathOperator{\morph}{\scalebox{0.7}{\ensuremath\square}} \DeclareMathOperator*{\esssup}{ess\,sup} \DeclareMathOperator{\Int}{Int} \DeclareMathOperator{\Cl}{Cl} \DeclareMathOperator{\id}{id} \DeclareMathOperator{\diam}{diam} \DeclareMathOperator{\supp}{supp} \DeclareMathOperator{\arctantwo}{arctan2} \DeclareMathOperator{\relu}{ReLU} \newcommand{\mathds}{\mathbb}[/math]

This article was automatically generated from a tex file and may contain conversion errors. If permitted, you may login and edit this article to improve the conversion.

For deep network, the initial values of the parameters make a significant difference to the functioning of the SGD algorithms. Not in the least because initial parameters being too small/large in magnitude would cause vanishing/exploding gradient problems and cause immediate issues for training. In this section we will develop some stochastic initialization schemes that provide a functional starting point for a network to train from. We generally use stochastic methods for initialization since we need initial parameter values in a layer to be sufficiently different. Imagine if the values were the same, then their gradients would also be the same and they would never diverge from each other, locking the network in an untrainable state. The most straightforward method of achieving this is drawing values from a probability distribution, which is currently the common practice. Usually the distributions that are used are either the normal [math]\mathcal{N}(0,\sigma^2)[/math] or the uniform [math]\rident{Unif}[-a,a][/math], where we need to pick [math]\sigma^2[/math] or [math]a[/math] to avoid training issues such as vanishing/exploding gradient. Each group of parameters (the linear coefficients and the biases in each layer) typically get assigned their own distribution. We will look at how these distributions are currently chosen.

[Deterministic initialization scheme]

If you recall, in the tutorial notebook 4\_FunctionApproximationIn1D.ipynb we developed a deterministic initialization scheme that outperformed the default stochastic scheme. We did have to use our insight about the problem as a whole (beyond what the data provided) to do it. It is quite conceivable that for a given application a deterministic scheme can be developed that gives a much better starting point for training than the current crop of stochastic schemes.

[On the importance of good initialization]

A good initialization scheme allows the network to reach a higher performance level in less time. This can have important consequences for large production networks. One such network is GPT-3 (see [1]), which is used for natural language processing, it has 175 billion parameters and training it has reportedly cost 4 million \$. Hence better initialization schemes can be of substantial economic value.


Stochastic Initialization

To develop a suitable probability distribution to initialize parameters with we start by looking at the input to a neuron as a vector valued random variable [math]X \in \mathbb{R}^n[/math]. The output of a neuron is then a random variable [math]Y \in \mathbb{R}^m[/math] per

[[math]] \begin{equation*} Y = \sigma \left( A X +\boldsymbol{b} \right) , \end{equation*} [[/math]]

for some constant matrix [math]A \in \mathbb{R}^{m \times n}[/math], bias vector [math]\boldsymbol{b} \in \mathbb{m}[/math] and activation function [math]\sigma[/math]. Or alternatively written out per component:

[[math]] \begin{equation} \label{eq:Y_i} Y_i = \sigma \left( \sum_{j=1}^n A_{ij} X_j + b_i \right) . \end{equation} [[/math]]

The idea is to then initialize [math]A[/math] and [math]\boldsymbol{b}[/math] so that the variance of the signal does not change too much from layer to layer:

[[math]] \begin{equation*} \sum_{i=1}^m \Var(Y_i)^2 \approx \sum_{j=1}^n \Var(X_j)^2 . \end{equation*} [[/math]]

Controlling the [math]L^2[/math] norm of the signal variances is not necessarily the only possibility here, but it is the choice we will proceed with. Calculating variances of functions of random variables is difficult in general. The schemes we will be looking at depend on the following approximation.

Lemma

Let [math]X[/math] be a real valued random variable and let [math]f:\mathbb{R}\to\mathbb{R}[/math] be a differentiable function then

[[math]] \begin{equation*} \Var\left( f(X) \right) \approx f' \left( \mathbb{E}[X] \right)^2 \cdot \Var(X), \end{equation*} [[/math]]
assuming the variance of [math]X[/math] is finite and [math]f'[/math] is differentiable.


Show Proof

Let [math]m := \mathbb{E}[X][/math] and approximate [math]f[/math] by its linearization [math]f(X) \approx f(m) + f'(m) (X - m)[/math]. Then we find

[[math]] \begin{align*} \Var\left( f(X) \right) &= \mathbb{E}\left[ \left( f(X) - \mathbb{E}[f(X)] \right)^2 \right] \\ &\approx \mathbb{E}\left[ \left( f(m) + f'(m) (X - m) - \mathbb{E}[f(m) + f'(m) (X - m)] \right)^2 \right] \\ &= \mathbb{E}\left[ \left( f(m) + f'(m) (X - m) - f(m) - f'(m) \,\mathbb{E}[(X - m)] \right)^2 \right] \\ &= \mathbb{E}\left[ \left( f'(m) (X - m) \right)^2 \right] \\ &= f'(m)^2 \, \mathbb{E}\left[ (X-m)^2 \right] \\ &= f' \left( \mathbb{E}[X] \right)^2 \cdot \Var(X). \end{align*} [[/math]]

Intuitively, the approximation in Lemma lemma can be read as: if the variance of [math]X[/math] is small and [math]f'[/math] is reasonably bounded then the variance of [math]f(X)[/math] will also be small. Applying Lemma lemma to \ref{eq:Y_i} yields

[[math]] \begin{equation*} \Var(Y_i) \approx \sigma' \left( \sum_{j=1}^n A_{i j} \, \mathbb{E}[X_j] + b_i \right)^2 \ \left( \sum_{j=1}^n A_{i j}^2 \ \Var(X_j) \right) . \end{equation*} [[/math]]

To make progress let us assume [math]\mathbb{E}[X_1]=\ldots=\mathbb{E}[X_n][/math] and [math]\Var(X_1)=\ldots=\Var(X_n)[/math], then

[[math]] \begin{equation*} \Var(Y_i) \approx \underbrace{ \sigma' \left( \mathbb{E}[X_1] \sum_{j=1}^n A_{i j} + b_i \right)^2 \ \left( \sum_{j=1}^n A_{i j}^2 \right)}_{\text{ideally}\ \approx 1} \ \Var(X_1) . \end{equation*} [[/math]]

When the bracketed term is approximately 1 then the variances of the outputs [math]Y_i[/math] are about the same as those of the inputs [math]X_j[/math]. Let us now turn the [math]A_{i j}[/math]'s and [math]b_i[/math]'s into random variables: [math]A_{i j} \sim \mu_1[/math] and [math]b_i \sim \mu_2[/math] for all [math]i[/math] and [math]j[/math] in their respective ranges and where [math]\mu_1[/math] and [math]\mu_2[/math] are some choice of scalar probability distributions. Ideally we would choose [math]\mu_1[/math] and [math]\mu_2[/math] so that

[[math]] \begin{equation} \label{eq:E=1} \mathbb{E}\left[ \sigma' \left( \mathbb{E}[X_1] \sum_{j=1}^n A_{i j} + b_i \right)^2 \ \left( \sum_{j=1}^n A_{i j}^2 \right) \right] = 1. \end{equation} [[/math]]

This expression allows us to put a condition on our choice of probability distributions that ensures that the variances of the signals between layers stay under control (at least at the start of training). The following examples show how \ref{eq:E=1} is utilized.

Example [Sigmoid with balanced inputs] Like before assume [math]\mathbb{E}[X_1]=\ldots=\mathbb{E}[X_n][/math] and [math]\Var(X_1)=\ldots=\Var(X_n)[/math] and our goal is choosing a probability distribution for the linear coefficients and one for the biases. Say [math]\sigma[/math] is the sigmoid activation function and we have [math]\mathbb{E}[X_i]=0[/math], i.e. we have balanced inputs. Additionally we want balanced parameter initialization, i.e. [math]\mathbb{E}[b_i]=0[/math] and [math]\mathbb{E}[A_{i j}]=0[/math] for all [math]i,j[/math] in their respective ranges. Since [math]\mathbb{E}[X_i]=0[/math] \ref{eq:E=1} reduces to:

[[math]] \begin{equation*} \mathbb{E}\left[ \sigma' \left( b_i \right)^2 \ \left( \sum_{j=1}^n A_{i j}^2 \right) \right] = \mathbb{E}\left[ \sigma' \left( b_i \right)^2 \right] \, \mathbb{E}\left[ \sum_{j=1}^n A_{i j}^2 \right], \end{equation*} [[/math]]

since the [math]b_i[/math]'s and [math]A_{i j}[/math]'s are independent. We know that [math]0 \lt \sigma'(x) \leq \frac{1}{4}[/math] and that the maximum is achieved at [math]x=0[/math]. Hence for that first factor to not become too small we need the variance of [math]b_i[/math] to be small since we already decided on setting [math]\mathbb{E}[b_i]=0[/math]. Of course the smallest possible variance is zero, so let us be uncompromising and fix [math]b_i=0[/math] for all [math]i[/math]. Thus the previous expression becomes

[[math]] \begin{equation*} \sigma'(0)^2 \ n \, \mathbb{E} \left[ A_{i j}^2 \right] = \frac{n}{16} \, \mathbb{E} \left[ (A_{i j} - 0)^2 \right] = \frac{n}{16} \, \mathbb{E} \left[ (A_{i j} - \mathbb{E}[A_{i j}])^2 \right] = \frac{n}{16} \, \Var(A_{i j}), \end{equation*} [[/math]]

which equals 1 if

[[math]] \begin{equation*} \Var(A_{i j}) = \frac{16}{n}. \end{equation*} [[/math]]

So we could choose our probability distribution for the linear coefficients to be the normal distribution [math]\mathcal{N}(0,\frac{16}{n})[/math] or the uniform distribution [math]\rident{Unif}\left[ -\frac{4\sqrt{3}}{n}, \frac{4\sqrt{3}}{n}\right][/math]. Of course the choice of the type of distribution is free as long as the expected value is zero and the variance is [math]\frac{16}{n}[/math], but in practice you will usually only encounter normal or uniform distributions. In any case this choice of expected values and variances will provide some assurance that the signals will not explode or die out as they travel through the network (at least at the start of training).

Example [ReLU with balanced inputs] Again assume [math]\mathbb{E}[X_1]=\ldots=\mathbb{E}[X_n][/math] and [math]\Var(X_1)=\ldots=\Var(X_n)[/math]. This time we use the ReLU activation function and we are going to initialize both our linear coefficients and biases with a uniform distribution [math]\rident{Unif}[-a,a][/math], our goal is choosing [math]a \gt 0[/math] in a suitable manner. Assume again that the inputs are balanced, i.e. [math]\mathbb{E}[X_i]=0[/math], then the expression from \ref{eq:E=1} simplifies to

[[math]] \begin{align*} \mathbb{E} \left[ \relu'\left( b_i\right)^2 \ \left( \sum_{j=1}^n A_{i j}^2 \right) \right] &= \mathbb{E} \left[ \mathbb{1}_{b_i \gt 0} \right] \ \mathbb{E}\left[ \left( \sum_{j=1}^n A_{i j}^2 \right) \right] \\ &= \mathbb{P} \left( b_i \gt 0 \right) \ n \, \Var(A_{i j}) \\ &= \frac{n}{2} \, \Var(A_{i j}) \\ &= \frac{n a^2}{6}, \end{align*} [[/math]]

which equals 1 if [math]a=\sqrt{\frac{6}{n}}[/math] so we would draw our [math]A_{i j}[/math]'s and [math]b_i[/math]'s from [math]\rident{Unif}\left[ - \sqrt{\frac{6}{n}}, \sqrt{\frac{6}{n}} \right][/math].


Xavier Initialization

The initialization schemes from the previous section focused on controlling the variance of the signals going forward through the network. While this does help in controlling the vanishing/exploding gradient problem we can also look at gradients directly as they backpropagate through the network, this is the approach taken by [2]. The first author's name is Xavier Glorot and for that reason the scheme we will be seeing is commonly referred to as Glorot or Xavier initialization (as it is in PyTorch for example). The idea is to treat the partial derivatives of the loss function with regards to the linear coefficients and biases as random variables as well. Consider a setting with a linear activation function and no bias:

[[math]] \begin{equation*} Y_i = \sum_{j=1}^n A_{i j} X_j, \end{equation*} [[/math]]

where we assume all inputs [math]X_j[/math] are distributed i.i.d. with zero mean. We additionally want to initialize our coefficients [math]A_{i j}[/math] with mean zero as well. Since the [math]A_{i j}[/math] and [math]X_j[/math]'s are independent the variance distributes over the sum. Additionally we have that

[[math]] \begin{equation*} \Var(A_{i j} X_j) = \mathbb{E}[X_j]^2 \Var(A_{i j}) + \mathbb{E}[A_{i j}]^2 \Var(X_j) + \Var(A_{i j}) \Var(X_j) = \Var(A_{i j}) \Var(X_j) \end{equation*} [[/math]]

since [math]\mathbb{E}[X_j]=\mathbb{E}[A_{i j}]=0[/math]. So we work out that

[[math]] \begin{equation*} \Var(Y_i) = \sum_{j=1}^n \Var(A_{i j}) \Var(X_j) = n \Var(A_{i j}) \Var(X_j). \end{equation*} [[/math]]

Hence for forward signal propagation we have [math]\Var(Y_i)=\Var(X_j)[/math] if

[[math]] \begin{equation} \label{eq:var_forward} \Var(A_{i j})=\frac{1}{n}. \end{equation} [[/math]]


But we can look at the backward gradient propagation as well. Let [math]\ell[/math] be a loss function at the end of the network, then we can look at the partial derivatives of [math]\ell[/math] with respect to the inputs and outputs as random variables as well. Call these random variables [math]\frac{\partial \ell}{\partial X_i}[/math] and [math]\frac{\partial \ell}{\partial Y_i}[/math], applying the chain rule gives us

[[math]] \begin{equation*} \frac{\partial \ell}{\partial X_j} = \sum_{i=1}^m \frac{\partial \ell}{\partial Y_i} \frac{\partial Y_i}{\partial X_j} = \sum_{i=1}^m \frac{\partial \ell}{\partial Y_i} A_{i j} . \end{equation*} [[/math]]

Now we make the same assumption about the backward gradients as we did about the forward signals, namely that the partial derivatives [math]\frac{\partial \ell}{\partial Y_i}[/math] are i.i.d. with zero mean. Then we can do the same calculation as before and find

[[math]] \begin{equation*} \Var\left( \frac{\partial \ell}{\partial X_j} \right) = m \Var\left( A_{i j} \right) \Var \left( \frac{\partial \ell}{\partial Y_i} \right). \end{equation*} [[/math]]

Hence if we want to have [math]\Var\left( \frac{\partial \ell}{\partial X_j} \right)=\Var \left( \frac{\partial \ell}{\partial Y_i} \right)[/math] for backward gradient propagation we need to set

[[math]] \begin{equation} \label{eq:var_backward} \Var(A_{i j}) = \frac{1}{m}. \end{equation} [[/math]]

Now unless [math]n=m[/math] we cannot satisfy \ref{eq:var_forward} and \ref{eq:var_backward} at the same time, but we can compromise and set

[[math]] \begin{equation} \label{eq:var_forward_backward} \Var(A_{i j}) = \frac{2}{n+m}. \end{equation} [[/math]]

Under this choice we can use the normal distribution [math]\mathcal{N}(0,\frac{2}{n+m})[/math] or the uniform distribution [math]\rident{Unif}\left[ -\sqrt{\frac{6}{n+m}}, \sqrt{\frac{6}{n+m}} \right][/math] to draw our coefficients [math]A_{i j}[/math] from.

Of course in reality we never use the linear activation function. The original Xavier initialization scheme has been expanded to include specific activation functions. For example for the ReLU by [3], they arrive at

[[math]] \begin{equation*} \Var(A_{i j}) = \frac{4}{n+m}. \end{equation*} [[/math]]

Which intuitively makes sense: since the ReLU is zero on half its domain the variance of the coefficients needs to be increased to keep the variances of the signals/gradients constant.

Other choices of activation function lead to other multipliers being introduced to the same basic formula \ref{eq:var_forward_backward}:

[[math]] \begin{equation*} \Var(A_{i j}) = \alpha^2 \frac{2}{n+m}, \end{equation*} [[/math]]

where [math]\alpha[/math] is called the gain and depends on the choice of activation function.

See torch.nn.init.xavier\_uniform\_ and torch.nn.init.xavier\_normal\_ for \mbox{PyTorch's} implementation of these initialization schemes.

Those are a lot of Assumptions

In the last two sections we made a lot of assumptions to arrive at simple formulas. Some of the assumptions are even verifiably incorrect in the networks we employ. In spite of the coarse and inelegant way these initialization schemes were derived they are widely used for the simple reason that they work. They do not totally solve the vanishing/exploding gradient problem but they still significantly improve the performance of the gradient descent algorithms.

General references

Smets, Bart M. N. (2024). "Mathematics of Neural Networks". arXiv:2403.04807 [cs.LG].

References

  1. Cite error: Invalid <ref> tag; no text was provided for refs named brown2020language
  2. Cite error: Invalid <ref> tag; no text was provided for refs named glorot2010understanding
  3. Cite error: Invalid <ref> tag; no text was provided for refs named he2015delving