Python API

This section contains the Python API component of the nGraph Compiler stack. The Python API exposes nGraph™ C++ operations to Python users. For quick-start you can find an example of the API usage below.

Note that the output at print(model) may vary; it varies according to the number of nodes or variety of step used to compute the printed solution. Various NNs configured in different ways should produce the same result for simple calculations or accountings. More complex computations may have minor variations with respect to how precise they ought to be. For example, a more efficient graph <Multiply: 'Multiply_12' ([2, 2])> can also be achieved with some configurations.

“Basic example”
import numpy as np
import ngraph as ng

A = ng.parameter(shape=[2, 2], name='A', dtype=np.float32)
B = ng.parameter(shape=[2, 2], name='B')
C = ng.parameter(shape=[2, 2], name='C')
# >>> print(A)
# <Parameter: 'A' ([2, 2], float)>

model = (A + B) * C
# >>> print(model)
# <Multiply: 'Multiply_14' ([2, 2])>

runtime = ng.runtime(backend_name='CPU')
# >>> print(runtime)
# <Runtime: Backend='CPU'>

computation = runtime.computation(model, A, B, C)
# >>> print(computation)
# <Computation: Multiply_14(A, B, C)>

value_a = np.array([[1, 2], [3, 4]], dtype=np.float32)
value_b = np.array([[5, 6], [7, 8]], dtype=np.float32)
value_c = np.array([[9, 10], [11, 12]], dtype=np.float32)

result = computation(value_a, value_b, value_c)
# >>> print(result)
# [[ 54.  80.]
#  [110. 144.]]