Broadcast  // Operation that produces a tensor based on arg's axes


Operation whose output tensor ignores axes not in the arg tensor.


Name Element Type Shape
arg Any Any


Name Type Notes
shape Shape The shape of the output.
broadcast_axes AxisSet Axis positions in shape that are broadcast.


Name Element Type Shape
output Same as arg Same as shape

The shape of arg must match shape with elements in broadcast_axes removed.

For example, if arg is \([a, b, c]\) then

\[\begin{split}\mathtt{Broadcast(arg, Shape{2, 3}, AxisSet{0})} &= \begin{bmatrix} a & b & c\\ a & b & c \end{bmatrix}\\ \mathtt{Broadcast(arg, Shape{3, 2}, AxisSet{1})} &= \begin{bmatrix} a & a\\ b & b\\ c & c \end{bmatrix}\end{split}\]

Mathematical Definition

For a coordinate \(C\), let \(p(C)\) be a coordinate with the axes in broadcast_axes removed. For example, if \(\mathtt{broadcast_axes}=\{1,3\}\) then \(p([d_0, d_1, d_2, d_3, d_4]) = [d_0, d_2, d_4]\). Then

\[\mathtt{output}_C = \mathtt{arg}_{p(C)}.\]


\[\overline{\mathtt{arg}} \leftarrow \mathtt{Sum}(\Delta, \mathtt{broadcast_axes}).\]

C++ Interface

class Broadcast : public ngraph::op::Op

Operation which “adds” axes to an input tensor, replicating elements from the input as needed along the new axes.

Subclassed by ngraph::op::BroadcastLike

Public Functions

const std::string &description() const

Get the string name for the type of the node, such as Add or Multiply. The class name, must not contain spaces as it is used for codegen.

A const reference to the node’s type name


Constructs a broadcast operation.

Broadcast(const Output<Node> &arg, const Shape &shape, const AxisSet &broadcast_axes)

Constructs a broadcast operation.

  • arg: Node that produces the input tensor to be broadcast.
  • shape: The shape of the output tensor.
  • broadcast_axes: The axis positions (0-based) in the result that are being broadcast. The remaining axes in shape must be the same as the shape of arg.

void validate_and_infer_types()

Throws if the node is invalid.

const AxisSet &get_broadcast_axes() const

A set containing the indices of the broadcast axes (0-based).