BatchNormTrainingBackprop

BatchNormTrainingBackprop  // Compute mean and variance backprop from the input.

Description

Computes the input, gamma and beta backprop increments.

Inputs

Name Element Type Shape
input real \((\bullet, C, \ldots)\)
gamma same as input \((C)\)
beta same as input \((C)\)
mean same as input \((C)\)
variance same as input \((C)\)
normalized_delta same as input same as input

Attributes

Name Type Notes
epsilon double Small bias added to variance to avoid division by 0.

Outputs

Name Element Type Shape
input_delta same as input Same as input
gamma_delta same as gamma \((C)\)
beta_delta same as beta \((C)\)

Mathematical Definition

It is easiest to simplify by looking at a single channel and flattening the remaining axes into a vector; so gamma and beta are scalars, and input is an \(N\)-element vector.

The step by step forward training computation is

\[\begin{split}\mathtt{mean} &= \frac{\sum{\mathtt{input}_i}}{N}\\ \mathtt{centered}_i &= \mathtt{input}_i - \mathtt{mean}\\ \mathtt{square}_i &= \mathtt{centered}_i^2\\ \mathtt{variance} &= \frac{\sum \mathtt{square}_i}{N}\\ \mathtt{invsqrt} &= \frac{1}{\sqrt{\mathtt{variance}+\epsilon}}\\ \mathtt{gmul} &= \texttt{gamma}\cdot \mathtt{invsqrt}\\ \mathtt{normed}_i &= \mathtt{centered}_i\mathtt{gmul}+\texttt{beta}\end{split}\]

Using the notation \(\overline{\texttt{name}}\) for \(\texttt{name_delta}\) and \(\overline{x} \leftarrow y\) to mean the backprop value for \(\texttt{x_delta}\) is a sum that includes \(y\).

We work backwards

\[\begin{split}\overline{\texttt{beta}}&\leftarrow \overline{\texttt{normed}}\\ \overline{\texttt{gmul}}&\leftarrow \sum \overline{\texttt{normed}}_i\\ \overline{\texttt{centered}}_i&\leftarrow\overline{\texttt{normed}}_i\texttt{gmul}\\ \overline{\texttt{gamma}}&\leftarrow \overline{\texttt{gmul}}\cdot\texttt{invsqrt}\\ \overline{\texttt{invsqrt}}&\leftarrow\texttt{gamma}\cdot\overline{\texttt{gmul}}\\ \overline{\texttt{variance}}&\leftarrow -\frac{\overline{\texttt{invsqrt}}\cdot\texttt{invsqrt}}{2\cdot(\texttt{variance}+\epsilon)}\\ \overline{\texttt{square}}_i&\leftarrow\frac{\overline{\texttt{variance}}}{N}\\ \overline{\texttt{centered}}_i&\leftarrow 2\cdot\texttt{centered}_i\cdot\overline{\texttt{square}}_i\\ \overline{\texttt{input}}_i&\leftarrow\overline{\texttt{centered}}_i\\ \overline{\texttt{mean}}&\leftarrow\sum\overline{\texttt{centered}}_i\\ \overline{\texttt{input}}_i&\leftarrow\frac{\overline{\texttt{mean}}}{N}\end{split}\]

C++ Interface

class BatchNormTrainingBackprop : public ngraph::op::Op

Public Functions

void validate_and_infer_types()

Throws if the node is invalid.