BatchNormTraining // Compute mean and variance from the input.
||real||\((\bullet, C, \ldots)\)|
||Small bias added to variance to avoid division by 0.|
batch_variance outputs are computed per-channel from
The axes of the input fall into two categories: positional and channel, with channel being axis 1. For each position, there are \(C\) channel values, each normalized independently.
Normalization of a channel sample is controlled by two values:
- the batch_mean \(\mu\), and
- the batch_variance \(\sigma^2\);
and by two scaling attributes: \(\gamma\) and \(\beta\).
The values for \(\mu\) and \(\sigma^2\) come from computing the
mean and variance of
BatchNormTraining: public ngraph::op::Op¶
Batchnorm for training operation.
Subclassed by ngraph::op::gpu::BatchNormTrainingWithStats
const std::string &
- A const reference to the node’s type name
BatchNormTraining(const Output<Node> &input, const Output<Node> &gamma, const Output<Node> &beta, double epsilon)¶
input: Must have rank >= 2, [., C, …]
gamma: gamma scaling for normalized value. [C]
beta: bias added to the scaled normalized value [C]
epsilon: Avoids divsion by 0 if input has 0 variance
BatchNormTraining(double eps, const Output<Node> &gamma, const Output<Node> &beta, const Output<Node> &input)¶
In this version of BatchNorm:
MEAN AND VARIANCE: computed directly from the content of ‘input’.
OUTPUT VALUE: A tuple with the following structure:  - The normalization of ‘input’.  - The per-channel means of (pre-normalized) ‘input’.  - The per-channel variances of (pre-normalized) ‘input’.
AUTODIFF SUPPORT: yes: ‘generate_adjoints(…)’ works as expected.
SHAPE DETAILS: gamma: must have rank 1, with the same span as input’s channel axis. beta: must have rank 1, with the same span as input’s channel axis. input: must have rank >= 2. The second dimension represents the channel axis and must have a span of at least 1. output: shall have the same shape as ‘input’. output: shall have rank 1, with the same span as input’s channel axis. output: shall have rank 1, with the same span as input’s channel axis.
Throws if the node is invalid.
- const std::string &