OneHot

OneHot  // One-hot expansion

Description

Inputs

Name Element Type Shape
arg Any \(d_1,\dots,d_{m-1},d_{m+1},\dots,d_n)~(n \geq 0)\)

Attributes

Name Description
shape The desired output shape, including the new one-hot axis.
one_hot_axis The index within the output shape of the new one-hot axis.

Outputs

Name Element Type Shape
output Same as arg shape

Mathematical Definition

\[\begin{split}\mathtt{output}_{i_0, \ldots, i_{n-1}} = \begin{cases} 1&\text{if }i_{\mathtt{one\_hot\_axis}} = \mathtt{arg}_{(i : i\ne \mathtt{one\_hot\_axis})}\\ 0&\text{otherwise} \end{cases}\end{split}\]

C++ Interface

class OneHot : public ngraph::op::Op

One-hot operator.

## Parameters

| | Description | | ————– | ———————————————————- | | `shape` | The desired output shape, including the new one-hot axis. | | `one_hot_axis` | The index within the output shape of the new one-hot axis. |

## Inputs

| | Type | Description | | —– | ——————————————————- | ——————————————- | | `arg` | \(E[d_1,\dots,d_{m-1},d_{m+1},\dots,d_n]~(n \geq 0)\) | A tensor of any shape and any element type. |

## Output

| Type | Description | | ———————- | ————————————————————————————————————————————————————————————————————————————————————————– | | \(E[d_1,\dots,d_n]\) | The tensor \(T'\), where \(T'[i_1,\dots,i_{m-1},i_m,i_{m+1},\dots,i_n] = 1\) if \(T[i_1,\dots,i_{m-1},i_{m+1},\dots,i_n] = i_m\), else \(0\). However, \(T'\) is undefined if any non-integral value or any out-of-bounds value is detected in the input tensor. |

Public Functions

OneHot(const std::shared_ptr<Node> &arg, const PartialShape &shape, size_t one_hot_axis)

Constructs a one-hot operation.

Parameters
  • arg: Node that produces the input tensor to be one-hot encoded.
  • shape: The shape of the output tensor, including the new one-hot axis.
  • one_hot_axis: The index within the output shape of the new one-hot axis.

size_t get_one_hot_axis() const

Return
The index of the one-hot axis.