BatchNormTrainingBackprop¶

BatchNormTrainingBackprop  // Compute mean and variance backprop from the input.


Description¶

Computes the input, gamma and beta backprop increments.

Inputs¶

Name Element Type Shape
input real $$(\bullet, C, \ldots)$$
gamma same as input $$(C)$$
beta same as input $$(C)$$
mean same as input $$(C)$$
variance same as input $$(C)$$
normalized_delta same as input same as input

Attributes¶

Name Type Notes
epsilon double Small bias added to variance to avoid division by 0.

Outputs¶

Name Element Type Shape
input_delta same as input Same as input
gamma_delta same as gamma $$(C)$$
beta_delta same as beta $$(C)$$

Mathematical Definition¶

It is easiest to simplify by looking at a single channel and flattening the remaining axes into a vector; so gamma and beta are scalars, and input is an $$N$$-element vector.

The step by step forward training computation is

$\begin{split}\mathtt{mean} &= \frac{\sum{\mathtt{input}_i}}{N}\\ \mathtt{centered}_i &= \mathtt{input}_i - \mathtt{mean}\\ \mathtt{square}_i &= \mathtt{centered}_i^2\\ \mathtt{variance} &= \frac{\sum \mathtt{square}_i}{N}\\ \mathtt{invsqrt} &= \frac{1}{\sqrt{\mathtt{variance}+\epsilon}}\\ \mathtt{gmul} &= \texttt{gamma}\cdot \mathtt{invsqrt}\\ \mathtt{normed}_i &= \mathtt{centered}_i\mathtt{gmul}+\texttt{beta}\end{split}$

Using the notation $$\overline{\texttt{name}}$$ for $$\texttt{name_delta}$$ and $$\overline{x} \leftarrow y$$ to mean the backprop value for $$\texttt{x_delta}$$ is a sum that includes $$y$$.

We work backwards

$\begin{split}\overline{\texttt{beta}}&\leftarrow \overline{\texttt{normed}}\\ \overline{\texttt{gmul}}&\leftarrow \sum \overline{\texttt{normed}}_i\\ \overline{\texttt{centered}}_i&\leftarrow\overline{\texttt{normed}}_i\texttt{gmul}\\ \overline{\texttt{gamma}}&\leftarrow \overline{\texttt{gmul}}\cdot\texttt{invsqrt}\\ \overline{\texttt{invsqrt}}&\leftarrow\texttt{gamma}\cdot\overline{\texttt{gmul}}\\ \overline{\texttt{variance}}&\leftarrow -\frac{\overline{\texttt{invsqrt}}\cdot\texttt{invsqrt}}{2\cdot(\texttt{variance}+\epsilon)}\\ \overline{\texttt{square}}_i&\leftarrow\frac{\overline{\texttt{variance}}}{N}\\ \overline{\texttt{centered}}_i&\leftarrow 2\cdot\texttt{centered}_i\cdot\overline{\texttt{square}}_i\\ \overline{\texttt{input}}_i&\leftarrow\overline{\texttt{centered}}_i\\ \overline{\texttt{mean}}&\leftarrow\sum\overline{\texttt{centered}}_i\\ \overline{\texttt{input}}_i&\leftarrow\frac{\overline{\texttt{mean}}}{N}\end{split}$

C++ Interface¶

class BatchNormTrainingBackprop : public ngraph::op::Op

Public Functions

void validate_and_infer_types()

Throws if the node is invalid.