OneHot

OneHot  // One-hot expansion

Description

Inputs

Name Element Type Shape
arg Any integral type \(d_1,\dots,d_{m-1},d_{m+1},\dots,d_n)~(n \geq 0)\)

Attributes

Name Description
shape The desired output shape, including the new one-hot axis.
one_hot_axis The index within the output shape of the new one-hot axis.

Outputs

Name Element Type Shape
output Same as arg shape

Mathematical Definition

\[\begin{split}\mathtt{output}_{i_0, \ldots, i_{n-1}} = \begin{cases} 1&\text{if }i_{\mathtt{one\_hot\_axis}} = \mathtt{arg}_{(i : i\ne \mathtt{one\_hot\_axis})}\\ 0&\text{otherwise} \end{cases}\end{split}\]

C++ Interface

class OneHot : public ngraph::op::Op

One-hot operator.

Parameters

Inputs

Public Functions

OneHot(const std::shared_ptr<Node> &arg, const PartialShape &shape, size_t one_hot_axis)

Constructs a one-hot operation.

Parameters
  • arg: Node that produces the input tensor to be one-hot encoded.
  • shape: The shape of the output tensor, including the new one-hot axis.
  • one_hot_axis: The index within the output shape of the new one-hot axis.

size_t get_one_hot_axis() const

Return
The index of the one-hot axis.