aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/src
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyarmnn/src')
-rw-r--r--python/pyarmnn/src/pyarmnn/__init__.py138
-rw-r--r--python/pyarmnn/src/pyarmnn/_generated/__init__.py2
-rw-r--r--python/pyarmnn/src/pyarmnn/_quantization/__init__.py4
-rw-r--r--python/pyarmnn/src/pyarmnn/_quantization/quantize_and_dequantize.py70
-rw-r--r--python/pyarmnn/src/pyarmnn/_tensor/__init__.py6
-rw-r--r--python/pyarmnn/src/pyarmnn/_tensor/const_tensor.py159
-rw-r--r--python/pyarmnn/src/pyarmnn/_tensor/tensor.py119
-rw-r--r--python/pyarmnn/src/pyarmnn/_tensor/workload_tensors.py123
-rw-r--r--python/pyarmnn/src/pyarmnn/_utilities/__init__.py4
-rw-r--r--python/pyarmnn/src/pyarmnn/_utilities/profiling_helper.py95
-rw-r--r--python/pyarmnn/src/pyarmnn/_version.py26
-rw-r--r--python/pyarmnn/src/pyarmnn/swig/armnn.i27
-rw-r--r--python/pyarmnn/src/pyarmnn/swig/armnn_caffeparser.i103
-rw-r--r--python/pyarmnn/src/pyarmnn/swig/armnn_onnxparser.i96
-rw-r--r--python/pyarmnn/src/pyarmnn/swig/armnn_tfliteparser.i132
-rw-r--r--python/pyarmnn/src/pyarmnn/swig/armnn_tfparser.i102
-rw-r--r--python/pyarmnn/src/pyarmnn/swig/armnn_version.i58
-rw-r--r--python/pyarmnn/src/pyarmnn/swig/modules/armnn_backend.i66
-rw-r--r--python/pyarmnn/src/pyarmnn/swig/modules/armnn_descriptors.i1000
-rw-r--r--python/pyarmnn/src/pyarmnn/swig/modules/armnn_lstmparam.i97
-rw-r--r--python/pyarmnn/src/pyarmnn/swig/modules/armnn_network.i1159
-rw-r--r--python/pyarmnn/src/pyarmnn/swig/modules/armnn_profiler.i82
-rw-r--r--python/pyarmnn/src/pyarmnn/swig/modules/armnn_runtime.i254
-rw-r--r--python/pyarmnn/src/pyarmnn/swig/modules/armnn_tensor.i313
-rw-r--r--python/pyarmnn/src/pyarmnn/swig/modules/armnn_types.i136
-rw-r--r--python/pyarmnn/src/pyarmnn/swig/modules/armnn_types_utils.i26
-rw-r--r--python/pyarmnn/src/pyarmnn/swig/standard_header.i53
-rw-r--r--python/pyarmnn/src/pyarmnn/swig/typemaps/network_optimize.i41
-rw-r--r--python/pyarmnn/src/pyarmnn/swig/typemaps/permutation_vector.i52
-rw-r--r--python/pyarmnn/src/pyarmnn/swig/typemaps/tensor_memory.i52
-rw-r--r--python/pyarmnn/src/pyarmnn/swig/typemaps/tensor_shape.i51
-rw-r--r--python/pyarmnn/src/pyarmnn/swig/typemaps/vectors.i235
32 files changed, 4881 insertions, 0 deletions
diff --git a/python/pyarmnn/src/pyarmnn/__init__.py b/python/pyarmnn/src/pyarmnn/__init__.py
new file mode 100644
index 0000000000..c451479614
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/__init__.py
@@ -0,0 +1,138 @@
+# Copyright © 2019 Arm Ltd. All rights reserved.
+# SPDX-License-Identifier: MIT
+import inspect
+import sys
+import logging
+
+from ._generated.pyarmnn_version import GetVersion, GetMajorVersion, GetMinorVersion
+
+# Parsers
+try:
+ from ._generated.pyarmnn_caffeparser import ICaffeParser
+except ImportError as err:
+ logger = logging.getLogger(__name__)
+ message = "Your ArmNN library instance does not support Caffe models parser functionality. "
+ logger.warning(message + "Skipped ICaffeParser import.")
+ logger.debug(str(err))
+
+
+ def ICaffeParser():
+ raise RuntimeError(message)
+
+try:
+ from ._generated.pyarmnn_onnxparser import IOnnxParser
+except ImportError as err:
+ logger = logging.getLogger(__name__)
+ message = "Your ArmNN library instance does not support Onnx models parser functionality. "
+ logger.warning(message + "Skipped IOnnxParser import.")
+ logger.debug(str(err))
+
+
+ def IOnnxParser():
+ raise RuntimeError(message)
+
+try:
+ from ._generated.pyarmnn_tfparser import ITfParser
+except ImportError as err:
+ logger = logging.getLogger(__name__)
+ message = "Your ArmNN library instance does not support TF models parser functionality. "
+ logger.warning(message + "Skipped ITfParser import.")
+ logger.debug(str(err))
+
+
+ def ITfParser():
+ raise RuntimeError(message)
+
+try:
+ from ._generated.pyarmnn_tfliteparser import ITfLiteParser
+except ImportError as err:
+ logger = logging.getLogger(__name__)
+ message = "Your ArmNN library instance does not support TF lite models parser functionality. "
+ logger.warning(message + "Skipped ITfLiteParser import.")
+ logger.debug(str(err))
+
+
+ def ITfLiteParser():
+ raise RuntimeError(message)
+
+# Network
+from ._generated.pyarmnn import Optimize, OptimizerOptions, IOptimizedNetwork, IInputSlot, \
+ IOutputSlot, IConnectableLayer, INetwork
+
+# Backend
+from ._generated.pyarmnn import BackendId
+from ._generated.pyarmnn import IDeviceSpec
+
+# Tensors
+from ._generated.pyarmnn import TensorInfo, TensorShape
+
+# Runtime
+from ._generated.pyarmnn import IRuntime, CreationOptions, INetworkProperties
+
+# Profiler
+from ._generated.pyarmnn import IProfiler
+
+# Types
+from ._generated.pyarmnn import DataType_Float32, DataType_QuantisedAsymm8, DataType_Signed32, \
+ DataType_QuantisedSymm16, DataType_Float16, DataType_QuantizedSymm8PerAxis, \
+ DataType_QuantisedSymm8, DataType_Boolean
+from ._generated.pyarmnn import DataLayout_NCHW, DataLayout_NHWC
+
+from ._generated.pyarmnn import ActivationFunction_Abs, ActivationFunction_BoundedReLu, ActivationFunction_LeakyReLu, \
+ ActivationFunction_Linear, ActivationFunction_ReLu, ActivationFunction_Sigmoid, ActivationFunction_SoftReLu, \
+ ActivationFunction_Sqrt, ActivationFunction_Square, ActivationFunction_TanH, ActivationDescriptor
+from ._generated.pyarmnn import ArgMinMaxFunction_Max, ArgMinMaxFunction_Min, ArgMinMaxDescriptor
+from ._generated.pyarmnn import BatchNormalizationDescriptor, BatchToSpaceNdDescriptor
+from ._generated.pyarmnn import ComparisonDescriptor, ComparisonOperation_Equal, ComparisonOperation_Greater, \
+ ComparisonOperation_GreaterOrEqual, ComparisonOperation_Less, \
+ ComparisonOperation_LessOrEqual, ComparisonOperation_NotEqual
+from ._generated.pyarmnn import Convolution2dDescriptor, DepthToSpaceDescriptor, DepthwiseConvolution2dDescriptor, \
+ DetectionPostProcessDescriptor, FakeQuantizationDescriptor, FullyConnectedDescriptor, \
+ InstanceNormalizationDescriptor, LstmDescriptor, L2NormalizationDescriptor, MeanDescriptor
+from ._generated.pyarmnn import NormalizationAlgorithmChannel_Across, NormalizationAlgorithmChannel_Within, \
+ NormalizationAlgorithmMethod_LocalBrightness, NormalizationAlgorithmMethod_LocalContrast, NormalizationDescriptor
+from ._generated.pyarmnn import PadDescriptor
+from ._generated.pyarmnn import PermutationVector, PermuteDescriptor
+from ._generated.pyarmnn import OutputShapeRounding_Ceiling, OutputShapeRounding_Floor, \
+ PaddingMethod_Exclude, PaddingMethod_IgnoreValue, PoolingAlgorithm_Average, PoolingAlgorithm_L2, \
+ PoolingAlgorithm_Max, Pooling2dDescriptor
+from ._generated.pyarmnn import ResizeMethod_Bilinear, ResizeMethod_NearestNeighbor, ResizeDescriptor, \
+ ReshapeDescriptor, SliceDescriptor, SpaceToBatchNdDescriptor, SpaceToDepthDescriptor, StandInDescriptor, \
+ StackDescriptor, StridedSliceDescriptor, SoftmaxDescriptor, TransposeConvolution2dDescriptor, \
+ SplitterDescriptor
+from ._generated.pyarmnn import ConcatDescriptor, CreateDescriptorForConcatenation
+
+from ._generated.pyarmnn import LstmInputParams
+
+# Public API
+# Quantization
+from ._quantization.quantize_and_dequantize import quantize, dequantize
+
+# Tensor
+from ._tensor.tensor import Tensor
+from ._tensor.const_tensor import ConstTensor
+from ._tensor.workload_tensors import make_input_tensors, make_output_tensors, workload_tensors_to_ndarray
+
+# Utilities
+from ._utilities.profiling_helper import ProfilerData, get_profiling_data
+
+from ._version import __version__, __arm_ml_version__
+
+ARMNN_VERSION = GetVersion()
+
+
+def __check_version():
+ from ._version import check_armnn_version
+ check_armnn_version(ARMNN_VERSION)
+
+
+__check_version()
+
+__all__ = []
+
+__private_api_names = ['__check_version']
+
+for name, obj in inspect.getmembers(sys.modules[__name__]):
+ if inspect.isclass(obj) or inspect.isfunction(obj):
+ if name not in __private_api_names:
+ __all__.append(name)
diff --git a/python/pyarmnn/src/pyarmnn/_generated/__init__.py b/python/pyarmnn/src/pyarmnn/_generated/__init__.py
new file mode 100644
index 0000000000..18b11630d1
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/_generated/__init__.py
@@ -0,0 +1,2 @@
+# Copyright © 2019 Arm Ltd. All rights reserved.
+# SPDX-License-Identifier: MIT
diff --git a/python/pyarmnn/src/pyarmnn/_quantization/__init__.py b/python/pyarmnn/src/pyarmnn/_quantization/__init__.py
new file mode 100644
index 0000000000..fd9bbf1db7
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/_quantization/__init__.py
@@ -0,0 +1,4 @@
+# Copyright © 2019 Arm Ltd. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+from .quantize_and_dequantize import quantize, dequantize
diff --git a/python/pyarmnn/src/pyarmnn/_quantization/quantize_and_dequantize.py b/python/pyarmnn/src/pyarmnn/_quantization/quantize_and_dequantize.py
new file mode 100644
index 0000000000..7f06b43bc8
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/_quantization/quantize_and_dequantize.py
@@ -0,0 +1,70 @@
+# Copyright © 2019 Arm Ltd. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+from .._generated.pyarmnn import Quantize_uint8_t, Quantize_int16_t, Quantize_int32_t, \
+ Dequantize_uint8_t, Dequantize_int16_t, Dequantize_int32_t
+
+__dtype_to_quantize_function = {
+ 'uint8': Quantize_uint8_t,
+ 'int16': Quantize_int16_t,
+ 'int32': Quantize_int32_t
+ }
+
+__dtype_to_dequantize_function = {
+ 'uint8': ((0, 255), Dequantize_uint8_t),
+ 'int16': ((-32768, 32767), Dequantize_int16_t),
+ 'int32': ((-2147483648, 2147483647), Dequantize_int32_t)
+ }
+
+
+def quantize(value: float, scale: float, offset: int, target_dtype: str) -> int:
+ """Quantize given value to the given target datatype using Arm NN.
+
+ This function can be used to convert a 32-bit floating point value into 16/32-bit
+ integer or 8-bit unsigned integer values.
+
+ Args:
+ value (float): The value to be quantized.
+ scale (float): A numeric constant that the value is multiplied by.
+ offset (int): A 'zero-point' used to 'shift' the integer range.
+ target_dtype (str): The target data type. Supported values: 'unit8', 'int16', 'int32'.
+
+ Returns:
+ int: A quantized 8-bit unsigned integer value or 16/32-bit integer value.
+ """
+
+ if target_dtype not in __dtype_to_quantize_function:
+ raise ValueError("""Unexpected target datatype {} given.
+ Armnn currently supports quantization to {} values.""".format(target_dtype, list(__dtype_to_quantize_function.keys())))
+
+ return __dtype_to_quantize_function[target_dtype](float(value), scale, offset)
+
+
+def dequantize(value: int, scale: float, offset: float, from_dtype: str) -> float:
+ """Dequantize given value from the given datatype using Armnn.
+
+ This function can be used to convert an 8-bit unsigned integer value or 16/32-bit
+ integer value into a 32-bit floating point value. Typically used when decoding an
+ output value from an output tensor on a quantized model.
+
+ Args:
+ value (int): The value to be dequantized. Value could be numpy numeric data type.
+ scale (float): A numeric constant that the value is multiplied by.
+ offset (float): A 'zero-point' used to 'shift' the integer range.
+ from_dtype (str): The data type 'value' represents. Supported values: 'unit8', 'int16', 'int32'.
+
+ Returns:
+ float: A dequantized 32-bit floating-point value.
+ """
+
+ # specifies which function to use with given datatype and the value range for that data type.
+ if from_dtype not in __dtype_to_dequantize_function:
+ raise ValueError("""Unexpected value datatype {} given.
+ Armnn currently supports dequantization from {} values.""".format(from_dtype, list(__dtype_to_dequantize_function.keys())))
+
+ input_range = __dtype_to_dequantize_function[from_dtype][0]
+
+ if not input_range[0] <= value <= input_range[1]:
+ raise ValueError('Value is not within range of the given datatype {}'.format(from_dtype))
+
+ return __dtype_to_dequantize_function[from_dtype][1](int(value), scale, offset)
diff --git a/python/pyarmnn/src/pyarmnn/_tensor/__init__.py b/python/pyarmnn/src/pyarmnn/_tensor/__init__.py
new file mode 100644
index 0000000000..0c928785b4
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/_tensor/__init__.py
@@ -0,0 +1,6 @@
+# Copyright © 2019 Arm Ltd. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+from .const_tensor import ConstTensor
+from .tensor import Tensor
+from .workload_tensors import make_input_tensors, make_output_tensors, workload_tensors_to_ndarray
diff --git a/python/pyarmnn/src/pyarmnn/_tensor/const_tensor.py b/python/pyarmnn/src/pyarmnn/_tensor/const_tensor.py
new file mode 100644
index 0000000000..9735d7a63b
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/_tensor/const_tensor.py
@@ -0,0 +1,159 @@
+# Copyright © 2019 Arm Ltd. All rights reserved.
+# SPDX-License-Identifier: MIT
+import numpy as np
+
+from .._generated.pyarmnn import DataType_QuantisedAsymm8, DataType_QuantisedSymm16, DataType_Signed32, \
+ DataType_Float32, DataType_Float16
+from .._generated.pyarmnn import ConstTensor as AnnConstTensor, TensorInfo, Tensor
+
+
+class ConstTensor(AnnConstTensor):
+ """Creates a PyArmNN ConstTensor object.
+
+ A ConstTensor is a Tensor with an immutable data store. Typically, a ConstTensor
+ is used to input data into a network when running inference.
+
+ This class overrides the swig generated Tensor class. The aim of
+ this is to have an easy to use public API for the ConstTensor objects.
+
+ """
+
+ def __init__(self, *args):
+ """
+ Supported tensor data types:
+ DataType_QuantisedAsymm8,
+ DataType_QuantisedSymm16,
+ DataType_Signed32,
+ DataType_Float32,
+ DataType_Float16
+
+ Examples:
+ Create empty ConstTensor
+ >>> import pyarmnn as ann
+ >>> ann.ConstTensor()
+
+ Create ConstTensor given tensor info and input data
+ >>> input_data = ... # numpy array
+ >>> ann.ConstTensor(ann.TensorInfo(...), input_data)
+
+ Create ConstTensor from another ConstTensor i.e. copy ConstTensor
+ >>> ann.ConstTensor(ann.ConstTensor())
+
+ Create ConstTensor from tensor
+ >>> ann.ConstTensor(ann.Tensor())
+
+ Args:
+ tensor (Tensor, optional): Create a ConstTensor from a Tensor.
+ const_tensor (ConstTensor, optional): Create a ConstTensor from a ConstTensor i.e. copy.
+ tensor_info (TensorInfo, optional): Tensor information.
+ input_data (ndarray): Numpy array. The numpy array will be transformed to a
+ buffer according to type returned by `TensorInfo.GetDataType`.
+ Input data values type must correspond to data type returned by
+ `TensorInfo.GetDataType`.
+
+ Raises:
+ TypeError: Unsupported input data type.
+ ValueError: Unsupported tensor data type and incorrect input data size.
+ """
+ self.__memory_area = None
+
+ # TensorInfo as first argument and numpy array as second
+ if len(args) > 1 and isinstance(args[0], TensorInfo):
+ if isinstance(args[1], np.ndarray):
+ self.__create_memory_area(args[0].GetDataType(), args[0].GetNumBytes(), args[0].GetNumElements(),
+ args[1])
+ super().__init__(args[0], self.__memory_area.data)
+ else:
+ raise TypeError('Data must be provided as a numpy array.')
+
+ # copy constructor - reference to memory area is passed from copied const
+ # tensor and armnn's copy constructor is called
+ elif len(args) > 0 and isinstance(args[0], (ConstTensor, Tensor)):
+ self.__memory_area = args[0].get_memory_area()
+ super().__init__(args[0])
+
+ # empty tensor
+ elif len(args) == 0:
+ super().__init__()
+
+ else:
+ raise ValueError('Incorrect number of arguments or type of arguments provided to create Const Tensor.')
+
+ def __copy__(self) -> 'ConstTensor':
+ """ Make copy of a const tensor.
+
+ Make const tensor copyable using the python copy operation.
+
+ Note:
+ The tensor memory area is NOT copied. Instead, the new tensor maintains a
+ reference to the same memory area as the old tensor.
+
+ Example:
+ Copy empty tensor
+ >>> from copy import copy
+ >>> import pyarmnn as ann
+ >>> tensor = ann.ConstTensor()
+ >>> copied_tensor = copy(tensor)
+
+ Returns:
+ Tensor: a copy of the tensor object provided.
+
+ """
+ return ConstTensor(self)
+
+ @staticmethod
+ def __check_size(data: np.ndarray, num_bytes: int, num_elements: int):
+ """ Check the size of the input data against the number of bytes provided by tensor info.
+
+ Args:
+ data (ndarray): Input data.
+ num_bytes (int): Number of bytes required by tensor info.
+ num_elements: Number of elements required by tensor info.
+
+ Raises:
+ ValueError: number of bytes in input data does not match tensor info.
+
+ """
+ size_in_bytes = data.nbytes
+ elements = data.size
+
+ if size_in_bytes != num_bytes:
+ raise ValueError(
+ "ConstTensor requires {} bytes, {} provided. "
+ "Is your input array data type ({}) aligned with TensorInfo?".format(num_bytes, size_in_bytes,
+ data.dtype))
+ elif elements != num_elements:
+ raise ValueError("ConstTensor requires {} elements, {} provided.".format(num_elements, elements))
+
+ def __create_memory_area(self, data_type: int, num_bytes: int, num_elements: int, data: np.ndarray):
+ """ Create the memory area used by the tensor to output its results.
+
+ Args:
+ data_type (int): The type of data that will be stored in the memory area.
+ See DataType_*.
+ num_bytes (int): Determines the size of the memory area that will be created.
+ num_elements (int): Determines number of elements in memory area.
+ data (ndarray): Input data as numpy array.
+
+ """
+ np_data_type_mapping = {DataType_QuantisedAsymm8: np.uint8,
+ DataType_Float32: np.float32,
+ DataType_QuantisedSymm16: np.int16,
+ DataType_Signed32: np.int32,
+ DataType_Float16: np.float16}
+
+ if data_type not in np_data_type_mapping:
+ raise ValueError("The data type provided for this Tensor is not supported: {}".format(data_type))
+
+ self.__check_size(data, num_bytes, num_elements)
+ self.__memory_area = data
+ self.__memory_area.flags.writeable = False
+
+ def get_memory_area(self) -> np.ndarray:
+ """ Get values that are stored by the tensor.
+
+ Returns:
+ ndarray: Tensor data (as numpy array).
+
+ """
+ return self.__memory_area
diff --git a/python/pyarmnn/src/pyarmnn/_tensor/tensor.py b/python/pyarmnn/src/pyarmnn/_tensor/tensor.py
new file mode 100644
index 0000000000..5906b6bae6
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/_tensor/tensor.py
@@ -0,0 +1,119 @@
+# Copyright © 2019 Arm Ltd. All rights reserved.
+# SPDX-License-Identifier: MIT
+import numpy as np
+
+from .._generated.pyarmnn import Tensor as annTensor, TensorInfo, DataType_QuantisedAsymm8, \
+ DataType_Float32, DataType_QuantisedSymm16, DataType_Signed32, DataType_Float16
+
+
+class Tensor(annTensor):
+ """pyArmnn Tensor object
+
+ This class overrides the swig generated Tensor class. The aim of
+ this is to create an easy to use public api for the Tensor object.
+
+ Memory is allocated and managed by this class, avoiding the need to manage
+ a separate memory area for the tensor compared to the swig generated api.
+
+ """
+
+ def __init__(self, *args):
+ """ Create Tensor object.
+
+ Supported tensor data types:
+ DataType_QuantisedAsymm8,
+ DataType_QuantisedSymm16,
+ DataType_Signed32,
+ DataType_Float32,
+ DataType_Float16
+
+ Examples:
+ Create an empty tensor
+ >>> import pyarmnn as ann
+ >>> ann.Tensor()
+
+ Create tensor given tensor information
+ >>> ann.Tensor(ann.TensorInfo(...))
+
+ Create tensor from another tensor i.e. copy a tensor
+ >>> ann.Tensor(ann.Tensor())
+
+ Args:
+ tensor(Tensor, optional): Create Tensor from a Tensor i.e. copy.
+ tensor_info (TensorInfo, optional): Tensor information.
+
+ Raises:
+ TypeError: unsupported input data type.
+ ValueError: appropriate constructor could not be found with provided arguments.
+
+ """
+ self.__memory_area = None
+
+ # TensorInfo as first argument, we need to create memory area manually
+ if len(args) > 0 and isinstance(args[0], TensorInfo):
+ self.__create_memory_area(args[0].GetDataType(), args[0].GetNumElements())
+ super().__init__(args[0], self.__memory_area.data)
+
+ # copy constructor - reference to memory area is passed from copied tensor
+ # and armnn's copy constructor is called
+ elif len(args) > 0 and isinstance(args[0], Tensor):
+ self.__memory_area = args[0].get_memory_area()
+ super().__init__(args[0])
+
+ # empty constructor
+ elif len(args) == 0:
+ super().__init__()
+
+ else:
+ raise ValueError('Incorrect number of arguments or type of arguments provided to create Tensor.')
+
+ def __copy__(self) -> 'Tensor':
+ """ Make copy of a tensor.
+
+ Make tensor copyable using the python copy operation.
+
+ Note:
+ The tensor memory area is NOT copied. Instead, the new tensor maintains a
+ reference to the same memory area as the old tensor.
+
+ Example:
+ Copy empty tensor
+ >>> from copy import copy
+ >>> import pyarmnn as ann
+ >>> tensor = ann.Tensor()
+ >>> copied_tensor = copy(tensor)
+
+ Returns:
+ Tensor: a copy of the tensor object provided.
+
+ """
+ return Tensor(self)
+
+ def __create_memory_area(self, data_type: int, num_elements: int):
+ """ Create the memory area used by the tensor to output its results.
+
+ Args:
+ data_type (int): The type of data that will be stored in the memory area.
+ See DataType_*.
+ num_elements (int): Determines the size of the memory area that will be created.
+
+ """
+ np_data_type_mapping = {DataType_QuantisedAsymm8: np.uint8,
+ DataType_Float32: np.float32,
+ DataType_QuantisedSymm16: np.int16,
+ DataType_Signed32: np.int32,
+ DataType_Float16: np.float16}
+
+ if data_type not in np_data_type_mapping:
+ raise ValueError("The data type provided for this Tensor is not supported.")
+
+ self.__memory_area = np.empty(shape=(num_elements,), dtype=np_data_type_mapping[data_type])
+
+ def get_memory_area(self) -> np.ndarray:
+ """ Get values that are stored by the tensor.
+
+ Returns:
+ ndarray : Tensor data (as numpy array).
+
+ """
+ return self.__memory_area
diff --git a/python/pyarmnn/src/pyarmnn/_tensor/workload_tensors.py b/python/pyarmnn/src/pyarmnn/_tensor/workload_tensors.py
new file mode 100644
index 0000000000..e345a1a5d4
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/_tensor/workload_tensors.py
@@ -0,0 +1,123 @@
+# Copyright © 2019 Arm Ltd. All rights reserved.
+# SPDX-License-Identifier: MIT
+"""
+This file contains functions relating to WorkloadTensors.
+WorkloadTensors are the inputTensors and outputTensors that are consumed by IRuntime.EnqueueWorkload.
+"""
+from typing import Union, List, Tuple
+
+import numpy as np
+
+from .tensor import Tensor
+from .const_tensor import ConstTensor
+
+
+def make_input_tensors(inputs_binding_info: List[Tuple],
+ input_data: List[np.ndarray]) -> List[Tuple[int, ConstTensor]]:
+ """Returns `inputTensors` to be used with `IRuntime.EnqueueWorkload`.
+
+ This is the primary function to call when you want to produce `inputTensors` for `IRuntime.EnqueueWorkload`.
+ The output is a list of tuples containing ConstTensors with a corresponding input tensor id.
+ The output should be used directly with `IRuntime.EnqueueWorkload`.
+ This function works for single or multiple input data and binding information.
+
+ Examples:
+ Creating inputTensors.
+ >>> import pyarmnn as ann
+ >>> import numpy as np
+ >>>
+ >>> parser = ann.ITfLiteParser()
+ >>> ...
+ >>> example_image = np.array(...)
+ >>> input_binding_info = parser.GetNetworkInputBindingInfo(...)
+ >>>
+ >>> input_tensors = ann.make_input_tensors([input_binding_info], [example_image])
+
+ Args:
+ inputs_binding_info (list of tuples): (int, `TensorInfo`) Binding information for input tensors obtained from `GetNetworkInputBindingInfo`.
+ input_data (ndarray): Tensor data to be used for inference.
+
+ Returns:
+ list: `inputTensors` - A list of tuples (`int` , `ConstTensor`).
+
+
+ Raises:
+ ValueError: If length of `inputs_binding_info` and `input_data` are not the same.
+ """
+ if len(inputs_binding_info) != len(input_data):
+ raise ValueError("Length of 'inputs_binding_info' does not match length of 'input_data'")
+
+ input_tensors = []
+
+ for in_bind_info, in_data in zip(inputs_binding_info, input_data):
+ in_tensor_id = in_bind_info[0]
+ in_tensor_info = in_bind_info[1]
+ input_tensors.append((in_tensor_id, ConstTensor(in_tensor_info, in_data)))
+
+ return input_tensors
+
+
+def make_output_tensors(outputs_binding_info: List[Tuple]) -> List[Tuple[int, Tensor]]:
+ """Returns `outputTensors` to be used with `IRuntime.EnqueueWorkload`.
+
+ This is the primary function to call when you want to produce `outputTensors` for `IRuntime.EnqueueWorkload`.
+ The output is a list of tuples containing Tensors with a corresponding output tensor id.
+ The output should be used directly with `IRuntime.EnqueueWorkload`.
+
+ Examples:
+ Creating outputTensors.
+ >>> import pyarmnn as ann
+ >>>
+ >>> parser = ann.ITfLiteParser()
+ >>> ...
+ >>> output_binding_info = parser.GetNetworkOutputBindingInfo(...)
+ >>>
+ >>> output_tensors = ann.make_output_tensors([output_binding_info])
+
+ Args:
+ outputs_binding_info (list of tuples): (int, `TensorInfo`) Binding information for output tensors obtained from `GetNetworkOutputBindingInfo`.
+
+ Returns:
+ list: `outputTensors` - A list of tuples (`int`, `Tensor`).
+ """
+ output_tensors = []
+
+ for out_bind_info in outputs_binding_info:
+ out_tensor_id = out_bind_info[0]
+ out_tensor_info = out_bind_info[1]
+ output_tensors.append((out_tensor_id, Tensor(out_tensor_info)))
+
+ return output_tensors
+
+
+def workload_tensors_to_ndarray(workload_tensors: List[Tuple[int, Union[Tensor, ConstTensor]]]) -> List[np.ndarray]:
+ """Returns a list of the underlying tensor data as ndarrays from `inputTensors` or `outputTensors`.
+
+ We refer to `inputTensors` and `outputTensors` as workload tensors because
+ they are used with `IRuntime.EnqueueWorkload`.
+ Although this function can be used on either `inputTensors` or `outputTensors` the main use of this function
+ is to collect results from `outputTensors` after `IRuntime.EnqueueWorkload` has been called.
+
+ Examples:
+ Getting results after inference.
+ >>> import pyarmnn as ann
+ >>>
+ >>> ...
+ >>> runtime = ann.IRuntime(...)
+ >>> ...
+ >>> runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
+ >>>
+ >>> inference_results = tensors_to_ndarray(output_tensors)
+
+ Args:
+ workload_tensors (inputTensors or outputTensors): `inputTensors` or `outputTensors` to get data from.
+
+ Returns:
+ list: List of `ndarrays` for the underlying tensor data from given `inputTensors` or `outputTensors`.
+ """
+ arrays = []
+ for index, (_, tensor) in enumerate(workload_tensors):
+ arrays.append(tensor.get_memory_area())
+ print("Workload tensor {} shape: {}".format(index, tensor.GetShape()))
+
+ return arrays
diff --git a/python/pyarmnn/src/pyarmnn/_utilities/__init__.py b/python/pyarmnn/src/pyarmnn/_utilities/__init__.py
new file mode 100644
index 0000000000..e60fae0880
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/_utilities/__init__.py
@@ -0,0 +1,4 @@
+# Copyright © 2019 Arm Ltd. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+from .profiling_helper import ProfilerData, get_profiling_data
diff --git a/python/pyarmnn/src/pyarmnn/_utilities/profiling_helper.py b/python/pyarmnn/src/pyarmnn/_utilities/profiling_helper.py
new file mode 100644
index 0000000000..d10c28915e
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/_utilities/profiling_helper.py
@@ -0,0 +1,95 @@
+# Copyright © 2019 Arm Ltd. All rights reserved.
+# SPDX-License-Identifier: MIT
+import json
+from collections import namedtuple
+
+ProfilerData = namedtuple('ProfilerData', ['inference_data', 'per_workload_execution_data'])
+ProfilerData.__doc__ = """Container to hold the profiling inference data, and the profiling data per workload.
+
+Contains:
+ inference_data (dict): holds end-to-end inference performance data. Keys:
+ 'time_unit' - timer units.
+ 'execution_time' - list of total inference execution times for each inference run.
+ per_workload_execution_data (dict): holds per operation performance data, key is a operation name
+ Each operation has
+ 'time_unit' - timer units.
+ 'execution_time' - list of total execution times for each inference run.
+ 'backend' - backend used for this operation.
+
+Example:
+
+ >>> data = get_profiling_data(profiler)
+ >>> print(data)
+ >>> ProfilerData(inference_data={'time_unit': 'us',
+ 'execution_time': [8901372.972]},
+ per_workload_execution_data={'CopyMemGeneric_Execute_#3': {'time_unit': 'us',
+ 'execution_time': [28.941],
+ 'backend': 'Unknown'},
+ 'RefConvolution2dWorkload_Execute_#5': {'time_unit': 'us',
+ 'execution_time': [126838.071],
+ 'backend': 'CpuRef'},
+ 'RefDepthwiseConvolution2dWorkload_Execute_#6': {'time_unit': 'us',
+ 'execution_time': [49886.208],
+ 'backend': 'CpuRef'}
+ ...etc
+ }
+ )
+"""
+
+
+def get_profiling_data(profiler: 'IProfiler') -> ProfilerData:
+ """Reads IProfiler object passed in, extracts the relevant data
+ and returns it in a ProfilerData container.
+
+ Args:
+ profile_log (IProfiler): The IProfiler object to be parsed.
+
+ Returns:
+ ProfilerData: A container containing the relevant data extracted from the Profiler output.
+ """
+
+ top_level_dict = json.loads(profiler.as_json())
+ armnn_data = top_level_dict["ArmNN"]
+ inference_measurements = armnn_data["inference_measurements_#1"]
+ execution_data = inference_measurements["Execute_#2"]
+
+ workload_data = {}
+ inference_data = {}
+ for exec_key, exec_value in execution_data.items():
+ # Check all items with a type.
+ if "type" in exec_value and exec_value["type"] == "Event":
+ for event_key, event_value in exec_value.items():
+ if event_key.startswith("Wall clock time_#") and event_value["type"] == "Measurement":
+ time_data = __get_wall_clock_times__(event_value)
+ time_data["backend"] = __get_backend(exec_key)
+ workload_data[exec_key] = time_data
+ # This is the total inference time map
+ if exec_key.startswith("Wall clock time_#") and exec_value["type"] == "Measurement":
+ time_data = __get_wall_clock_times__(exec_value)
+ inference_data.update(time_data)
+ return ProfilerData(inference_data=inference_data, per_workload_execution_data=workload_data)
+
+
+def __get_wall_clock_times__(wall_clock_item):
+ execution_times = wall_clock_item["raw"]
+ time_data = {}
+ raw_data = []
+ for time in execution_times:
+ raw_data.append(time)
+ time_data["time_unit"] = wall_clock_item["unit"]
+ time_data["execution_time"] = raw_data
+ return time_data
+
+
+def __get_backend(exec_key):
+ if "ref" in exec_key.lower():
+ return "CpuRef"
+ elif "neon" in exec_key.lower():
+ return "CpuAcc"
+ elif "cl" in exec_key.lower():
+ return "GpuAcc"
+ elif "npu" in exec_key.lower():
+ return "NpuAcc"
+ else:
+ return "Unknown"
+
diff --git a/python/pyarmnn/src/pyarmnn/_version.py b/python/pyarmnn/src/pyarmnn/_version.py
new file mode 100644
index 0000000000..2bcb888819
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/_version.py
@@ -0,0 +1,26 @@
+# Copyright © 2019 Arm Ltd. All rights reserved.
+# SPDX-License-Identifier: MIT
+import os
+
+version_info = (19, 11, 0)
+
+__dev_version_env = os.getenv("PYARMNN_DEV_VER", "")
+
+if __dev_version_env:
+ __dev_version = "dev0"
+ try:
+ __dev_version = "dev{}".format(int(__dev_version_env))
+ except ValueError:
+ __dev_version = str(__dev_version_env)
+
+ version_info = (*version_info, __dev_version)
+
+__version__ = '.'.join(str(c) for c in version_info)
+__arm_ml_version__ = '2{:03d}{:02d}{:02d}'.format(version_info[0], version_info[1], version_info[2])
+
+
+def check_armnn_version(installed_armnn_version, expected_armnn_version=__arm_ml_version__):
+ expected_armnn_version = expected_armnn_version[:-2] # cut off minor patch version
+ installed_armnn_version = installed_armnn_version[:-2] # cut off minor patch version
+ assert expected_armnn_version == installed_armnn_version, \
+ "Expected ArmNN version is {} but installed ArmNN version is {}".format(expected_armnn_version, installed_armnn_version)
diff --git a/python/pyarmnn/src/pyarmnn/swig/armnn.i b/python/pyarmnn/src/pyarmnn/swig/armnn.i
new file mode 100644
index 0000000000..48e0f2edbb
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/swig/armnn.i
@@ -0,0 +1,27 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+%module pyarmnn
+%{
+#define SWIG_FILE_WITH_INIT
+#include "armnn/Types.hpp"
+%}
+
+//typemap definitions and other common stuff
+%include "standard_header.i"
+
+//armnn api submodules
+%include "modules/armnn_backend.i"
+%include "modules/armnn_types.i"
+%include "modules/armnn_descriptors.i"
+%include "modules/armnn_lstmparam.i"
+%include "modules/armnn_network.i"
+%include "modules/armnn_profiler.i"
+%include "modules/armnn_runtime.i"
+%include "modules/armnn_tensor.i"
+%include "modules/armnn_types_utils.i"
+
+// Clear exception typemap.
+%exception;
+
diff --git a/python/pyarmnn/src/pyarmnn/swig/armnn_caffeparser.i b/python/pyarmnn/src/pyarmnn/swig/armnn_caffeparser.i
new file mode 100644
index 0000000000..fa1a71fd9f
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/swig/armnn_caffeparser.i
@@ -0,0 +1,103 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+%module pyarmnn_caffeparser
+%{
+#define SWIG_FILE_WITH_INIT
+#include "armnnCaffeParser/ICaffeParser.hpp"
+#include "armnn/INetwork.hpp"
+%}
+
+//typemap definitions and other common stuff
+%include "standard_header.i"
+
+namespace std {
+ %template(BindingPointInfo) pair<int, armnn::TensorInfo>;
+ %template(MapStringTensorShape) map<std::string, armnn::TensorShape>;
+ %template(StringVector) vector<string>;
+}
+
+namespace armnnCaffeParser
+{
+
+%feature("docstring",
+"
+Interface for creating a parser object using Caffe (http://caffe.berkeleyvision.org/) caffemodel files.
+
+Parsers are used to automatically construct Arm NN graphs from model files.
+
+") ICaffeParser;
+
+%nodefaultctor ICaffeParser;
+class ICaffeParser
+{
+public:
+ // Documentation
+ %feature("docstring",
+ "
+ Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name.
+
+ Args:
+ name (str): Name of the input.
+
+ Returns:
+ tuple: (`int`, `TensorInfo`)
+ ") GetNetworkInputBindingInfo;
+
+ %feature("docstring",
+ "
+ Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name.
+
+ Args:
+ name (str): Name of the output.
+
+ Returns:
+ tuple: (`int`, `TensorInfo`)
+ ") GetNetworkOutputBindingInfo;
+
+ std::pair<int, armnn::TensorInfo> GetNetworkInputBindingInfo(const std::string& name);
+ std::pair<int, armnn::TensorInfo> GetNetworkOutputBindingInfo(const std::string& name);
+};
+
+%extend ICaffeParser {
+ // This is not a substitution of the default constructor of the Armnn class. It tells swig to create custom __init__
+ // method for ICaffeParser python object that will use static factory method to do the job.
+
+ ICaffeParser() {
+ return armnnCaffeParser::ICaffeParser::CreateRaw();
+ }
+
+ // The following does not replace a real destructor of the Armnn class.
+ // It creates a functions that will be called when swig object goes out of the scope to clean resources.
+ // so the user doesn't need to call ICaffeParser::Destroy himself.
+ // $self` is a pointer to extracted ArmNN ICaffeParser object.
+
+ ~ICaffeParser() {
+ armnnCaffeParser::ICaffeParser::Destroy($self);
+ }
+
+ %feature("docstring",
+ "
+ Create the network from a Caffe caffemodel binary file on disk.
+
+ Args:
+ graphFile: Path to the caffe model to be parsed.
+ inputShapes (tuple): (`string`, `TensorShape`) A tuple containing the input name and TensorShape information for the network.
+ requestedOutputs (list): A list of the output tensor names.
+
+ Returns:
+ INetwork: INetwork object for the parsed Caffe model.
+ ") CreateNetworkFromBinaryFile;
+
+ %newobject CreateNetworkFromBinaryFile;
+ armnn::INetwork* CreateNetworkFromBinaryFile(const char* graphFile,
+ const std::map<std::string, armnn::TensorShape>& inputShapes,
+ const std::vector<std::string>& requestedOutputs) {
+ return $self->CreateNetworkFromBinaryFile(graphFile, inputShapes, requestedOutputs).release();
+ }
+}
+}
+
+// Clear exception typemap.
+%exception;
diff --git a/python/pyarmnn/src/pyarmnn/swig/armnn_onnxparser.i b/python/pyarmnn/src/pyarmnn/swig/armnn_onnxparser.i
new file mode 100644
index 0000000000..e72a425374
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/swig/armnn_onnxparser.i
@@ -0,0 +1,96 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+%module pyarmnn_onnxparser
+%{
+#define SWIG_FILE_WITH_INIT
+#include "armnnOnnxParser/IOnnxParser.hpp"
+#include "armnn/INetwork.hpp"
+%}
+
+//typemap definitions and other common stuff
+%include "standard_header.i"
+
+namespace std {
+ %template(BindingPointInfo) pair<int, armnn::TensorInfo>;
+ %template(MapStringTensorShape) map<std::string, armnn::TensorShape>;
+ %template(StringVector) vector<string>;
+}
+
+namespace armnnOnnxParser
+{
+%feature("docstring",
+"
+Interface for creating a parser object using ONNX (https://onnx.ai/) onnx files.
+
+Parsers are used to automatically construct Arm NN graphs from model files.
+
+") IOnnxParser;
+
+%nodefaultctor IOnnxParser;
+class IOnnxParser
+{
+public:
+ %feature("docstring",
+ "
+ Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name.
+
+ Args:
+ name (string): Name of the input node.
+
+ Returns:
+ tuple: (`int`, `TensorInfo`)
+ ") GetNetworkInputBindingInfo;
+ std::pair<int, armnn::TensorInfo> GetNetworkInputBindingInfo(const std::string& name);
+
+ %feature("docstring",
+ "
+ Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name.
+
+ Args:
+ name (string): Name of the output node.
+
+ Returns:
+ tuple: (`int`, `TensorInfo`)
+ ") GetNetworkOutputBindingInfo;
+ std::pair<int, armnn::TensorInfo> GetNetworkOutputBindingInfo(const std::string& name);
+};
+
+%extend IOnnxParser {
+ // This is not a substitution of the default constructor of the Armnn class. It tells swig to create custom __init__
+ // method for IOnnxParser python object that will use static factory method to do the job.
+ IOnnxParser() {
+ return armnnOnnxParser::IOnnxParser::CreateRaw();
+ }
+
+ // The following does not replace a real destructor of the Armnn class.
+ // It creates a functions that will be called when swig object goes out of the scope to clean resources.
+ // so the user doesn't need to call IOnnxParser::Destroy himself.
+ // $self` is a pointer to extracted ArmNN IOnnxParser object.
+ ~IOnnxParser() {
+ armnnOnnxParser::IOnnxParser::Destroy($self);
+ }
+
+ %feature("docstring",
+ "
+ Create the network from a binary file on disk.
+
+ Args:
+ graphFile (str): Path to the onnx model to be parsed.
+
+ Returns:
+ INetwork: Parsed network.
+
+ Raises:
+ RuntimeError: If model file was not found.
+ ") CreateNetworkFromBinaryFile;
+ %newobject CreateNetworkFromBinaryFile;
+ armnn::INetwork* CreateNetworkFromBinaryFile(const char* graphFile) {
+ return $self->CreateNetworkFromBinaryFile(graphFile).release();
+ }
+}
+
+}
+// Clear exception typemap.
+%exception;
diff --git a/python/pyarmnn/src/pyarmnn/swig/armnn_tfliteparser.i b/python/pyarmnn/src/pyarmnn/swig/armnn_tfliteparser.i
new file mode 100644
index 0000000000..fbe5fd7720
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/swig/armnn_tfliteparser.i
@@ -0,0 +1,132 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+%module pyarmnn_tfliteparser
+%{
+#include "armnnTfLiteParser/ITfLiteParser.hpp"
+#include "armnn/Types.hpp"
+#include "armnn/INetwork.hpp"
+%}
+
+//typemap definitions and other common stuff
+%include "standard_header.i"
+
+namespace std {
+ %template(BindingPointInfo) pair<int, armnn::TensorInfo>;
+ %template(MapStringTensorShape) map<std::string, armnn::TensorShape>;
+ %template(StringVector) vector<string>;
+}
+
+namespace armnnTfLiteParser
+{
+%feature("docstring",
+"
+Interface for creating a parser object using TfLite (https://www.tensorflow.org/lite) tflite files.
+
+Parsers are used to automatically construct Arm NN graphs from model files.
+
+") ITfLiteParser;
+%nodefaultctor ITfLiteParser;
+class ITfLiteParser
+{
+public:
+ %feature("docstring",
+ "
+ Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name and subgraph id.
+ Args:
+ subgraphId (int): The subgraph id.
+ name (str): Name of the input.
+
+ Returns:
+ tuple: (`int`, `TensorInfo`).
+ ") GetNetworkInputBindingInfo;
+ std::pair<int, armnn::TensorInfo> GetNetworkInputBindingInfo(size_t subgraphId, const std::string& name);
+
+ %feature("docstring",
+ "
+ Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name and subgraph id.
+
+ Args:
+ subgraphId (int): The subgraphID.
+ name (str): Name of the output.
+
+ Returns:
+ tuple: (`int`, `TensorInfo`).
+ ") GetNetworkOutputBindingInfo;
+ std::pair<int, armnn::TensorInfo> GetNetworkOutputBindingInfo(size_t subgraphId, const std::string& name);
+
+ %feature("docstring",
+ "
+ Return the number of subgraphs in the parsed model.
+ Returns:
+ int: The number of subgraphs.
+ ") GetSubgraphCount;
+ size_t GetSubgraphCount();
+
+ %feature("docstring",
+ "
+ Return the input tensor names for a given subgraph.
+
+ Args:
+ subgraphId (int): The subgraph id.
+
+ Returns:
+ list: A list of the input tensor names for the given model.
+ ") GetSubgraphInputTensorNames;
+ std::vector<std::string> GetSubgraphInputTensorNames(size_t subgraphId);
+
+ %feature("docstring",
+ "
+ Return the output tensor names for a given subgraph.
+
+ Args:
+ subgraphId (int): The subgraph id
+
+ Returns:
+ list: A list of the output tensor names for the given model.
+ ") GetSubgraphOutputTensorNames;
+ std::vector<std::string> GetSubgraphOutputTensorNames(size_t subgraphId);
+};
+
+%extend ITfLiteParser {
+// This is not a substitution of the default constructor of the Armnn class. It tells swig to create custom __init__
+// method for ITfLiteParser python object that will use static factory method to do the job.
+
+ ITfLiteParser() {
+ return armnnTfLiteParser::ITfLiteParser::CreateRaw();
+ }
+
+// The following does not replace a real destructor of the Armnn class.
+// It creates a functions that will be called when swig object goes out of the scope to clean resources.
+// so the user doesn't need to call ITfLiteParser::Destroy himself.
+// $self` is a pointer to extracted ArmNN ITfLiteParser object.
+
+ ~ITfLiteParser() {
+ armnnTfLiteParser::ITfLiteParser::Destroy($self);
+ }
+
+ %feature("docstring",
+ "
+ Create the network from a flatbuffers binary file.
+
+ Args:
+ graphFile (str): Path to the tflite model to be parsed.
+
+ Returns:
+ INetwork: Parsed network.
+
+ Raises:
+ RuntimeError: If model file was not found.
+ ") CreateNetworkFromBinaryFile;
+
+ %newobject CreateNetworkFromBinaryFile;
+ armnn::INetwork* CreateNetworkFromBinaryFile(const char* graphFile) {
+ return $self->CreateNetworkFromBinaryFile(graphFile).release();
+ }
+
+}
+
+}
+// Clear exception typemap.
+%exception;
diff --git a/python/pyarmnn/src/pyarmnn/swig/armnn_tfparser.i b/python/pyarmnn/src/pyarmnn/swig/armnn_tfparser.i
new file mode 100644
index 0000000000..3438492d26
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/swig/armnn_tfparser.i
@@ -0,0 +1,102 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+%module pyarmnn_tfparser
+%{
+#define SWIG_FILE_WITH_INIT
+#include "armnnTfParser/ITfParser.hpp"
+#include "armnn/INetwork.hpp"
+%}
+
+//typemap definitions and other common stuff
+%include "standard_header.i"
+
+namespace std {
+ %template(BindingPointInfo) pair<int, armnn::TensorInfo>;
+ %template(MapStringTensorShape) map<std::string, armnn::TensorShape>;
+ %template(StringVector) vector<string>;
+}
+
+namespace armnnTfParser
+{
+%feature("docstring",
+"
+Interface for creating a parser object using TensorFlow (https://www.tensorflow.org/) frozen pb files.
+
+Parsers are used to automatically construct Arm NN graphs from model files.
+
+") ITfParser;
+%nodefaultctor ITfParser;
+class ITfParser
+{
+public:
+ %feature("docstring",
+ "
+ Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name.
+
+ Args:
+ name (str): Name of the input.
+
+ Returns:
+ tuple: (`int`, `TensorInfo`).
+ ") GetNetworkInputBindingInfo;
+ std::pair<int, armnn::TensorInfo> GetNetworkInputBindingInfo(const std::string& name);
+
+ %feature("docstring",
+ "
+ Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name.
+
+ Args:
+ name (str): Name of the output.
+
+ Returns:
+ tuple: (`int`, `TensorInfo`).
+ ") GetNetworkOutputBindingInfo;
+ std::pair<int, armnn::TensorInfo> GetNetworkOutputBindingInfo(const std::string& name);
+};
+
+%extend ITfParser {
+ // This is not a substitution of the default constructor of the Armnn class. It tells swig to create custom __init__
+ // method for ITfParser python object that will use static factory method to do the job.
+
+ ITfParser() {
+ return armnnTfParser::ITfParser::CreateRaw();
+ }
+
+ // The following does not replace a real destructor of the Armnn class.
+ // It creates a functions that will be called when swig object goes out of the scope to clean resources.
+ // so the user doesn't need to call ITfParser::Destroy himself.
+ // $self` is a pointer to extracted ArmNN ITfParser object.
+
+ ~ITfParser() {
+ armnnTfParser::ITfParser::Destroy($self);
+ }
+
+ %feature("docstring",
+ "
+ Create the network from a pb Protocol buffer file.
+
+ Args:
+ graphFile (str): Path to the tf model to be parsed.
+ inputShapes (dict): A dict containing the input name as a key & TensorShape as a value.
+ requestedOutputs (list of str): A list of the output tensor names.
+
+ Returns:
+ INetwork: Parsed network.
+
+ Raises:
+ RuntimeError: If model file was not found.
+ ") CreateNetworkFromBinaryFile;
+ %newobject CreateNetworkFromBinaryFile;
+ armnn::INetwork* CreateNetworkFromBinaryFile(const char* graphFile,
+ const std::map<std::string, armnn::TensorShape>& inputShapes,
+ const std::vector<std::string>& requestedOutputs) {
+ return $self->CreateNetworkFromBinaryFile(graphFile, inputShapes, requestedOutputs).release();
+ }
+
+}
+
+}
+// Clear exception typemap.
+%exception;
diff --git a/python/pyarmnn/src/pyarmnn/swig/armnn_version.i b/python/pyarmnn/src/pyarmnn/swig/armnn_version.i
new file mode 100644
index 0000000000..b8e760d435
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/swig/armnn_version.i
@@ -0,0 +1,58 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+%module pyarmnn_version
+
+%include "std_string.i"
+
+%{
+#define SWIG_FILE_WITH_INIT
+#include "armnn/Version.hpp"
+%}
+
+%{
+ std::string GetVersion()
+ {
+ return ARMNN_VERSION;
+ };
+
+ std::string GetMajorVersion()
+ {
+ return STRINGIFY_VALUE(ARMNN_MAJOR_VERSION);
+ };
+
+ std::string GetMinorVersion()
+ {
+ return STRINGIFY_VALUE(ARMNN_MINOR_VERSION);
+ };
+%}
+%feature("docstring",
+"
+ Returns Arm NN library full version: MAJOR + MINOR + INCREMENTAL.
+
+ Returns:
+ str: Full version of Arm NN installed.
+
+") GetVersion;
+std::string GetVersion();
+
+%feature("docstring",
+"
+ Returns Arm NN library major version. The year of the release.
+
+ Returns:
+ str: Major version of Arm NN installed.
+
+") GetMajorVersion;
+std::string GetMajorVersion();
+
+%feature("docstring",
+"
+ Returns Arm NN library minor version. Month of the year of the release.
+
+ Returns:
+ str: Minor version of Arm NN installed.
+
+") GetMinorVersion;
+std::string GetMinorVersion();
diff --git a/python/pyarmnn/src/pyarmnn/swig/modules/armnn_backend.i b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_backend.i
new file mode 100644
index 0000000000..4d13150a19
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_backend.i
@@ -0,0 +1,66 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+%{
+#include "armnn/BackendId.hpp"
+%}
+
+namespace std {
+ %template(BackendIdVector) vector<armnn::BackendId>;
+ %template(BackendIdSet) unordered_set<armnn::BackendId>;
+}
+
+namespace armnn
+{
+
+class BackendId
+{
+public:
+ %feature("docstring",
+ "
+ Creates backend id instance.
+ Supported backend ids: 'CpuRef', 'CpuAcc', 'GpuAcc', 'NpuAcc'.
+
+ Args:
+ id (str): Computation backend identification.
+ ") BackendId;
+
+ BackendId(const std::string& id);
+
+ %feature("docstring",
+ "
+ Checks if backend is cpu reference implementation.
+ Returns:
+ bool: True if backend supports cpu reference implementation, False otherwise.
+
+ ") IsCpuRef;
+ bool IsCpuRef();
+
+ %feature("docstring",
+ "
+ Returns backend identification.
+
+ >>> backendId = BackendId('CpuRef')
+ >>> assert 'CpuRef' == str(backendId)
+ >>> assert 'CpuRef' == backendId.Get()
+
+ Returns:
+ str: Backend identification.
+
+ ") Get;
+ const std::string& Get();
+};
+
+%extend BackendId {
+
+ std::string __str__() {
+ return $self->Get();
+ }
+
+}
+
+using BackendIdVector = std::vector<armnn::BackendId>;
+using BackendIdSet = std::unordered_set<armnn::BackendId>;
+}
diff --git a/python/pyarmnn/src/pyarmnn/swig/modules/armnn_descriptors.i b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_descriptors.i
new file mode 100644
index 0000000000..eb2c8f6278
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_descriptors.i
@@ -0,0 +1,1000 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+%{
+#include "armnn/Descriptors.hpp"
+#include "armnn/Types.hpp"
+%}
+
+namespace std {
+ %template() vector<unsigned int>;
+ %template() vector<int>;
+ %template() vector<pair<unsigned int, unsigned int>>;
+ %template(TensorShapeVector) vector<armnn::TensorShape>;
+}
+
+%include "typemaps/vectors.i"
+
+%typemap(out) const uint32_t*
+%{
+{
+ auto len = arg1->GetNumViews();
+ $result = PyList_New(len);
+ if (!$result) {
+ Py_XDECREF($result);
+ return PyErr_NoMemory();
+ }
+ for (unsigned int i = 0; i < len; ++i) {
+
+ PyList_SetItem($result, i, PyLong_FromUnsignedLong($1[i]));
+ }
+}
+%}
+
+namespace armnn
+{
+
+%list_to_vector( std::vector<unsigned int> );
+%list_to_vector( std::vector<int> );
+%list_to_vector( std::vector<std::pair<unsigned int, unsigned int>> );
+
+%feature("docstring",
+ "
+ A configuration for the Activation layer. See `INetwork.AddActivationLayer()`.
+
+ Contains:
+ m_Function (ActivationFunction): The activation function to use
+ (Sigmoid, TanH, Linear, ReLu, BoundedReLu, SoftReLu, LeakyReLu, Abs, Sqrt, Square).
+ Default: ActivationFunction_Sigmoid.
+ m_A (float): Alpha upper bound value used by the activation functions. (BoundedReLu, Linear, TanH). Default: 0.
+ m_B (float): Beta lower bound value used by the activation functions. (BoundedReLu, Linear, TanH). Default: 0.
+
+ ") ActivationDescriptor;
+struct ActivationDescriptor
+{
+ ActivationDescriptor();
+
+ ActivationFunction m_Function;
+ float m_A;
+ float m_B;
+
+ bool operator ==(const ActivationDescriptor &rhs) const;
+};
+
+
+%feature("docstring",
+ "
+ A descriptor for the ArgMinMax layer. See `INetwork.AddArgMinMaxLayer()`.
+
+ Contains:
+ m_Function (int): Specify if the function is to find Min or Max with ArgMinMaxFunction_Min or ArgMinMaxFunction_Max.
+ Default: ArgMinMaxFunction_Min.
+ m_Axis (int): Axis to reduce across the input tensor. Default: -1.
+
+ ") ArgMinMaxDescriptor;
+struct ArgMinMaxDescriptor
+{
+ ArgMinMaxDescriptor();
+
+ ArgMinMaxFunction m_Function;
+ int m_Axis;
+
+ bool operator ==(const ArgMinMaxDescriptor &rhs) const;
+};
+
+%feature("docstring",
+ "
+ A descriptor for the BatchNormalization layer. See `INetwork.AddBatchNormalizationLayer()`.
+
+ Contains:
+ m_Eps (float): Value to add to the variance. Used to avoid dividing by zero. Default: 0.0001f.
+ m_DataLayout (int): The data layout to be used (DataLayout_NCHW, DataLayout_NHWC). Default: DataLayout_NCHW.
+
+ ") BatchNormalizationDescriptor;
+struct BatchNormalizationDescriptor
+{
+ BatchNormalizationDescriptor();
+
+ float m_Eps;
+ DataLayout m_DataLayout;
+
+ bool operator ==(const BatchNormalizationDescriptor& rhs) const;
+};
+
+%feature("docstring",
+ "
+ A descriptor for the BatchToSpaceNd layer. See `INetwork.AddBatchToSpaceNdLayer()`.
+
+ Contains:
+ m_BlockShape (list of int): Block shape values. Default: (1, 1). Underlying C++ type is unsigned int.
+
+ m_Crops (list of tuple): The values to crop from the input dimension. Default: [(0, 0), (0, 0)].
+
+ m_DataLayout (int): The data layout to be used (DataLayout_NCHW, DataLayout_NHWC). Default: DataLayout_NCHW.
+
+ ") BatchToSpaceNdDescriptor;
+struct BatchToSpaceNdDescriptor
+{
+ BatchToSpaceNdDescriptor();
+ BatchToSpaceNdDescriptor(std::vector<unsigned int> blockShape,
+ std::vector<std::pair<unsigned int, unsigned int>> crops);
+
+ std::vector<unsigned int> m_BlockShape;
+ std::vector<std::pair<unsigned int, unsigned int>> m_Crops;
+ DataLayout m_DataLayout;
+
+ bool operator ==(const BatchToSpaceNdDescriptor& rhs) const;
+};
+
+%feature("docstring",
+ "
+ A descriptor for the Comparison layer. See `INetwork.AddComparisonLayer()`.
+
+ Contains:
+ m_Operation (ComparisonOperation): Specifies the comparison operation to execute.
+ ") ComparisonDescriptor;
+struct ComparisonDescriptor
+{
+ ComparisonDescriptor();
+
+ ComparisonDescriptor(ComparisonOperation operation);
+
+ bool operator ==(const ComparisonDescriptor &rhs) const;
+
+ armnn::ComparisonOperation m_Operation;
+};
+
+%feature("docstring",
+ "
+ Creates a configuration/descriptor for a Concatenation layer. See `INetwork.AddConcatLayer()`.
+ Number of Views must be equal to the number of inputs, and their order must match e.g. first view corresponds to the first input, second view to the second input, etc.
+
+ Contains:
+ numViews (int): Number of views, the value must be equal to the number of outputs of a layer.
+ numDimensions (int): Number of dimensions. Default value is 4.
+
+ ") ConcatDescriptor;
+struct ConcatDescriptor
+{
+ ConcatDescriptor();
+
+ ConcatDescriptor(uint32_t numViews, uint32_t numDimensions = 4);
+
+ %feature("docstring",
+ "
+ Get the number of views.
+ Returns:
+ int: Number of views.
+ ") GetNumViews;
+ uint32_t GetNumViews() const;
+
+ %feature("docstring",
+ "
+ Get the number of dimensions.
+ Returns:
+ int: Number of dimensions.
+ ") GetNumDimensions;
+ uint32_t GetNumDimensions() const;
+
+ %feature("docstring",
+ "
+ Get the view origin input by index.
+
+ Each view match the inputs order, e.g. first view corresponds to the first input, second view to the second input, etc.
+
+ Args:
+ idx (int): Index to get view from.
+
+ Returns:
+ list: View origin (shape) specified by the int value `idx` as a list of ints.
+ ") GetViewOrigin;
+
+ const uint32_t* GetViewOrigin(uint32_t idx) const;
+
+ %feature("docstring",
+ "
+ Set the concatenation dimension.
+ Args:
+ concatAxis (int): Concatenation axis index.
+ ") SetConcatAxis;
+ void SetConcatAxis(unsigned int concatAxis);
+
+ %feature("docstring",
+ "
+ Get the concatenation dimension.
+ Returns:
+ int: Concatenation axis index.
+ ") GetConcatAxis;
+ unsigned int GetConcatAxis() const;
+
+ bool operator ==(const ConcatDescriptor& rhs) const;
+};
+%extend ConcatDescriptor{
+ %feature("docstring",
+ "
+ Set the coordinates of a specific origin view input.
+
+ Args:
+ view (int): Origin view index.
+ coord (int): Coordinate of the origin view to set.
+ value (int): Value to set.
+ Raises:
+ RuntimeError: If the `view` is greater than or equal to GetNumViews().
+ RuntimeError: If the `coord` is greater than or equal to GetNumDimensions().
+ ") SetViewOriginCoord;
+ void SetViewOriginCoord(uint32_t view, uint32_t coord, uint32_t value) {
+ armnn::Status status = $self->SetViewOriginCoord(view, coord, value);
+ if(status == armnn::Status::Failure)
+ {
+ throw armnn::Exception("Failed to set view origin coordinates.");
+ }
+ };
+}
+
+%feature("docstring",
+ "
+ A descriptor for the Convolution2d layer. See `INetwork.AddConvolution2dLayer()`.
+
+ Contains:
+ m_PadLeft (int): Underlying C++ data type is `uint32_t`. Padding left value in the width dimension. Default: 0.
+ m_PadRight (int): Underlying C++ data type is `uint32_t`. Padding right value in the width dimension. Default: 0.
+ m_PadTop (int): Underlying C++ data type is `uint32_t`. Padding top value in the height dimension. Default: 0.
+ m_PadBottom (int): Underlying C++ data type is `uint32_t`. Padding bottom value in the height dimension. Default: 0.
+ m_StrideX (int): Underlying C++ data type is `uint32_t`. Stride value when proceeding through input for the width dimension. Default: 0.
+ m_StrideY (int): Underlying C++ data type is `uint32_t`. Stride value when proceeding through input for the height dimension. Default: 0.
+ m_DilationX (int): Underlying C++ data type is `uint32_t`. Dilation along x axis. Default: 1.
+ m_DilationY (int): Underlying C++ data type is `uint32_t`. Dilation along y axis. Default: 1.
+ m_BiasEnabled (bool): Enable/disable bias. Default: false.
+ m_DataLayout (int): The data layout to be used (DataLayout_NCHW, DataLayout_NHWC). Default: DataLayout_NCHW.
+
+ ") Convolution2dDescriptor;
+struct Convolution2dDescriptor
+{
+ Convolution2dDescriptor();
+
+ uint32_t m_PadLeft;
+ uint32_t m_PadRight;
+ uint32_t m_PadTop;
+ uint32_t m_PadBottom;
+ uint32_t m_StrideX;
+ uint32_t m_StrideY;
+ uint32_t m_DilationX;
+ uint32_t m_DilationY;
+ bool m_BiasEnabled;
+ DataLayout m_DataLayout;
+
+ bool operator ==(const Convolution2dDescriptor& rhs) const;
+};
+
+
+%feature("docstring",
+ "
+ A descriptor for the DepthToSpace layer. See `INetwork.AddDepthToSpaceLayer()`.
+
+ Contains:
+ m_BlockSize (int): Underlying C++ type is `unsigned int`. Scalar specifying the input block size. It must be >= 1. Default: 1.
+ m_DataLayout (int): The data layout to be used (DataLayout_NCHW, DataLayout_NHWC). Default: DataLayout_NHWC.
+
+ ") DepthToSpaceDescriptor;
+struct DepthToSpaceDescriptor
+{
+ DepthToSpaceDescriptor();
+ DepthToSpaceDescriptor(unsigned int blockSize, DataLayout dataLayout);
+
+ unsigned int m_BlockSize;
+ DataLayout m_DataLayout;
+};
+
+
+%feature("docstring",
+ "
+ A descriptor for the DepthwiseConvolution2d layer. See `INetwork.AddDepthwiseConvolution2dLayer()`.
+
+ Contains:
+ m_PadLeft (int): Underlying C++ data type is `uint32_t`. Padding left value in the width dimension. Default: 0.
+ m_PadRight (int): Underlying C++ data type is `uint32_t`. Padding right value in the width dimension. Default: 0.
+ m_PadTop (int): Underlying C++ data type is `uint32_t`. Padding top value in the height dimension. Default: 0.
+ m_PadBottom (int): Underlying C++ data type is `uint32_t`. Padding bottom value in the height dimension. Default: 0.
+ m_StrideX (int): Underlying C++ data type is `uint32_t`. Stride value when proceeding through input for the width dimension. Default: 0.
+ m_StrideY (int): Underlying C++ data type is `uint32_t`. Stride value when proceeding through input for the height dimension. Default: 0.
+ m_DilationX (int): Underlying C++ data type is `uint32_t`. Dilation along x axis. Default: 1.
+ m_DilationY (int): Underlying C++ data type is `uint32_t`. Dilation along y axis. Default: 1.
+ m_BiasEnabled (bool): Enable/disable bias. Default: false.
+ m_DataLayout (int): The data layout to be used (DataLayout_NCHW, DataLayout_NHWC). Default: DataLayout_NCHW.
+
+ ") DepthwiseConvolution2dDescriptor;
+struct DepthwiseConvolution2dDescriptor
+{
+ DepthwiseConvolution2dDescriptor();
+
+ uint32_t m_PadLeft;
+ uint32_t m_PadRight;
+ uint32_t m_PadTop;
+ uint32_t m_PadBottom;
+ uint32_t m_StrideX;
+ uint32_t m_StrideY;
+ uint32_t m_DilationX;
+ uint32_t m_DilationY;
+ bool m_BiasEnabled;
+ DataLayout m_DataLayout;
+
+ bool operator ==(const DepthwiseConvolution2dDescriptor& rhs) const;
+};
+
+%feature("docstring",
+ "
+ A descriptor for the DetectionPostProcess layer. See `INetwork.AddDetectionPostProcessLayer()`.
+
+ This layer is a custom layer used to process the output from SSD MobilenetV1.
+
+ Contains:
+ m_MaxDetections (int): Underlying C++ data type is `uint32_t`. Maximum numbers of detections. Default: 0.
+ m_MaxClassesPerDetection (int): Underlying C++ data type is `uint32_t`. Maximum numbers of classes per detection, used in Fast NMS. Default: 1.
+ m_DetectionsPerClass (int): Underlying C++ data type is `uint32_t`. Detections per classes, used in Regular NMS. Default: 1.
+ m_NmsScoreThreshold (float): Non maximum suppression score threshold. Default: 0.
+ m_NmsIouThreshold (float): Intersection over union threshold. Default: 0.
+ m_NumClasses (int): Underlying C++ data type is `uint32_t`. Number of classes. Default: 0.
+ m_UseRegularNms (bool): Use Regular Non maximum suppression. Default: false.
+ m_ScaleX (float): Center size encoding scale x. Default: 0.
+ m_ScaleY (float): Center size encoding scale y. Default: 0.
+ m_ScaleW (float): Center size encoding scale weight. Default: 0.
+ m_ScaleH (float): Center size encoding scale height. Default: 0.
+
+ ") DetectionPostProcessDescriptor;
+struct DetectionPostProcessDescriptor
+{
+ DetectionPostProcessDescriptor();
+
+ uint32_t m_MaxDetections;
+ uint32_t m_MaxClassesPerDetection;
+ uint32_t m_DetectionsPerClass;
+ float m_NmsScoreThreshold;
+ float m_NmsIouThreshold;
+ uint32_t m_NumClasses;
+ bool m_UseRegularNms;
+ float m_ScaleX;
+ float m_ScaleY;
+ float m_ScaleW;
+ float m_ScaleH;
+
+ bool operator ==(const DetectionPostProcessDescriptor& rhs) const;
+};
+
+%feature("docstring",
+ "
+ A descriptor for the FakeQuantization layer. See ``.
+
+ Contains:
+ m_Min (float): Minimum value for quantization range. Default: -6.0.
+ m_Max (float): Maximum value for quantization range. Default: 6.0.
+
+ ") FakeQuantizationDescriptor;
+struct FakeQuantizationDescriptor
+{
+ FakeQuantizationDescriptor();
+
+ float m_Min;
+ float m_Max;
+
+ bool operator ==(const FakeQuantizationDescriptor& rhs) const;
+};
+
+%feature("docstring",
+ "
+ A descriptor for the FullyConnected layer. See `INetwork.AddFullyConnectedLayer()`.
+
+ Contains:
+ m_BiasEnabled (bool): Enable/disable bias. Default: false.
+ m_TransposeWeightMatrix (bool): Enable/disable transpose weight matrix. Default: false.
+
+ ") FullyConnectedDescriptor;
+struct FullyConnectedDescriptor
+{
+ FullyConnectedDescriptor();
+
+ bool m_BiasEnabled;
+ bool m_TransposeWeightMatrix;
+
+ bool operator ==(const FullyConnectedDescriptor& rhs) const;
+};
+
+%feature("docstring",
+ "
+ A descriptor for InstanceNormalization layer. See `INetwork.AddInstanceNormalizationLayer()`.
+
+ Contains:
+ m_Gamma (float): Gamma, the scale scalar value applied for the normalized tensor. Default: 1.0.
+ m_Gamma (float): Beta, the offset scalar value applied for the normalized tensor. Default: 0.0.
+ m_Gamma (float): Epsilon, small scalar value added to variance to avoid dividing by zero. Default: 1e-12f.
+ m_DataLayout (int): The data layout to be used (DataLayout_NCHW, DataLayout_NHWC). Default: DataLayout_NCHW.
+
+ ") InstanceNormalizationDescriptor;
+struct InstanceNormalizationDescriptor
+{
+ InstanceNormalizationDescriptor();
+
+ float m_Gamma;
+ float m_Beta;
+ float m_Eps;
+ DataLayout m_DataLayout;
+
+ bool operator ==(const InstanceNormalizationDescriptor& rhs) const;
+};
+
+%feature("docstring",
+ "
+ A descriptor for the LSTM layer. See `INetwork.AddLstmLayer()`.
+
+ Contains:
+ m_ActivationFunc (int): Underlying C++ data type is `uint32_t`. The activation function to use. 0: None, 1: Relu, 3: Relu6, 4: Tanh, 6: Sigmoid.
+ Default: 1.
+ m_ClippingThresCell (float): Clipping threshold value for the cell state. Default: 0.0.
+ m_ClippingThresProj (float): Clipping threshold value for the projection. Default: 0.0.
+ m_CifgEnabled (bool): Enable/disable cifg (coupled input & forget gate). Default: true.
+ m_PeepholeEnabled (bool): Enable/disable peephole. Default: false.
+ m_ProjectionEnabled (bool): Enable/disable the projection layer. Default: false.
+ m_LayerNormEnabled (bool): Enable/disable layer normalization. Default: false.
+
+ ") LstmDescriptor;
+struct LstmDescriptor
+{
+ LstmDescriptor();
+
+ uint32_t m_ActivationFunc;
+ float m_ClippingThresCell;
+ float m_ClippingThresProj;
+ bool m_CifgEnabled;
+ bool m_PeepholeEnabled;
+ bool m_ProjectionEnabled;
+ bool m_LayerNormEnabled;
+
+ bool operator ==(const LstmDescriptor& rhs) const;
+};
+
+%feature("docstring",
+ "
+ A Descriptor for the L2Normalization layer. See `INetwork.AddL2NormalizationLayer()`.
+
+ Contains:
+ m_Eps (float): Used to avoid dividing by zero.. Default: 1e-12f.
+ m_DataLayout (int): The data layout to be used (DataLayout_NCHW, DataLayout_NHWC). Default: DataLayout_NCHW.
+
+ ") L2NormalizationDescriptor;
+struct L2NormalizationDescriptor
+{
+ L2NormalizationDescriptor();
+
+ float m_Eps;
+ DataLayout m_DataLayout;
+
+ bool operator ==(const L2NormalizationDescriptor& rhs) const;
+};
+
+%feature("docstring",
+ "
+ A descriptor for the Mean layer. See `INetwork.AddMeanLayer()`.
+
+ Contains:
+ m_Axis (list of int): Underlying C++ data type is std::vector<unsigned int>. Used to avoid dividing by zero. Values for the dimensions to reduce.
+ m_KeepDims (bool): Enable/disable keep dimensions. If true, then the reduced dimensions that are of length 1 are kept. Default: False.
+
+ ") MeanDescriptor;
+struct MeanDescriptor
+{
+ MeanDescriptor();
+ MeanDescriptor(const std::vector<unsigned int>& axis, bool keepDims);
+
+ std::vector<unsigned int> m_Axis;
+ bool m_KeepDims;
+
+ bool operator ==(const MeanDescriptor& rhs) const;
+};
+
+%feature("docstring",
+ "
+ A descriptor for the Normalization layer. See `INetwork.AddNormalizationLayer()`.
+
+ Contains:
+ m_NormChannelType (int): Normalization channel algorithm to use (NormalizationAlgorithmMethod_Across, NormalizationAlgorithmMethod_Within).
+ Default: NormalizationAlgorithmChannel_Across.
+ m_NormMethodType (int): Normalization method algorithm to use (NormalizationAlgorithmMethod_LocalBrightness, NormalizationAlgorithmMethod_LocalContrast).
+ Default: NormalizationAlgorithmMethod_LocalBrightness.
+ m_NormSize (int): Underlying C++ data type is `uint32_t`. Depth radius value. Default: 0.
+ m_Alpha (float): Alpha value for the normalization equation. Default: 0.0.
+ m_Beta (float): Beta value for the normalization equation. Default: 0.0.
+ m_K (float): Kappa value used for the across channel normalization equation. Default: 0.0.
+ m_DataLayout (int): The data layout to be used (DataLayout_NCHW, DataLayout_NHWC). Default: DataLayout_NCHW.
+
+ ") NormalizationDescriptor;
+struct NormalizationDescriptor
+{
+ NormalizationDescriptor();
+
+ NormalizationAlgorithmChannel m_NormChannelType;
+ NormalizationAlgorithmMethod m_NormMethodType;
+ uint32_t m_NormSize;
+ float m_Alpha;
+ float m_Beta;
+ float m_K;
+ DataLayout m_DataLayout;
+
+ bool operator ==(const NormalizationDescriptor& rhs) const;
+};
+
+%feature("docstring",
+ "
+ A descriptor for the Pad layer. See `INetwork.AddPadLayer()`.
+
+ Contains:
+ m_PadList (list of tuple): specifies the padding for input dimension.
+ The first tuple value is the number of values to add before the tensor in the dimension.
+ The second tuple value is the number of values to add after the tensor in the dimension.
+ The number of pairs should match the number of dimensions in the input tensor.
+ m_PadValue (bool): Optional value to use for padding. Default: 0.
+
+ ") PadDescriptor;
+struct PadDescriptor
+{
+ PadDescriptor();
+ PadDescriptor(const std::vector<std::pair<unsigned int, unsigned int>>& padList, const float& padValue = 0);
+
+ std::vector<std::pair<unsigned int, unsigned int>> m_PadList;
+ float m_PadValue;
+
+ bool operator ==(const PadDescriptor& rhs) const;
+};
+
+%feature("docstring",
+ "
+ A descriptor for the Permute layer. See `INetwork.AddPermuteLayer()`.
+
+ Contains:
+ m_DimMappings (PermutationVector): Indicates how to translate tensor elements from a given source into the target destination,
+ when source and target potentially have different memory layouts e.g. {0U, 3U, 1U, 2U}.
+
+ ") PermuteDescriptor;
+struct PermuteDescriptor
+{
+ PermuteDescriptor();
+ PermuteDescriptor(const PermutationVector& dimMappings);
+
+ PermutationVector m_DimMappings;
+
+ bool operator ==(const PermuteDescriptor &rhs) const;
+};
+
+%feature("docstring",
+ "
+ A descriptor for the Pooling2d layer. See `INetwork.AddPooling2dLayer()`.
+
+ Contains:
+ m_PoolType (int): The pooling algorithm to use (`PoolingAlgorithm_Max`, `PoolingAlgorithm_Average`, `PoolingAlgorithm_L2`). Default: `PoolingAlgorithm_Max`.
+ m_PadLeft (int): Underlying C++ data type is `uint32_t`. Padding left value in the width dimension. Default: 0.
+ m_PadRight (int): Underlying C++ data type is `uint32_t`. Padding right value in the width dimension. Default: 0.
+ m_PadTop (int): Underlying C++ data type is `uint32_t`. Padding top value in the height dimension. Default: 0.
+ m_PadBottom (int): Underlying C++ data type is `uint32_t`. Padding bottom value in the height dimension. Default: 0.
+ m_PoolWidth (int): Underlying C++ data type is `uint32_t`. Pooling width value. Default: 0.
+ m_PoolHeight (int): Underlying C++ data type is `uint32_t`. Pooling height value. Default: 0.
+ m_StrideX (int): Underlying C++ data type is `uint32_t`. Stride value when proceeding through input for the width dimension. Default: 0.
+ m_StrideY (int): Underlying C++ data type is `uint32_t`. Stride value when proceeding through input for the height dimension. Default: 0.
+ m_OutputShapeRounding (int): The rounding method for the output shape. (OutputShapeRounding_Floor, OutputShapeRounding_Ceiling).
+ Default: OutputShapeRounding_Floor.
+ m_PaddingMethod (int): The padding method to be used. (PaddingMethod_Exclude, PaddingMethod_IgnoreValue).
+ Default: PaddingMethod_Exclude.
+ m_DataLayout (int): The data layout to be used (DataLayout_NCHW, DataLayout_NHWC). Default: DataLayout_NCHW.
+
+ ") Pooling2dDescriptor;
+struct Pooling2dDescriptor
+{
+ Pooling2dDescriptor();
+
+ PoolingAlgorithm m_PoolType;
+ uint32_t m_PadLeft;
+ uint32_t m_PadRight;
+ uint32_t m_PadTop;
+ uint32_t m_PadBottom;
+ uint32_t m_PoolWidth;
+ uint32_t m_PoolHeight;
+ uint32_t m_StrideX;
+ uint32_t m_StrideY;
+ OutputShapeRounding m_OutputShapeRounding;
+ PaddingMethod m_PaddingMethod;
+ DataLayout m_DataLayout;
+
+ bool operator ==(const Pooling2dDescriptor& rhs) const;
+};
+
+%feature("docstring",
+ "
+ A descriptor for the Reshape layer. See `INetwork.AddReshapeLayer()`.
+
+ Contains:
+ m_TargetShape (TensorShape): Target shape value.
+
+ ") ReshapeDescriptor;
+struct ReshapeDescriptor
+{
+ ReshapeDescriptor();
+ ReshapeDescriptor(const armnn::TensorShape& shape);
+
+ armnn::TensorShape m_TargetShape;
+
+ bool operator ==(const ReshapeDescriptor& rhs) const;
+};
+
+%feature("docstring",
+ "
+ A descriptor for the Resize layer. See `INetwork.AddResizeLayer()`.
+
+ Contains:
+ m_TargetWidth (int): Underlying C++ data type is `uint32_t`. Target width value. Default: 0.
+ m_TargetHeight (int): Underlying C++ data type is `uint32_t`. Target height value. Default: 0.
+ m_Method (int): The Interpolation method to use (ResizeMethod_Bilinear, ResizeMethod_NearestNeighbor).
+ Default: ResizeMethod_NearestNeighbor.
+ m_DataLayout (int): The data layout to be used (DataLayout_NCHW, DataLayout_NHWC). Default: DataLayout_NCHW.
+
+ ") ResizeDescriptor;
+struct ResizeDescriptor
+{
+ ResizeDescriptor();
+
+ uint32_t m_TargetWidth;
+ uint32_t m_TargetHeight;
+ ResizeMethod m_Method;
+ DataLayout m_DataLayout;
+
+ bool operator ==(const ResizeDescriptor& rhs) const;
+};
+
+%feature("docstring",
+ "
+ A descriptor for the Slice layer. See `INetwork.AddSliceLayer()`.
+
+ Contains:
+ m_Begin (list of int): Underlying C++ data type is std::vector<unsigned int>. Beginning indices of the slice in each dimension.
+ m_Size (list of int): Underlying C++ data type is std::vector<unsigned int>. Size of the slice in each dimension.
+
+ ") SliceDescriptor;
+struct SliceDescriptor
+{
+ SliceDescriptor();
+ SliceDescriptor(const std::vector<unsigned int>& begin, const std::vector<unsigned int>& size);
+
+ std::vector<unsigned int> m_Begin;
+ std::vector<unsigned int> m_Size;
+
+ bool operator ==(const SliceDescriptor& rhs) const;
+};
+
+%feature("docstring",
+ "
+ A descriptor for the Space To Batch N-dimensions layer. See `INetwork.AddSpaceToBatchNdLayer()`.
+
+ Contains:
+ m_BlockShape (list of int): Underlying C++ data type is std::vector<unsigned int>. Block shape values. Default: [1, 1].
+ m_Crops (list of tuple): Specifies the padding values for the input dimension:
+ [heightPad - (top, bottom) widthPad - (left, right)].
+ Default: [(0, 0), (0, 0)].
+ m_DataLayout (int): The data layout to be used (DataLayout_NCHW, DataLayout_NHWC). Default: DataLayout_NCHW.
+ ") SpaceToBatchNdDescriptor;
+struct SpaceToBatchNdDescriptor
+{
+ SpaceToBatchNdDescriptor();
+ SpaceToBatchNdDescriptor(const std::vector<unsigned int>& blockShape,
+ const std::vector<std::pair<unsigned int, unsigned int>>& padList);
+
+ std::vector<unsigned int> m_BlockShape;
+ std::vector<std::pair<unsigned int, unsigned int>> m_PadList;
+ DataLayout m_DataLayout;
+
+ bool operator ==(const SpaceToBatchNdDescriptor& rhs) const;
+};
+
+%feature("docstring",
+ "
+ A descriptor for the SpaceToDepth layer. See `INetwork.AddSpaceToDepthLayer()`.
+
+ Contains:
+ m_BlockSize (int): Underlying C++ type is `unsigned int`. Scalar specifying the input block size. It must be >= 1. Default: 1.
+ m_DataLayout (int): The data layout to be used (DataLayout_NCHW, DataLayout_NHWC). Default: DataLayout_NHWC.
+
+ ") SpaceToDepthDescriptor;
+struct SpaceToDepthDescriptor
+{
+ SpaceToDepthDescriptor();
+ SpaceToDepthDescriptor(unsigned int blockSize, DataLayout dataLayout);
+
+ unsigned int m_BlockSize;
+ DataLayout m_DataLayout;
+
+ bool operator ==(const SpaceToDepthDescriptor& rhs) const;
+};
+
+%feature("docstring",
+ "
+ A descriptor for a Splitter layer. See `INetwork.AddSplitterLayer()`.
+
+ Args:
+ numViews (int): Number of views, the value must be equal to the number of outputs of a layer.
+ numDimensions (int): Number of dimensions. Default value is 4.
+
+ ") SplitterDescriptor;
+struct SplitterDescriptor
+{
+
+ SplitterDescriptor(uint32_t numViews, uint32_t numDimensions = 4);
+
+ SplitterDescriptor();
+
+ %feature("docstring",
+ "
+ Get the number of views.
+ Returns:
+ int: number of views.
+ ") GetNumViews;
+ uint32_t GetNumViews() const;
+
+ %feature("docstring",
+ "
+ Get the number of dimensions.
+
+ Returns:
+ int: Number of dimensions.
+
+ ") GetNumDimensions;
+ uint32_t GetNumDimensions() const;
+
+ %feature("docstring",
+ "
+ Get the output view origin (shape) by index, the order matches the outputs.
+
+ e.g. first view corresponds to the first output, second view to the second output, etc.
+ Args:
+ idx (int): Index.
+ Returns:
+ list: View origin (shape) as a list of ints.
+ ") GetViewOrigin;
+
+ const uint32_t* GetViewOrigin(uint32_t idx) const;
+
+ %feature("docstring",
+ "
+ Get the view sizes by index.
+ Args:
+ idx (int): Index.
+ Returns:
+ list: Sizes for the specified index as a list of ints.
+ ") GetViewSizes;
+ const uint32_t* GetViewSizes(uint32_t idx) const;
+
+
+ %feature("docstring",
+ "
+ Get the view origins that describe how the splitting process is configured.
+
+ The number of views is the number of outputs, and their order match.
+ Returns:
+ OriginsDescriptor: A descriptor for the origins view.
+ ") GetOrigins;
+ const ConcatDescriptor GetOrigins() const;
+
+ bool operator ==(const SplitterDescriptor& rhs) const;
+};
+
+%extend SplitterDescriptor{
+ %feature("docstring",
+ "
+ Set the value of a specific origin view input coordinate.
+
+ Contains:
+ view (int): Origin view index.
+ coord (int): Coordinate of the origin view to set.
+ value (int): Value to set.
+ Raises:
+ RuntimeError: If the `view` is greater than or equal to GetNumViews().
+ If the `coord` is greater than or equal to GetNumDimensions().
+ ") SetViewOriginCoord;
+ void SetViewOriginCoord(uint32_t view, uint32_t coord, uint32_t value) {
+ armnn::Status status = $self->SetViewOriginCoord(view, coord, value);
+ if(status == armnn::Status::Failure)
+ {
+ throw armnn::Exception("Failed to set view origin coordinates.");
+ }
+ };
+
+ %feature("docstring",
+ "
+ Set the size of the views.
+
+ Args:
+ view (int): View index.
+ coord (int): Coordinate of the origin view to set.
+ value (int): Value to set.
+ Raises:
+ RuntimeError: If the `view` is greater than or equal to GetNumViews().
+ If the `coord` is greater than or equal to GetNumDimensions().
+ ") SetViewSize;
+ void SetViewSize(uint32_t view, uint32_t coord, uint32_t value) {
+ armnn::Status status = $self->SetViewSize(view, coord, value);
+ if(status == armnn::Status::Failure)
+ {
+ throw armnn::Exception("Failed to set view size.");
+ }
+ }
+}
+
+%feature("docstring",
+ "
+ A descriptor for the Stack layer. See `INetwork.AddStackLayer()`.
+
+ Contains:
+ m_Axis (int): Underlying C++ type is `unsigned int`. 0-based axis along which to stack the input tensors. Default: 0.
+ m_NumInputs (int): Required shape of all input tensors. Default: 0.
+ m_InputShape (TensorShape): Required shape of all input tensors.
+
+ ") StackDescriptor;
+struct StackDescriptor
+{
+ StackDescriptor();
+ StackDescriptor(uint32_t axis, uint32_t numInputs, const armnn::TensorShape& inputShape);
+
+ uint32_t m_Axis;
+ uint32_t m_NumInputs;
+ armnn::TensorShape m_InputShape;
+
+ bool operator ==(const StackDescriptor& rhs) const;
+};
+
+%feature("docstring",
+ "
+ A descriptor for the StandIn layer. See `INetwork.AddStandInLayer()`.
+
+ Contains:
+ m_NumInputs (int): Underlying C++ type is `unsigned int`. Number of input tensors. Default: 0.
+ m_NumOutputs (int): Underlying C++ type is `unsigned int`. Number of output tensors. Default: 0.
+
+ ") StandInDescriptor;
+struct StandInDescriptor
+{
+ StandInDescriptor();
+
+ StandInDescriptor(uint32_t numInputs, uint32_t numOutputs);
+
+ uint32_t m_NumInputs = 0;
+ uint32_t m_NumOutputs = 0;
+};
+
+%feature("docstring",
+ "
+ A descriptor for the StridedSlice layer. See `INetwork.AddStridedSliceLayer()`.
+
+ Contains:
+ m_Begin (list of int): Underlying C++ data type is `std::vector<int>`. Begin values for the input that will be sliced.
+
+ m_End (list of int): Underlying C++ data type is `std::vector<int>`. End values for the input that will be sliced.
+
+ m_Stride (list of int): Underlying C++ data type is `std::vector<int>`. Stride values for the input that will be sliced.
+
+ m_BeginMask (int): Underlying C++ data type is `int32_t`. Begin mask value. If set, then the begin is disregarded and
+ the fullest range is used for the dimension. Default: 0.
+
+ m_EndMask (int): Underlying C++ data type is `int32_t`. End mask value. If set, then the end is disregarded and
+ the fullest range is used for the dimension.Default: 0.
+
+ m_ShrinkAxisMask (int): Underlying C++ data type is `int32_t`. Shrink axis mask value. If set, the nth specification shrinks the dimensionality by 1. Default: 0.
+
+ m_EllipsisMask (int): Underlying C++ data type is `int32_t`. Ellipsis mask value. Default: 0.
+
+ m_NewAxisMask (int): Underlying C++ data type is `int32_t`. New axis mask value. If set, the begin, end and stride is disregarded and
+ a new 1 dimension is inserted to this location of the output tensor. Default: 0.
+
+ m_DataLayout (int): The data layout to be used (DataLayout_NCHW, DataLayout_NHWC). Default: DataLayout_NCHW.
+
+ ") StridedSliceDescriptor;
+struct StridedSliceDescriptor
+{
+ StridedSliceDescriptor();
+ StridedSliceDescriptor(const std::vector<int> begin,
+ const std::vector<int> end,
+ const std::vector<int> stride);
+
+ int GetStartForAxis(const armnn::TensorShape& inputShape, unsigned int axis) const;
+ int GetStopForAxis(const armnn::TensorShape& inputShape, unsigned int axis, int startForAxis) const;
+
+ std::vector<int> m_Begin;
+ std::vector<int> m_End;
+ std::vector<int> m_Stride;
+
+ int32_t m_BeginMask;
+ int32_t m_EndMask;
+ int32_t m_ShrinkAxisMask;
+ int32_t m_EllipsisMask;
+ int32_t m_NewAxisMask;
+ DataLayout m_DataLayout;
+
+ bool operator ==(const StridedSliceDescriptor& rhs) const;
+};
+
+%feature("docstring",
+ "
+ A descriptor for the Softmax layer. See `INetwork.AddSoftmaxLayer()`.
+
+ Contains:
+ m_Beta (float): Exponentiation value.
+ m_Axis (int): Scalar, defaulted to the last index (-1), specifying the dimension the activation will be performed on.
+ ") SoftmaxDescriptor;
+struct SoftmaxDescriptor
+{
+ SoftmaxDescriptor();
+
+ float m_Beta;
+ int m_Axis;
+
+ bool operator ==(const SoftmaxDescriptor& rhs) const;
+};
+
+
+%feature("docstring",
+ "
+ A descriptor for the TransposeConvolution2d layer. See `INetwork.AddTransposeConvolution2dLayer()`.
+
+ Contains:
+ m_PadLeft (int): Underlying C++ data type is `uint32_t`. Padding left value in the width dimension. Default: 0.
+ m_PadRight (int): Underlying C++ data type is `uint32_t`. Padding right value in the width dimension. Default: 0.
+ m_PadTop (int): Underlying C++ data type is `uint32_t`. Padding top value in the height dimension. Default: 0.
+ m_PadBottom (int): Underlying C++ data type is `uint32_t`. Padding bottom value in the height dimension. Default: 0.
+ m_StrideX (int): Underlying C++ data type is `uint32_t`. Stride value when proceeding through input for the width dimension. Default: 0.
+ m_StrideY (int): Underlying C++ data type is `uint32_t`. Stride value when proceeding through input for the height dimension. Default: 0.
+ m_BiasEnabled (bool): Enable/disable bias. Default: false.
+ m_DataLayout (int): The data layout to be used (DataLayout_NCHW, DataLayout_NHWC). Default: DataLayout_NCHW.
+
+ ") TransposeConvolution2dDescriptor;
+struct TransposeConvolution2dDescriptor
+{
+ TransposeConvolution2dDescriptor();
+
+ uint32_t m_PadLeft;
+ uint32_t m_PadRight;
+ uint32_t m_PadTop;
+ uint32_t m_PadBottom;
+ uint32_t m_StrideX;
+ uint32_t m_StrideY;
+ bool m_BiasEnabled;
+ DataLayout m_DataLayout;
+
+ bool operator ==(const TransposeConvolution2dDescriptor& rhs) const;
+};
+
+
+using ConcatDescriptor = OriginsDescriptor;
+using LogSoftmaxDescriptor = SoftmaxDescriptor;
+using SplitterDescriptor = ViewsDescriptor;
+
+%list_to_vector_clear(std::vector<unsigned int>);
+%list_to_vector_clear(std::vector<int>);
+%list_to_vector_clear(std::vector<std::pair<unsigned int, unsigned int>>);
+}
+
+%{
+ armnn::ConcatDescriptor CreateDescriptorForConcatenation(std::vector<armnn::TensorShape> shapes,
+ unsigned int concatenationDimension)
+ {
+ return armnn::CreateDescriptorForConcatenation(shapes.begin(), shapes.end(), concatenationDimension);
+ };
+%}
+
+%feature("docstring",
+ "
+ Create a descriptor for Concatenation layer.
+ Args:
+ shapes (list of TensorShape): Input shapes.
+ concatenationDimension (unsigned int): Concatenation axis.
+
+ Returns:
+ ConcatDescriptor: A descriptor object for a concatenation layer.
+ ") CreateDescriptorForConcatenation;
+armnn::ConcatDescriptor CreateDescriptorForConcatenation(std::vector<armnn::TensorShape> shapes,
+ unsigned int concatenationDimension);
+
+%typemap(out) const uint32_t*;
diff --git a/python/pyarmnn/src/pyarmnn/swig/modules/armnn_lstmparam.i b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_lstmparam.i
new file mode 100644
index 0000000000..a0e993c7ac
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_lstmparam.i
@@ -0,0 +1,97 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+%{
+#include "armnn/LstmParams.hpp"
+%}
+
+namespace armnn
+{
+
+%feature("docstring",
+ "
+ Long Short-Term Memory layer input parameters.
+
+ See `INetwork.AddLstmLayer()`.
+ Operation described by the following equations:
+
+ \[i_t=\sigma(W_{xi}x_t+W_{hi}h_{t-1}+W_{ci}C_{t-1}+b_i) \\\\
+ f_t=\sigma(W_{xf}x_t+W_{hf}h_{t-1}+W_{cf}C_{t-1}+b_f) \\\\
+ C_t=clip(f_t \odot C_{t-1} + i_t \odot g(W_{xc}x_t+W_{hc}h_{t-1}+b_c),\ t_{cell}) \\\\
+ o_t = \sigma(W_{xo}x_t+W_{ho}h_{t-1}+W_{co}C_t+b_o) \\\\
+ h_t = clip(W_{proj}(o_t \odot g(C_t))+b_{proj},\ t_{proj})\ if\ there\ is\ a\ projection; \\\\
+ h_t = o_t \odot g(C_t)\ otherwise. \]
+ Where:
+ \(x_t\) - input;
+ \(i_t\) - input gate;
+ \(f_t\) - forget gate;
+ \(C_t\) - cell state;
+ \(o_t\) - output;
+ \(h_t\) - output state;
+ \(\sigma\) - logistic sigmoid function;
+ \(g\) - cell input and cell output activation function, see `LstmDescriptor.m_ActivationFunc`;
+ \(t_{cell}\) - threshold for clipping the cell state, see `LstmDescriptor.m_ClippingThresCell`;
+ \(t_{proj}\) - threshold for clipping the projected output, see `LstmDescriptor.m_ClippingThresProj`;
+
+ Contains:
+ m_InputToInputWeights (ConstTensor): \(W_{xi}\), input-to-input weight matrix.
+ m_InputToForgetWeights (ConstTensor): \(W_{xf}\), input-to-forget weight matrix.
+ m_InputToCellWeights (ConstTensor): \(W_{xc}\), input-to-cell weight matrix.
+ m_InputToOutputWeights (ConstTensor): \(W_{xo}\), input-to-output weight matrix.
+
+ m_RecurrentToInputWeights (ConstTensor): \(W_{hi}\), recurrent-to-input weight matrix.
+ m_RecurrentToForgetWeights (ConstTensor): \(W_{hf}\), recurrent-to-forget weight matrix.
+ m_RecurrentToCellWeights (ConstTensor): \(W_{hc}\), recurrent-to-cell weight matrix.
+ m_RecurrentToOutputWeights (ConstTensor): \(W_{ho}\), recurrent-to-output weight matrix.
+
+ m_CellToInputWeights (ConstTensor): \(W_{ci}\), cell-to-input weight matrix. Has effect if `LstmDescriptor.m_PeepholeEnabled`.
+ m_CellToForgetWeights (ConstTensor): \(W_{cf}\), cell-to-forget weight matrix. Has effect if `LstmDescriptor.m_PeepholeEnabled`.
+ m_CellToOutputWeights (ConstTensor): \(W_{co}\), cell-to-output weight matrix. Has effect if `LstmDescriptor.m_PeepholeEnabled`.
+
+ m_InputGateBias (ConstTensor): \(b_i\), input gate bias.
+ m_ForgetGateBias (ConstTensor): \(b_f\), forget gate bias.
+ m_CellBias (ConstTensor): \(b_c\), cell bias.
+ m_OutputGateBias (ConstTensor): \(b_o\), output gate bias.
+
+ m_ProjectionWeights (ConstTensor): \(W_{proj}\), projection weight matrix.
+ Has effect if `LstmDescriptor.m_ProjectionEnabled` is set to True.
+ m_ProjectionBias (ConstTensor): \(b_{proj}\), projection bias.
+ Has effect if `LstmDescriptor.m_ProjectionEnabled` is set to True.
+ m_InputLayerNormWeights (ConstTensor): normalisation weights for input,
+ has effect if `LstmDescriptor.m_LayerNormEnabled` set to True.
+ m_ForgetLayerNormWeights (ConstTensor): normalisation weights for forget gate,
+ has effect if `LstmDescriptor.m_LayerNormEnabled` set to True.
+ m_CellLayerNormWeights (ConstTensor): normalisation weights for current cell,
+ has effect if `LstmDescriptor.m_LayerNormEnabled` set to True.
+ m_OutputLayerNormWeights (ConstTensor): normalisation weights for output gate,
+ has effect if `LstmDescriptor.m_LayerNormEnabled` set to True.
+
+ ") LstmInputParams;
+struct LstmInputParams
+{
+ LstmInputParams();
+
+ const armnn::ConstTensor* m_InputToInputWeights;
+ const armnn::ConstTensor* m_InputToForgetWeights;
+ const armnn::ConstTensor* m_InputToCellWeights;
+ const armnn::ConstTensor* m_InputToOutputWeights;
+ const armnn::ConstTensor* m_RecurrentToInputWeights;
+ const armnn::ConstTensor* m_RecurrentToForgetWeights;
+ const armnn::ConstTensor* m_RecurrentToCellWeights;
+ const armnn::ConstTensor* m_RecurrentToOutputWeights;
+ const armnn::ConstTensor* m_CellToInputWeights;
+ const armnn::ConstTensor* m_CellToForgetWeights;
+ const armnn::ConstTensor* m_CellToOutputWeights;
+ const armnn::ConstTensor* m_InputGateBias;
+ const armnn::ConstTensor* m_ForgetGateBias;
+ const armnn::ConstTensor* m_CellBias;
+ const armnn::ConstTensor* m_OutputGateBias;
+ const armnn::ConstTensor* m_ProjectionWeights;
+ const armnn::ConstTensor* m_ProjectionBias;
+ const armnn::ConstTensor* m_InputLayerNormWeights;
+ const armnn::ConstTensor* m_ForgetLayerNormWeights;
+ const armnn::ConstTensor* m_CellLayerNormWeights;
+ const armnn::ConstTensor* m_OutputLayerNormWeights;
+};
+}
diff --git a/python/pyarmnn/src/pyarmnn/swig/modules/armnn_network.i b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_network.i
new file mode 100644
index 0000000000..90454858da
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_network.i
@@ -0,0 +1,1159 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+%{
+#include "armnn/INetwork.hpp"
+#include "armnn/BackendId.hpp"
+#include "armnn/Types.hpp"
+#include "armnn/Optional.hpp"
+#include <fstream>
+%}
+
+%include <typemaps/network_optimize.i>
+
+namespace armnn
+{
+%feature("docstring",
+"
+Struct for holding options relating to the Arm NN optimizer. See `Optimize`.
+
+Contains:
+ m_debug (bool): ...
+ m_ReduceFp32ToFp16 (bool): ...
+
+") OptimizerOptions;
+struct OptimizerOptions
+{
+ OptimizerOptions();
+
+ OptimizerOptions(bool reduceFp32ToFp16, bool debug);
+
+ bool m_ReduceFp32ToFp16;
+ bool m_Debug;
+};
+
+%feature("docstring",
+"
+An input connection slot for a layer. Slot lifecycle is managed by the layer.
+
+The input slot can be connected to an output slot of the preceding layer in the graph.
+Only one connection to the input slot is allowed.
+
+") IInputSlot;
+%nodefaultctor IInputSlot;
+%nodefaultdtor IInputSlot;
+class IInputSlot
+{
+public:
+ %feature("docstring",
+ "
+ Returns output slot of a preceding layer that is connected to the given input slot.
+
+ Returns:
+ IOutputSlot: Borrowed reference to an output connection slot for a preceding layer.
+
+ ") GetConnection;
+
+ armnn::IOutputSlot* GetConnection();
+};
+
+%feature("docstring",
+"
+An output connection slot for a layer. Slot lifecycle is managed by the layer.
+
+The output slot may be connected to 1 or more input slots of subsequent layers in the graph.
+") IOutputSlot;
+%nodefaultctor IOutputSlot;
+%nodefaultdtor IOutputSlot;
+class IOutputSlot
+{
+public:
+
+ %feature("docstring",
+ "
+ Returns the total number of connected input slots.
+
+ The same result could be obtained by calling `len()`:
+
+ >>> output_slot = ...
+ >>> size = len(output_slot)
+ >>> assert size == output_slot.GetNumConnections()
+
+ Returns:
+ int: Number of connected input slots.
+ ") GetNumConnections;
+ unsigned int GetNumConnections();
+
+
+ %feature("docstring",
+ "
+ Retrieves connected input slot by index.
+
+ The same result could be obtained by using square brackets:
+
+ >>> output_slot = ...
+ >>> connected_input_slot = output_slot[0]
+
+ Args:
+ index (int): Slot index.
+
+ Returns:
+ IInputSlot: Borrowed reference to connected input slot with given index.
+
+ Raises:
+ RuntimeError: If index out of bounds.
+ ") GetConnection;
+ armnn::IInputSlot* GetConnection(unsigned int index);
+
+ %feature("docstring",
+ "
+ Sets tensor info for output slot.
+ Operation does not change TensorInfo ownership.
+ Args:
+ tensorInfo (TensorInfo): Output tensor info.
+
+ ") SetTensorInfo;
+ void SetTensorInfo(const armnn::TensorInfo& tensorInfo);
+
+ %feature("docstring",
+ "
+ Gets tensor info for output slot.
+
+ Args:
+ tensorInfo (TensorInfo): Output tensor info.
+
+ ") GetTensorInfo;
+ const armnn::TensorInfo& GetTensorInfo();
+
+ %feature("docstring",
+ "
+ Checks if tensor info was set previously.
+
+ Returns:
+ bool: True if output tensor info was set, False - otherwise.
+
+ ") IsTensorInfoSet;
+ bool IsTensorInfoSet();
+
+ %feature("docstring",
+ "
+ Connects this output slot with given input slot.
+ Input slot is updated with this output connection.
+
+ Args:
+ destination (IInputSlot): Output tensor info.
+
+ Returns:
+ int: Total number of connections.
+
+ Raises:
+ RuntimeError: If input slot was already connected.
+
+ ") Connect;
+ int Connect(IInputSlot& destination);
+
+ %feature("docstring",
+ "
+ Disconnects this output slot from given input slot.
+
+ Args:
+ slot (IInputSlot): Input slot to disconnect from.
+
+ ") Disconnect;
+ void Disconnect(IInputSlot& slot);
+
+ %feature("docstring",
+ "
+ Calculates the index of this slot for the layer.
+
+ Returns:
+ int: Slot index.
+
+ ") CalculateIndexOnOwner;
+ unsigned int CalculateIndexOnOwner();
+
+ %feature("docstring",
+ "
+ Returns the index of the layer. Same value as `IConnectableLayer.GetGuid`.
+
+ Returns:
+ int: Layer id.
+
+ ") GetOwningLayerGuid;
+ unsigned int GetOwningLayerGuid();
+
+};
+
+%extend IOutputSlot {
+
+ armnn::IInputSlot* __getitem__(unsigned int index) {
+ return $self->GetConnection(index);
+ }
+
+ unsigned int __len__() const {
+ return $self->GetNumConnections();
+ }
+
+}
+
+%feature("docstring",
+"
+Interface for a layer that is connectable to other layers via `IInputSlot` and `IOutputSlot`.
+The object implementing this interface is returned by `INetwork` when calling `add*Layer` methods.
+
+") IConnectableLayer;
+%nodefaultctor IConnectableLayer;
+%nodefaultdtor IConnectableLayer;
+class IConnectableLayer
+{
+public:
+ %feature("docstring",
+ "
+ Returns the name of the layer. Name attribute is optional for a layer, thus
+ `None` value could be returned.
+
+ Returns:
+ str: Layer name or `None`.
+
+ ") GetName;
+ const char* GetName();
+
+ %feature("docstring",
+ "
+ Gets the number of input slots for the layer.
+
+ Returns:
+ int: Number of input slots.
+
+ ") GetNumInputSlots;
+ unsigned int GetNumInputSlots();
+
+ %feature("docstring",
+ "
+ Gets the number of output slots for the layer.
+
+ Returns:
+ int: Number of output slots.
+
+ ") GetNumOutputSlots;
+ unsigned int GetNumOutputSlots();
+
+ %feature("docstring",
+ "
+ Gets the input slot by index.
+
+ Args:
+ index (int): Slot index.
+
+ Returns:
+ IInputSlot: Borrowed reference to input slot.
+
+ ") GetInputSlot;
+ armnn::IInputSlot& GetInputSlot(unsigned int index);
+
+ %feature("docstring",
+ "
+ Gets the output slot by index.
+
+ Args:
+ index (int): Slot index.
+
+ Returns:
+ IOutputSlot: Borrowed reference to output slot.
+
+ ") GetOutputSlot;
+ armnn::IOutputSlot& GetOutputSlot(unsigned int index);
+
+
+ %feature("docstring",
+ "
+ Gets the unique layer id (within one process).
+ Guid is generated and assigned automatically when the layer is created.
+
+ Returns:
+ int: The unique layer id.
+
+ ") GetGuid;
+ unsigned int GetGuid();
+};
+
+%feature("docstring",
+ "
+ Interface for a network object. Network objects contain the whole computation graph, made up of different layers connected together.
+
+ INetwork objects can be constructed manually or obtained by using parsers. INetwork objects are used to create optimized networks, see `Optimize`.
+
+ ") INetwork;
+%nodefaultctor INetwork;
+%nodefaultdtor INetwork;
+class INetwork
+{
+public:
+
+ %feature("docstring",
+ "
+ Adds an input layer to the network. Input layers are placed at the start of a network and used for feeding input data during inference.
+
+ Args:
+ id (int): User generated id to uniquely identify a particular input. The same id needs to be specified
+ when passing the inputs to the IRuntime::EnqueueWorkload() function.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddInputLayer;
+ armnn::IConnectableLayer* AddInputLayer(int id, const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds an addition layer to the network.
+
+ Args:
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddAdditionLayer;
+ armnn::IConnectableLayer* AddAdditionLayer(const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds an output layer to the network. Output layer is the final layer in your network.
+
+ Args:
+ id (int): User generated id to uniquely identify a particular input. The same id needs to be specified
+ when passing the inputs to `IRuntime.EnqueueWorkload()`.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddOutputLayer;
+ armnn::IConnectableLayer* AddOutputLayer(int id, const char* name = nullptr);
+
+
+ %feature("docstring",
+ "
+ Add an Absolute layer to the network. Calculates the absolute value of its inputs.
+
+ Args:
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddAbsLayer;
+ armnn::IConnectableLayer* AddAbsLayer(const char* name = nullptr);
+
+
+ %feature("docstring",
+ "
+ Adds an Activation layer to the network. Type of activation is decided by activationDescriptor.
+
+ Args:
+ activationDescriptor (ActivationDescriptor): ActivationDescriptor to configure the activation.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddActivationLayer;
+ armnn::IConnectableLayer* AddActivationLayer(const ActivationDescriptor& activationDescriptor,
+ const char* name = nullptr);
+
+
+ %feature("docstring",
+ "
+ Adds an ArgMinMax layer to the network.
+
+ Args:
+ desc (ArgMinMaxDescriptor): Parameters for the ArgMinMax layer.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddArgMinMaxLayer;
+ armnn::IConnectableLayer* AddArgMinMaxLayer(const armnn::ArgMinMaxDescriptor& desc,
+ const char* name = nullptr);
+
+
+ %feature("docstring",
+ "
+ Adds a Batch Normalization layer to the network.
+
+ Args:
+ mean (ConstTensor): Pre-calculated mean for each channel.
+ variance (ConstTensor): Pre-calculated variance for each channel.
+ beta (ConstTensor): Per-channel additive factor.
+ gamma (ConstTensor): Per-channel multiplicative factor.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddBatchNormalizationLayer;
+ armnn::IConnectableLayer* AddBatchNormalizationLayer(const armnn::BatchNormalizationDescriptor& desc,
+ const armnn::ConstTensor& mean,
+ const armnn::ConstTensor& variance,
+ const armnn::ConstTensor& beta,
+ const armnn::ConstTensor& gamma,
+ const char* name = nullptr);
+
+
+ %feature("docstring",
+ "
+ Adds a Batch To Space ND layer to the network.
+
+ Args:
+ batchToSpaceNdDescriptor (BatchToSpaceNdDescriptor): Configuration parameters for the layer.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddBatchToSpaceNdLayer;
+ armnn::IConnectableLayer* AddBatchToSpaceNdLayer(const armnn::BatchToSpaceNdDescriptor& batchToSpaceNdDescriptor,
+ const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds a Comparison layer to the network.
+
+ Args:
+ comparisonDescriptor (ComparisonDescriptor): Configuration parameters for the layer.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddComparisonLayer;
+ armnn::IConnectableLayer* AddComparisonLayer(const armnn::ComparisonDescriptor& comparisonDescriptor,
+ const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds a Concatenation layer to the network.
+
+ Args:
+ concatDescriptor (ConcatDescriptor): Parameters to configure the Concatenation layer.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddConcatLayer;
+ armnn::IConnectableLayer* AddConcatLayer(const armnn::ConcatDescriptor& concatDescriptor,
+ const char* name = nullptr);
+
+
+ %feature("docstring",
+ "
+ Adds a layer with no inputs and a single output, which always corresponds to the passed in constant tensor.
+
+ Args:
+ input (ConstTensor): Tensor to be provided as the only output of the layer. The layer will maintain
+ its own copy of the tensor data, meaning the memory referenced by input can
+ be freed or reused after this function is called.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddConstantLayer;
+ armnn::IConnectableLayer* AddConstantLayer(const armnn::ConstTensor& input,
+ const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds a Depth To Space layer to the network.
+
+ Args:
+ depthToSpaceDescriptor (DepthToSpaceDescriptor): Parameters for the depth to space operation.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddDepthToSpaceLayer;
+ armnn::IConnectableLayer* AddDepthToSpaceLayer(const armnn::DepthToSpaceDescriptor& depthToSpaceDescriptor,
+ const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds a Dequantize layer to the network.
+
+ Args:
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddDequantizeLayer;
+ armnn::IConnectableLayer* AddDequantizeLayer(const char* name = nullptr);
+
+
+ %feature("docstring",
+ "
+ Adds a Detection PostProcess layer to the network. Detection PostProcess is a custom layer for SSD MobilenetV1.
+
+ Args:
+ descriptor (DetectionPostProcessDescriptor): Description of the Detection PostProcess layer.
+ anchors (ConstTensor): Tensor for anchors.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddDetectionPostProcessLayer;
+ armnn::IConnectableLayer* AddDetectionPostProcessLayer(
+ const armnn::DetectionPostProcessDescriptor& descriptor,
+ const armnn::ConstTensor& anchors,
+ const char* name = nullptr);
+
+
+ %feature("docstring",
+ "
+ Adds a Division layer to the network.
+
+ Args:
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddDivisionLayer;
+ armnn::IConnectableLayer* AddDivisionLayer(const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds a Floor layer to the network.
+
+ Args:
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddFloorLayer;
+ armnn::IConnectableLayer* AddFloorLayer(const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Add Gather layer to the network.
+
+ Args:
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddGatherLayer;
+ armnn::IConnectableLayer* AddGatherLayer(const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds an Instance Normalization layer to the network.
+
+ Args:
+ desc (InstanceNormalizationDescriptor): Parameters for the instance normalization operation.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddInstanceNormalizationLayer;
+ armnn::IConnectableLayer* AddInstanceNormalizationLayer(const armnn::InstanceNormalizationDescriptor& desc,
+ const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds a Log Softmax layer to the network.
+
+ Args:
+ desc (SoftmaxDescriptor): parameters to configure the log softmax.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddLogSoftmaxLayer;
+ armnn::IConnectableLayer* AddLogSoftmaxLayer(const armnn::LogSoftmaxDescriptor& logSoftmaxDescriptor,
+ const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds an L2 Normalization layer to the network.
+ Normalization is performed along dimension 1, but requires a 4d input.
+
+ Args:
+ desc (L2NormalizationDescriptor): Parameters for the L2 normalization operation.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddL2NormalizationLayer;
+ armnn::IConnectableLayer* AddL2NormalizationLayer(const armnn::L2NormalizationDescriptor& desc,
+ const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Add a Long Short-Term Memory layer to the network.
+
+ Args:
+ descriptor (LstmDescriptor): Parameters for the Lstm operation.
+ params (LstmInputParams): Weights and biases for the LSTM cell.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddLstmLayer;
+ armnn::IConnectableLayer* AddLstmLayer(const armnn::LstmDescriptor& descriptor,
+ const armnn::LstmInputParams& params,
+ const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Add a Maximum layer to the network.
+
+ Args:
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddMaximumLayer;
+ armnn::IConnectableLayer* AddMaximumLayer(const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds a Mean layer to the network.
+
+ Args:
+ meanDescriptor (meanDescriptor): Parameters for the mean operation.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddMeanLayer;
+ armnn::IConnectableLayer* AddMeanLayer(const armnn::MeanDescriptor& meanDescriptor, const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds a Merge layer to the network.
+
+ Args:
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddMergeLayer;
+ armnn::IConnectableLayer* AddMergeLayer(const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds a Minimum layer to the network.
+
+ Args:
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddMinimumLayer;
+ armnn::IConnectableLayer* AddMinimumLayer(const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds a Multiplication layer to the network.
+
+ Args:
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddMultiplicationLayer;
+ armnn::IConnectableLayer* AddMultiplicationLayer(const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds a Normalization layer to the network.
+
+ Args:
+ normalizationDescriptor (NormalizationDescriptor): Parameters to configure the normalization.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddNormalizationLayer;
+ armnn::IConnectableLayer* AddNormalizationLayer(const armnn::NormalizationDescriptor& normalizationDescriptor,
+ const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds a Pad layer to the network.
+
+ Args:
+ padDescriptor (PadDescriptor): Padding configuration for the layer. See `PadDescriptor` for more details.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddPadLayer;
+ armnn::IConnectableLayer* AddPadLayer(const armnn::PadDescriptor& padDescriptor,
+ const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds a Permute layer to the network.
+
+ Args:
+ permuteDescriptor (PermuteDescriptor): Configuration of the permutation layer.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddPermuteLayer;
+ armnn::IConnectableLayer* AddPermuteLayer(const armnn::PermuteDescriptor& permuteDescriptor,
+ const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds a Pooling layer to the network. Type of pooling is decided by the configuration.
+
+ Args:
+ pooling2dDescriptor (Pooling2dDescriptor): Configuration for the pooling layer.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddPooling2dLayer;
+ armnn::IConnectableLayer* AddPooling2dLayer(const armnn::Pooling2dDescriptor& pooling2dDescriptor,
+ const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds a PReLU layer to the network.
+
+ Args:
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddPreluLayer;
+ armnn::IConnectableLayer* AddPreluLayer(const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds a Quantize layer to the network.
+
+ Args:
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddQuantizeLayer;
+ armnn::IConnectableLayer* AddQuantizeLayer(const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds a Quantized Long Short-Term Memory layer to the network.
+
+ Args:
+ params (QuantizedLstmInputParams): The weights and biases for the Quantized LSTM cell.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddQuantizedLstmLayer;
+ armnn::IConnectableLayer* AddQuantizedLstmLayer(const armnn::QuantizedLstmInputParams& params,
+ const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds a Reshape layer to the network.
+
+ Args:
+ reshapeDescriptor (ReshapeDescriptor): Parameters for the reshape operation.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddReshapeLayer;
+ armnn::IConnectableLayer* AddReshapeLayer(const armnn::ReshapeDescriptor& reshapeDescriptor,
+ const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds a Resize layer to the network.
+
+ Args:
+ resizeDescriptor (ResizeDescriptor): Configuration for the resize layer.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddResizeLayer;
+ armnn::IConnectableLayer* AddResizeLayer(const armnn::ResizeDescriptor& resizeDescriptor,
+ const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds Reciprocal of square root layer to the network.
+
+ Args:
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddRsqrtLayer;
+ armnn::IConnectableLayer* AddRsqrtLayer(const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds a Slice layer to the network.
+
+ Args:
+ sliceDescriptor (SliceDescriptor): Descriptor to configure the slice operation.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddSliceLayer;
+ armnn::IConnectableLayer* AddSliceLayer(const armnn::SliceDescriptor& sliceDescriptor,
+ const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds a Softmax layer to the network.
+
+ If the data type is `DataType_QuantisedAsymm8`, then the output quantization parameters
+ must have a scale of 1/256 and an offset of 0.
+
+ Args:
+ softmaxDescriptor (SoftmaxDescriptor): Configuration for the softmax layer.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddSoftmaxLayer;
+ armnn::IConnectableLayer* AddSoftmaxLayer(const armnn::SoftmaxDescriptor& softmaxDescriptor,
+ const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds a Space To Batch layer to the network.
+
+ Args:
+ spaceToBatchNdDescriptor (SpaceToBatchNdDescriptor): Configuration for the space to batch layer.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddSpaceToBatchNdLayer;
+ armnn::IConnectableLayer* AddSpaceToBatchNdLayer(const armnn::SpaceToBatchNdDescriptor& spaceToBatchNdDescriptor,
+ const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds a space to depth layer to the network.
+
+ Args:
+ spaceToDepthDescriptor (SpaceToDepthDescriptor): Parameters for the space to depth operation.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddSpaceToDepthLayer;
+ armnn::IConnectableLayer* AddSpaceToDepthLayer(const armnn::SpaceToDepthDescriptor& spaceToDepthDescriptor,
+ const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds a Splitter layer to the network.
+
+ Args:
+ splitterDescriptor (SplitterDescriptor): Parameters to configure the splitter layer.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddSplitterLayer;
+ armnn::IConnectableLayer* AddSplitterLayer(const armnn::SplitterDescriptor& splitterDescriptor,
+ const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds a Stack layer to the network.
+
+ Args:
+ descriptor (StackDescriptor): Descriptor to configure the stack layer.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddStackLayer;
+ armnn::IConnectableLayer* AddStackLayer(const armnn::StackDescriptor& descriptor,
+ const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds a StandIn layer to the network.
+
+ Args:
+ descriptor (StandInDescriptor): Parameters to configure the standIn layer.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddStandInLayer;
+ armnn::IConnectableLayer* AddStandInLayer(const armnn::StandInDescriptor& descriptor,
+ const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds a Strided Slice layer to the network.
+
+ Args:
+ stridedSliceDescriptor (StridedSliceDescriptor): Parameters for the strided slice operation.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddStridedSliceLayer;
+ armnn::IConnectableLayer* AddStridedSliceLayer(const armnn::StridedSliceDescriptor& stridedSliceDescriptor,
+ const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds a Subtraction layer to the network.
+
+ Args:
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddSubtractionLayer;
+ armnn::IConnectableLayer* AddSubtractionLayer(const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds a Switch layer to the network.
+
+ Args:
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddSwitchLayer;
+ armnn::IConnectableLayer* AddSwitchLayer(const char* name = nullptr);
+
+};
+
+%extend INetwork {
+
+ INetwork() {
+ return armnn::INetwork::CreateRaw();
+ }
+
+ ~INetwork() {
+ armnn::INetwork::Destroy($self);
+ }
+
+ %feature("docstring",
+ "
+ Adds a Fully Connected layer to the network. Also known as a Linear or Dense layer.
+
+ Args:
+ fullyConnectedDescriptor (FullyConnectedDescriptor): Description of the fully connected layer.
+ weights (ConstTensor): Tensor for the weights data.
+ biases (ConstTensor): Optional tensor for the bias data.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddFullyConnectedLayer;
+ armnn::IConnectableLayer* AddFullyConnectedLayer(const armnn::FullyConnectedDescriptor& fullyConnectedDescriptor,
+ const armnn::ConstTensor& weights,
+ armnn::ConstTensor* biases = nullptr,
+ const char* name = nullptr) {
+
+ if (biases) {
+ return $self->AddFullyConnectedLayer(fullyConnectedDescriptor, weights,
+ armnn::Optional<armnn::ConstTensor>(*biases), name);
+ } else {
+ return $self->AddFullyConnectedLayer(fullyConnectedDescriptor, weights,
+ armnn::Optional<armnn::ConstTensor>(), name);
+ }
+
+ }
+
+ %feature("docstring",
+ "
+ Adds a 2D Transpose Convolution layer to the network.
+
+ Args:
+ descriptor (TransposeConvolution2dDescriptor): Descriptor containing all parameters to configure this layer.
+ weights (ConstTensor): Tensor for the weights data.
+ biases (ConstTensor): Optional tensor for the bias data.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddTransposeConvolution2dLayer;
+ armnn::IConnectableLayer* AddTransposeConvolution2dLayer(const armnn::TransposeConvolution2dDescriptor& descriptor,
+ const armnn::ConstTensor& weights,
+ armnn::ConstTensor* biases = nullptr,
+ const char* name = nullptr){
+
+ if (biases) {
+ return $self->AddTransposeConvolution2dLayer(descriptor, weights,
+ armnn::Optional<armnn::ConstTensor>(*biases), name);
+ } else {
+ return $self->AddTransposeConvolution2dLayer(descriptor, weights,
+ armnn::Optional<armnn::ConstTensor>(), name);
+ }
+ }
+
+
+ %feature("docstring",
+ "
+ Adds a 2D Convolution layer to the network.
+
+ Args:
+ convolution2dDescriptor (Convolution2dDescriptor): Description of the 2D convolution layer.
+ weights (ConstTensor): Tensor for the weights data.
+ biases (ConstTensor): Optional tensor for the bias data. If specified, must match the output tensor shape.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddConvolution2dLayer;
+ armnn::IConnectableLayer* AddConvolution2dLayer(const armnn::Convolution2dDescriptor& convolution2dDescriptor,
+ const armnn::ConstTensor& weights,
+ armnn::ConstTensor* biases = nullptr,
+ const char* name = nullptr) {
+
+ if (biases) {
+ return $self->AddConvolution2dLayer(convolution2dDescriptor, weights,
+ armnn::Optional<armnn::ConstTensor>(*biases), name);
+ } else {
+ return $self->AddConvolution2dLayer(convolution2dDescriptor, weights,
+ armnn::Optional<armnn::ConstTensor>(), name);
+ }
+ }
+
+ %feature("docstring",
+ "
+ Adds a 2D Depthwise Convolution layer to the network.
+
+ Args:
+ convolution2dDescriptor (DepthwiseConvolution2dDescriptor): Description of the 2D depthwise convolution layer.
+ weights (ConstTensor): Tensor for the weights. Expected format: [channelMultiplier, inputChannels, height, width].
+ biases (ConstTensor): Optional tensor for the bias data. If specified, must match the output tensor shape.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddDepthwiseConvolution2dLayer;
+
+ armnn::IConnectableLayer* AddDepthwiseConvolution2dLayer(
+ const armnn::DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
+ const armnn::ConstTensor& weights,
+ const armnn::ConstTensor* biases = nullptr,
+ const char* name = nullptr) {
+
+ if (biases) {
+ return $self->AddDepthwiseConvolution2dLayer(convolution2dDescriptor, weights,
+ armnn::Optional<armnn::ConstTensor>(*biases), name);
+ } else {
+ return $self->AddDepthwiseConvolution2dLayer(convolution2dDescriptor, weights,
+ armnn::Optional<armnn::ConstTensor>(), name);
+ }
+ }
+}
+
+%feature("docstring",
+ "
+ Interface class for an optimzied network object. Optimized networks are obtained after running `Optimize` on
+ an `INetwork` object.
+ Optimized networks are passed to `EnqueueWorkload`.
+
+ Args:
+ convolution2dDescriptor (DepthwiseConvolution2dDescriptor): Description of the 2D depthwise convolution layer.
+ weights (ConstTensor): Tensor for the weights. Expected format: [channelMultiplier, inputChannels, height, width].
+ biases (ConstTensor): Optional tensor for the bias data. If specified, must match the output tensor shape.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") IOptimizedNetwork;
+%nodefaultctor IOptimizedNetwork;
+%nodefaultdtor IOptimizedNetwork;
+class IOptimizedNetwork
+{
+};
+
+%extend IOptimizedNetwork {
+
+ ~IOptimizedNetwork() {
+ armnn::IOptimizedNetwork::Destroy($self);
+ }
+
+ %feature("docstring",
+ "
+ Saves optimized network graph as dot file.
+
+ Args:
+ fileName (str): File path to save to.
+ Raises:
+ RuntimeError: If serialization failure.
+ ") SerializeToDot;
+
+ void SerializeToDot(const std::string& fileName) {
+ std::ofstream dot;
+ dot.open(fileName);
+ if(!dot.is_open())
+ {
+ throw armnn::Exception("Failed to open dot file");
+ } else {
+ armnn::Status status = $self->SerializeToDot(dot);
+ dot.close();
+ if(status == armnn::Status::Failure)
+ {
+ throw armnn::Exception("Failed to serialize to dot");
+ }
+ }
+ };
+}
+}
+
+%{
+ std::pair<armnn::IOptimizedNetwork*, std::vector<std::string>> Optimize(const armnn::INetwork* network,
+ const std::vector<armnn::BackendId>& backendPreferences,
+ const armnn::IDeviceSpec& deviceSpec,
+ const armnn::OptimizerOptions& options = armnn::OptimizerOptions())
+ {
+ std::vector<std::string> errorMessages;
+ armnn::IOptimizedNetwork* optimizedNetwork = armnn::Optimize(*network, backendPreferences, deviceSpec,
+ options, armnn::Optional<std::vector<std::string> &>(errorMessages)).release();
+
+ if(!optimizedNetwork)
+ {
+ std::string errorString;
+
+ for (auto error : errorMessages) {
+ errorString.append(error);
+ }
+
+ throw armnn::Exception(errorString);
+ }
+
+ return std::make_pair(optimizedNetwork, errorMessages);
+ };
+%}
+
+%feature("docstring",
+ "
+ Create an optimized version of the given network.
+ Args:
+ network (INetwork): INetwork description of the network to be optimized.
+ backendPreferences (list): The choice of the backend ordered by user preferences. See `BackendId`.
+ deviceSpec (IDeviceSpec): DeviceSpec object as queried from the runtime. See `IRuntime.GetDeviceSpec`.
+ options (OptimizerOptions): Object with optimizer configuration options.
+
+ Returns:
+ tuple: (`IOptimizedNetwork`, a tuple of failures or warnings).
+
+ Raises:
+ RuntimeError: If process fails.
+ ") Optimize;
+
+%optimize_typemap_out;
+std::pair<armnn::IOptimizedNetwork*, std::vector<std::string>> Optimize(const armnn::INetwork* network,
+ const std::vector<armnn::BackendId>& backendPreferences,
+ const armnn::IDeviceSpec& deviceSpec,
+ const armnn::OptimizerOptions& options = OptimizerOptions());
+%clear_optimize_typemap_out;
diff --git a/python/pyarmnn/src/pyarmnn/swig/modules/armnn_profiler.i b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_profiler.i
new file mode 100644
index 0000000000..929a7a0006
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_profiler.i
@@ -0,0 +1,82 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+%{
+#include "armnn/IProfiler.hpp"
+%}
+
+namespace armnn
+{
+
+%feature("docstring",
+"
+Interface for profiling Arm NN. See `IRuntime.GetProfiler`.
+
+IProfiler object allows you to enable profiling and get various profiling results.
+
+") IProfiler;
+%nodefaultctor IProfiler;
+%nodefaultdtor IProfiler;
+class IProfiler
+{
+public:
+
+ %feature("docstring",
+ "
+ Sets the profiler to start/stop profiling.
+
+ Args:
+ enableProfiling (bool): Flag to enable/disable profiling.
+
+ ") EnableProfiling;
+
+ void EnableProfiling(bool enableProfiling);
+
+ %feature("docstring",
+ "
+ Checks if profiling is enabled.
+
+ Returns:
+ bool: If profiling is enabled or not.
+
+ ") IsProfilingEnabled;
+
+ bool IsProfilingEnabled();
+};
+
+%extend IProfiler {
+
+ %feature("docstring",
+ "
+ Gets the string value of the profiling events analysis log.
+
+ Returns:
+ str: The profiling events analysis log.
+
+ ") event_log;
+
+ std::string event_log()
+ {
+ std::ostringstream oss;
+ $self->AnalyzeEventsAndWriteResults(oss);
+ return oss.str();
+ }
+
+ %feature("docstring",
+ "
+ Gets the profiling log as the JSON string.
+
+ Returns:
+ str: Profiling log as JSON formatted string.
+
+ ") as_json;
+
+ std::string as_json()
+ {
+ std::ostringstream oss;
+ $self->Print(oss);
+ return oss.str();
+ }
+}
+}
diff --git a/python/pyarmnn/src/pyarmnn/swig/modules/armnn_runtime.i b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_runtime.i
new file mode 100644
index 0000000000..bbeda51d89
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_runtime.i
@@ -0,0 +1,254 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+%{
+#include "armnn/IRuntime.hpp"
+#include <iostream>
+#include <ostream>
+#include <sstream>
+%}
+
+namespace std {
+ %template() pair<int, string>;
+ %template(IntPair) pair<int, int>;
+ %template(ConstTensorPair) pair<int, armnn::ConstTensor>;
+ %template(TensorPair) pair<int, armnn::Tensor>;
+
+ %template(InputTensorsVector) vector<pair<int, armnn::ConstTensor>>;
+ %template(OutputTensorsVector) vector<pair<int, armnn::Tensor>>;
+}
+
+%include <std_shared_ptr.i>
+
+%shared_ptr(IGpuAccTunedParameters);
+
+#pragma SWIG nowarn=SWIGWARN_PARSE_NESTED_CLASS
+
+%{
+typedef armnn::IRuntime::CreationOptions CreationOptions;
+%}
+
+struct CreationOptions
+{
+ %feature("docstring",
+ "
+ Structure for holding creation options. For majority of cases it is fine to leave values at default.
+
+ Contains:
+ m_GpuAccTunedParameters (IGpuAccTunedParameters): If set, uses the GpuAcc tuned parameters from the given object
+ when executing GPU workloads. It will also be updated with new
+ tuned parameters if it is configured to do so.
+
+ m_EnableGpuProfiling (bool): Setting this flag will allow the user to obtain GPU profiling information from
+ the runtime.
+
+ m_DynamicBackendsPath (string): Setting this value will override the paths set by the DYNAMIC_BACKEND_PATHS
+ compiler directive. Only a single path is allowed for the override.
+
+ ") CreationOptions;
+
+ CreationOptions();
+ std::shared_ptr<armnn::IGpuAccTunedParameters> m_GpuAccTunedParameters;
+ bool m_EnableGpuProfiling;
+ std::string m_DynamicBackendsPath;
+};
+
+namespace armnn
+{
+
+struct INetworkProperties
+{
+ %feature("docstring",
+ "
+ Structure for holding network properties.
+
+ Contains:
+ m_ImportEnabled (bool): Enable import.
+
+ m_ExportEnabled (bool): Enable export.
+
+ ") INetworkProperties;
+ INetworkProperties(bool importEnabled = false, bool exportEnabled = false);
+
+ const bool m_ImportEnabled;
+ const bool m_ExportEnabled;
+};
+
+%feature("docstring",
+"
+Interface for runtime objects.
+
+Runtime objects are responsible for performing inference on an `IOptimizedNetwork`.
+
+Args:
+ options (CreationOptions): CreationOptions data struct.
+
+") IRuntime;
+%nodefaultctor IRuntime;
+class IRuntime
+{
+public:
+
+ %ignore
+ armnn::IRuntime::UnloadNetwork(NetworkId networkId);
+
+ %ignore
+ armnn::IRuntime::EnqueueWorkload(NetworkId networkId,
+ const std::vector<std::pair<int, armnn::ConstTensor>>& inputTensors,
+ const std::vector<std::pair<int, armnn::Tensor>>& outputTensors);
+
+ %feature("docstring",
+ "
+ Get information relating to networks input tensor.
+
+ Args:
+ networkId (int): Unique ID of the network being run.
+ layerId (int): Unique ID of the input layer.
+
+ Returns:
+ TensorInfo: Information relating to the input tensor a network.
+ ") GetInputTensorInfo;
+ armnn::TensorInfo GetInputTensorInfo(int networkId, int layerId);
+
+ %feature("docstring",
+ "
+ Get information relating to networks output tensor.
+
+ Args:
+ networkId (int): Unique ID of the network being run.
+ layerId (int): Unique ID of the output layer.
+
+ Returns:
+ TensorInfo: Information relating to the output tensor a network.
+ ") GetOutputTensorInfo;
+ armnn::TensorInfo GetOutputTensorInfo(int networkId, int layerId);
+
+ %feature("docstring",
+ "
+ Get information relating supported compute backends on current device.
+
+ Returns:
+ IDeviceSpec: Device spec information detailing all supported backends on current platform.
+ ") GetDeviceSpec;
+ const IDeviceSpec& GetDeviceSpec();
+};
+
+%extend IRuntime {
+ //tell python to disown the IOptimizedNetwork pointer
+ //because IRuntime takes ownership
+ %typemap(in) armnn::IOptimizedNetwork* {
+ if (!SWIG_IsOK(SWIG_ConvertPtr($input, (void **) &$1, $1_descriptor, SWIG_POINTER_DISOWN))) {
+ SWIG_exception_fail(SWIG_TypeError, "in method '$symname', argument 2 of type armnn::IOptimizedNetwork*");
+ }
+ }
+
+ %feature("docstring",
+ "
+ Loads a complete network into the IRuntime.
+ The runtime takes ownership of the network once passed in.
+ Args:
+ network (IOptimizedNetwork): An optimized network to load into the IRuntime.
+ networkProperties (INetworkProperties): Properties that allows the user to opt-in to import/export behavior. Default: None.
+ Returns:
+ tuple: (int, str) Network id and non fatal failure or warning messsages.
+ Raises:
+ RuntimeError: If process fails.
+ ") LoadNetwork;
+
+ std::pair<int, std::string> LoadNetwork(armnn::IOptimizedNetwork* network,
+ const INetworkProperties* networkProperties = nullptr)
+ {
+ armnn::IOptimizedNetworkPtr netPtr(network, &armnn::IOptimizedNetwork::Destroy);
+ armnn::NetworkId networkIdOut;
+ std::string errorString;
+ armnn::Status status;
+
+ if (networkProperties) {
+ status = $self->LoadNetwork(networkIdOut, std::move(netPtr), errorString, *networkProperties);
+ } else {
+ status = $self->LoadNetwork(networkIdOut, std::move(netPtr), errorString);
+ }
+
+ if(status == armnn::Status::Failure)
+ {
+ throw armnn::Exception(errorString);
+ }
+
+ auto net_id_int = static_cast<int>(networkIdOut);
+ return std::make_pair(net_id_int, errorString);
+ };
+
+ %typemap(in) armnn::IOptimizedNetwork*;
+ %feature("docstring",
+ "
+ Calling this function will perform an inference on your network.
+
+ Args:
+ networkId (int): Unique ID of the network to run.
+ inputTensors (list): A list of tuples (int, ConstTensor), see `make_input_tensors`.
+ outputTensors (list): A list of tuples (int, Tensor), see `make_output_tensors`.
+
+ ") EnqueueWorkload;
+ void EnqueueWorkload(int networkId, const std::vector<std::pair<int, armnn::ConstTensor>>& inputTensors,
+ const std::vector<std::pair<int, armnn::Tensor>>& outputTensors) {
+ armnn::Status status = $self->EnqueueWorkload(networkId, inputTensors, outputTensors);
+
+ if(status == armnn::Status::Failure)
+ {
+ throw armnn::Exception("Failed to enqueue workload for network.");
+ }
+ };
+
+ %feature("docstring",
+ "
+ Unload a currently loaded network from the runtime.
+
+ Args:
+ networkId (int): Unique ID of the network to unload.
+
+ ") UnloadNetwork;
+ void UnloadNetwork(int networkId) {
+ armnn::Status status = $self->UnloadNetwork(networkId);
+ if(status == armnn::Status::Failure)
+ {
+ throw armnn::Exception("Failed to unload network.");
+ }
+ };
+
+ %feature("docstring",
+ "
+ Returns the IProfiler instance registered against the working thread, and stored on the loaded network.
+ Be aware that if the runtime has Unloaded the network, or if the runtime is destroyed,
+ that the IProfiler instance will also be destroyed, and will cause a segmentation fault.
+
+ Args:
+ networkId (int): The ID of the loaded network you want to profile.
+
+ Returns:
+ IProfiler: IProfiler instance the given loaded network has stored.
+
+ Raises:
+ RuntimeError: If no profiler is found.
+ ") GetProfiler;
+
+ armnn::IProfiler* GetProfiler(int networkId) {
+ std::shared_ptr<armnn::IProfiler> profiler = $self->GetProfiler(networkId);
+ if (nullptr == profiler) {
+ throw armnn::Exception("Failed to get profiler");
+ }
+ return profiler.get();
+ };
+
+ ~IRuntime() {
+ armnn::IRuntime::Destroy($self);
+ }
+
+ IRuntime(const CreationOptions& options) {
+ return armnn::IRuntime::CreateRaw(options);
+ }
+
+}
+
+}
+
diff --git a/python/pyarmnn/src/pyarmnn/swig/modules/armnn_tensor.i b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_tensor.i
new file mode 100644
index 0000000000..efa9a16352
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_tensor.i
@@ -0,0 +1,313 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+%{
+#include "armnn/Tensor.hpp"
+%}
+
+%include <typemaps/tensor_memory.i>
+%include <typemaps/tensor_shape.i>
+
+namespace armnn
+{
+
+%feature("docstring",
+"
+Class for holding the shape information of an Arm NN tensor.
+
+") TensorShape;
+class TensorShape
+{
+public:
+ %tensor_shape_typemap(unsigned int numDimensions, const unsigned int* dimensionSizes);
+ TensorShape(unsigned int numDimensions, const unsigned int* dimensionSizes);
+ %clear_tensor_shape_typemap(unsigned int numDimensions, const unsigned int* dimensionSizes);
+
+ %feature("docstring",
+ "
+ Returns the number of dimensions in this TensorShape.
+
+ Returns:
+ int: The number of dimensions in this TensorShape.
+
+ ") GetNumDimensions;
+ unsigned int GetNumDimensions() const;
+
+ %feature("docstring",
+ "
+ Returns the total number of elements for a tensor with this TensorShape.
+
+ Returns:
+ int: The total number of elements for a tensor with this TensorShape.
+
+ ") GetNumElements;
+ unsigned int GetNumElements() const;
+
+};
+
+%extend TensorShape {
+
+ unsigned int __getitem__(unsigned int i) const {
+ return $self->operator[](i);
+ }
+ void __setitem__(unsigned int i, unsigned int val) {
+ $self->operator[](i) = val;
+ }
+
+ std::string __str__() {
+ std::string dim = "NumDimensions: " + std::to_string($self->GetNumDimensions());
+ std::string elm = "NumElements: " + std::to_string($self->GetNumElements());
+
+ std::string shapeStr = "TensorShape{Shape(";
+
+ auto numDimensions = $self->GetNumDimensions();
+ auto sizeDims = $self->GetNumDimensions();
+ for (unsigned int i = 0; i < numDimensions; i++) {
+ shapeStr += std::to_string($self->operator[](i));
+
+ if (sizeDims - 1 > 0) {
+ shapeStr += ", ";
+ }
+ sizeDims--;
+ }
+ shapeStr = shapeStr + "), " + dim + ", " + elm + "}";
+ return shapeStr;
+ }
+
+}
+
+
+%feature("docstring",
+"
+Class for holding the tensor information of an Arm NN tensor such as quantization, datatype, shape etc.
+
+") TensorInfo;
+class TensorInfo
+{
+public:
+ TensorInfo();
+
+ TensorInfo(const TensorInfo& other);
+
+ TensorInfo(const TensorShape& shape, DataType dataType,
+ float quantizationScale = 0.0f, int32_t quantizationOffset = 0);
+
+ %feature("docstring",
+ "
+ Get the tensor shape.
+
+ Return:
+ TensorShape: Current shape of the tensor.
+
+ ") GetShape;
+ TensorShape& GetShape();
+
+ %feature("docstring",
+ "
+ Set the tensor shape. Must have the same number of elements as current tensor.
+
+ Args:
+ newShape (TensorShape): New tensor shape to reshape to.
+
+ ") SetShape;
+ void SetShape(const TensorShape& newShape);
+
+ %feature("docstring",
+ "
+ Returns the number of dimensions in this Tensor.
+
+ Returns:
+ int: The number of dimensions in this Tensor.
+
+ ") GetNumDimensions;
+ unsigned int GetNumDimensions() const;
+
+ %feature("docstring",
+ "
+ Returns the total number of elements for this Tensor.
+
+ Returns:
+ int: The total number of elements for this Tensor.
+
+ ") GetNumElements;
+ unsigned int GetNumElements() const;
+
+ %feature("docstring",
+ "
+ Get the tensor datatype.
+
+ Returns:
+ DataType: Current tensor DataType.
+
+ ") GetDataType;
+ DataType GetDataType() const;
+
+ %feature("docstring",
+ "
+ Set the tensor datatype.
+
+ Args:
+ type (DataType): DataType to set the tensor to.
+
+ ") SetDataType;
+ void SetDataType(DataType type);
+
+ %feature("docstring",
+ "
+ Get the value of the tensors quantization scale.
+
+ Returns:
+ float: Tensor quantization scale value.
+
+ ") GetQuantizationScale;
+ float GetQuantizationScale() const;
+
+ %feature("docstring",
+ "
+ Get the value of the tensors quantization offset.
+
+ Returns:
+ int: Tensor quantization offset value.
+
+ ") GetQuantizationOffset;
+ int32_t GetQuantizationOffset() const;
+
+ %feature("docstring",
+ "
+ Set the value of the tensors quantization scale.
+
+ Args:
+ scale (float): Scale value to set.
+
+ ") SetQuantizationScale;
+ void SetQuantizationScale(float scale);
+
+ %feature("docstring",
+ "
+ Set the value of the tensors quantization offset.
+
+ Args:
+ offset (int): Offset value to set.
+
+ ") SetQuantizationOffset;
+ void SetQuantizationOffset(int32_t offset);
+
+ %feature("docstring",
+ "
+ Returns true if the tensor is a quantized data type.
+
+ Returns:
+ bool: True if the tensor is a quantized data type.
+
+ ") IsQuantized;
+ bool IsQuantized() const;
+
+
+
+ %feature("docstring",
+ "
+ Check that the types are the same and, if quantize, that the quantization parameters are the same.
+
+ Returns:
+ bool: True if matched, else False.
+
+ ") IsTypeSpaceMatch;
+ bool IsTypeSpaceMatch(const TensorInfo& other) const;
+
+ %feature("docstring",
+ "
+ Get the number of bytes needed for this tensor.
+
+ Returns:
+ int: Number of bytes consumed by this tensor.
+
+ ") GetNumBytes;
+ unsigned int GetNumBytes() const;
+
+};
+
+%extend TensorInfo {
+
+ std::string __str__() {
+ const std::string tmp = "TensorInfo{DataType: " + std::to_string(static_cast<int>($self->GetDataType()))
+ + ", IsQuantized: " + std::to_string($self->IsQuantized())
+ + ", QuantizationScale: " + std::to_string( $self->GetQuantizationScale())
+ + ", QuantizationOffset: " + std::to_string($self->GetQuantizationOffset())
+ + ", NumDimensions: " + std::to_string($self->GetNumDimensions())
+ + ", NumElements: " + std::to_string($self->GetNumElements()) + "}";
+ return tmp;
+ }
+
+}
+
+class Tensor
+{
+public:
+ ~Tensor();
+ Tensor();
+ Tensor(const Tensor& other);
+
+ %mutable_memory(void* memory);
+ Tensor(const TensorInfo& info, void* memory);
+ %clear_mutable_memory(void* memory);
+
+ const TensorInfo& GetInfo() const;
+ const TensorShape& GetShape() const;
+
+ DataType GetDataType() const;
+ unsigned int GetNumDimensions() const;
+ unsigned int GetNumBytes() const;
+ unsigned int GetNumElements() const;
+
+ /* we want to disable getting the memory area from here - forcing use of get_memory_area() in public api.
+ void* GetMemoryArea() const;*/
+};
+
+%extend Tensor {
+
+ std::string __str__() {
+ const std::string tmp = "Tensor{DataType: " + std::to_string(static_cast<int>($self->GetDataType()))
+ + ", NumBytes: " + std::to_string($self->GetNumBytes())
+ + ", NumDimensions: " + std::to_string( $self->GetNumDimensions())
+ + ", NumElements: " + std::to_string($self->GetNumElements()) + "}";
+ return tmp;
+ }
+}
+
+class ConstTensor
+{
+public:
+ ~ConstTensor();
+ ConstTensor();
+ ConstTensor(const Tensor& other);
+ ConstTensor(const ConstTensor& other);
+
+ %const_memory(const void* memory);
+ ConstTensor(const TensorInfo& info, const void* memory);
+ %clear_const_memory(const void* memory);
+
+ const TensorInfo& GetInfo() const;
+ const TensorShape& GetShape() const;
+
+ DataType GetDataType() const;
+ unsigned int GetNumDimensions() const;
+ unsigned int GetNumBytes() const;
+ unsigned int GetNumElements() const;
+
+ /* we want to disable getting the memory area from here - forcing use of get_memory_area() in public api.
+ void* GetMemoryArea() const;*/
+};
+
+%extend ConstTensor {
+
+ std::string __str__() {
+ const std::string tmp = "ConstTensor{DataType: " + std::to_string(static_cast<int>($self->GetDataType()))
+ + ", NumBytes: " + std::to_string($self->GetNumBytes())
+ + ", NumDimensions: " + std::to_string( $self->GetNumDimensions())
+ + ", NumElements: " + std::to_string($self->GetNumElements()) + "}";
+ return tmp;
+ }
+}
+
+}
diff --git a/python/pyarmnn/src/pyarmnn/swig/modules/armnn_types.i b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_types.i
new file mode 100644
index 0000000000..50afda9fd3
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_types.i
@@ -0,0 +1,136 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+%{
+#include "armnn/Types.hpp"
+%}
+
+%include <typemaps/permutation_vector.i>
+
+
+namespace armnn
+{
+
+%feature("docstring",
+"
+Vector used to permute a tensor.
+
+For a 4-d tensor laid out in a memory with the format (Batch Element, Height, Width, Channels),
+which is to be passed as an input to Arm NN, each source dimension is mapped to the corresponding
+Arm NN dimension. The Batch dimension remains the same (0 -> 0). The source Height dimension is mapped
+to the location of the ArmNN Height dimension (1 -> 2). Similar arguments are made for the Width and
+Channels (2 -> 3 and 3 -> 1). This will lead to m_DimMappings pointing to the following array:
+[ 0, 2, 3, 1 ].
+
+Note that the mapping should be reversed if considering the case of Arm NN 4-d outputs (Batch Element,
+Channels, Height, Width) being written to a destination with the format mentioned above. We now have
+0 -> 0, 2 -> 1, 3 -> 2, 1 -> 3, which, when reordered, lead to the following m_DimMappings contents:
+[ 0, 3, 1, 2 ].
+
+Args:
+ dimMappings (list): Indicates how to translate tensor elements from a given source into the target destination,
+ when source and target potentially have different memory layouts.
+") PermutationVector;
+
+class PermutationVector
+{
+public:
+ using ValueType = unsigned int;
+ using SizeType = unsigned int;
+
+ %permutation_vector_typemap(const ValueType *dimMappings, SizeType numDimMappings);
+ PermutationVector(const ValueType *dimMappings, SizeType numDimMappings);
+ %clear_permutation_vector_typemap(const ValueType *dimMappings, SizeType numDimMappings);
+
+
+ %feature("docstring",
+ "
+ Get the PermutationVector size.
+
+ Return:
+ SizeType: Current size of the PermutationVector.
+
+ ") GetSize;
+ SizeType GetSize();
+
+ %feature("docstring",
+ "
+ Checks if a specified permutation vector is its inverse
+
+ Return:
+ bool: returns true if the specified Permutation vector is its inverse.
+
+ ") IsInverse;
+ bool IsInverse(const PermutationVector& other);
+};
+
+%extend PermutationVector {
+
+ unsigned int __getitem__(unsigned int i) const {
+ return $self->operator[](i);
+ }
+
+ bool __eq__(PermutationVector other) {
+ int size = $self->GetSize();
+ int otherSize = other.GetSize();
+ if(size != otherSize)
+ {
+ return false;
+ }
+ for(int i = 0; i < size; ++i){
+ if($self->operator[](i) != other[i])
+ {
+ return false;
+ }
+ return true;
+ }
+ return true;
+ }
+}
+
+}
+%feature("docstring",
+"
+Interface for device specifications. Main use is to get information relating to what compute capability the device being used has.
+") IDeviceSpec;
+
+
+%feature("docstring",
+"
+Returns the backends supported by this compute device.
+
+Returns:
+ set: This devices supported backends.
+
+") GetSupportedBackends;
+
+%ignore ProfilingGuid;
+%ignore PermutationVector;
+%include "armnn/Types.hpp"
+
+%extend armnn::IDeviceSpec {
+
+
+ std::string __str__() {
+
+ std::string deviceStr = "IDeviceSpec { supportedBackends: [";
+
+ auto bends = $self->GetSupportedBackends();
+ auto sizeBends = $self->GetSupportedBackends().size();
+ for (std::unordered_set<armnn::BackendId>::const_iterator p = bends.begin(); p != bends.end(); ++p) {
+
+ deviceStr += p->Get();
+
+ if (sizeBends - 1 > 0) {
+ deviceStr += ", ";
+ }
+ sizeBends--;
+
+ }
+ deviceStr = deviceStr + "]}";
+
+ return deviceStr;
+ }
+
+}
diff --git a/python/pyarmnn/src/pyarmnn/swig/modules/armnn_types_utils.i b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_types_utils.i
new file mode 100644
index 0000000000..c11d9927c9
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_types_utils.i
@@ -0,0 +1,26 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+%{
+#include "armnn/TypesUtils.hpp"
+%}
+
+namespace armnn
+{
+ constexpr unsigned int GetDataTypeSize(DataType dataType);
+
+ constexpr const char* GetDataTypeName(DataType dataType);
+
+ template<typename QuantizedType>
+ QuantizedType Quantize(float value, float scale, int32_t offset);
+ %template(Quantize_uint8_t) Quantize<uint8_t>;
+ %template(Quantize_int16_t) Quantize<int16_t>;
+ %template(Quantize_int32_t) Quantize<int32_t>;
+
+ template <typename QuantizedType>
+ float Dequantize(QuantizedType value, float scale, int32_t offset);
+ %template(Dequantize_uint8_t) Dequantize<uint8_t>;
+ %template(Dequantize_int16_t) Dequantize<int16_t>;
+ %template(Dequantize_int32_t) Dequantize<int32_t>;
+}
diff --git a/python/pyarmnn/src/pyarmnn/swig/standard_header.i b/python/pyarmnn/src/pyarmnn/swig/standard_header.i
new file mode 100644
index 0000000000..c412dc3bff
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/swig/standard_header.i
@@ -0,0 +1,53 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+%include "stl.i"
+%include "cstring.i"
+%include "std_string.i"
+%include "std_vector.i"
+%include "std_unordered_set.i"
+%include "std_pair.i"
+%include "stdint.i"
+%include "carrays.i"
+%include "exception.i"
+%include "typemaps.i"
+%include "std_iostream.i"
+
+%ignore *::operator=;
+%ignore *::operator[];
+
+
+// Define exception typemap to wrap armnn exception into python exception.
+
+%exception{
+ try {
+ $action
+ } catch (armnn::Exception &e) {
+ SWIG_exception(SWIG_RuntimeError, const_cast<char*>(e.what()));
+ }
+};
+
+%exception __getitem__ {
+ try {
+ $action
+ } catch (armnn::InvalidArgumentException &e) {
+ SWIG_exception(SWIG_ValueError, const_cast<char*>(e.what()));
+ } catch (const std::out_of_range &e) {
+ SWIG_exception(SWIG_IndexError, const_cast<char*>(e.what()));
+ } catch (const std::exception &e) {
+ SWIG_exception(SWIG_RuntimeError, const_cast<char*>(e.what()));
+ }
+};
+
+%exception __setitem__ {
+ try {
+ $action
+ } catch (armnn::InvalidArgumentException &e) {
+ SWIG_exception(SWIG_ValueError, const_cast<char*>(e.what()));
+ } catch (const std::out_of_range &e) {
+ SWIG_exception(SWIG_IndexError, const_cast<char*>(e.what()));
+ } catch (const std::exception &e) {
+ SWIG_exception(SWIG_RuntimeError, const_cast<char*>(e.what()));
+ }
+};
diff --git a/python/pyarmnn/src/pyarmnn/swig/typemaps/network_optimize.i b/python/pyarmnn/src/pyarmnn/swig/typemaps/network_optimize.i
new file mode 100644
index 0000000000..05df82bdd1
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/swig/typemaps/network_optimize.i
@@ -0,0 +1,41 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+%define %optimize_typemap_out
+ %typemap(out) (std::pair<armnn::IOptimizedNetwork*, std::vector<std::string>>) {
+ PyObject * network = SWIG_NewPointerObj(SWIG_as_voidptr($1.first), SWIGTYPE_p_armnn__IOptimizedNetwork, SWIG_POINTER_OWN);
+ $result = PyTuple_New(2);
+
+ // Convert vector to fixed-size tuple
+ std::vector<std::string> strings = $1.second;
+ Py_ssize_t size = strings.size();
+
+ // New reference. Need to Py_DECREF
+ PyObject* errMsgTuple = PyTuple_New(size);
+
+ if (!errMsgTuple) {
+ Py_XDECREF(errMsgTuple);
+ return PyErr_NoMemory();
+ }
+
+ for (Py_ssize_t i = 0; i < size; i++) {
+ // New reference. Need to Py_DECREF
+ PyObject *string = PyString_FromString(strings[i].c_str());
+
+ if (!string) {
+ Py_XDECREF(string);
+ return PyErr_NoMemory();
+ }
+ PyTuple_SetItem(errMsgTuple, i, string);
+ }
+
+ // Create result tuple
+ PyTuple_SetItem($result, 0, network);
+ PyTuple_SetItem($result, 1, errMsgTuple);
+ }
+%enddef
+
+%define %clear_optimize_typemap_out
+ %typemap(out) (std::pair<armnn::IOptimizedNetwork*, std::vector<std::string>>)
+%enddef
diff --git a/python/pyarmnn/src/pyarmnn/swig/typemaps/permutation_vector.i b/python/pyarmnn/src/pyarmnn/swig/typemaps/permutation_vector.i
new file mode 100644
index 0000000000..daa9663fb1
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/swig/typemaps/permutation_vector.i
@@ -0,0 +1,52 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+%define %permutation_vector_typemap(TYPE1, TYPE2)
+ %typemap(in) (TYPE1, TYPE2) {
+ if (PyTuple_Check($input)) {
+ PyObject* seq = $input;
+
+ $2 = PySequence_Fast_GET_SIZE(seq);
+ $1 = (unsigned int*)PyMem_RawMalloc($2*sizeof(unsigned int));
+
+
+ if(!$1) {
+ PyErr_NoMemory();
+ SWIG_fail;
+ }
+ int size = (int)$2;
+ for(int i=0; i < size; i++) {
+ PyObject *longItem;
+ // Borrowed reference. No need to Py_DECREF
+ PyObject *item = PySequence_Fast_GET_ITEM(seq, i);
+ if(!item) {
+ PyErr_SetString(PyExc_TypeError, "Failed to read data from tuple");
+ SWIG_fail;
+ }
+ // New reference. Need to Py_DECREF
+ longItem = PyNumber_Long(item);
+ if(!longItem) {
+ Py_XDECREF(longItem);
+ PyErr_SetString(PyExc_TypeError, "All elements must be numbers");
+ SWIG_fail;
+ }
+ $1[i] = (unsigned int)PyLong_AsUnsignedLong(longItem);
+ Py_XDECREF(longItem);
+ }
+
+ } else {
+ PyErr_SetString(PyExc_TypeError, "Argument is not a tuple");
+ SWIG_fail;
+ }
+ }
+
+ %typemap(freearg) (TYPE1, TYPE2) {
+ PyMem_RawFree($1);
+ }
+%enddef
+
+%define %clear_permutation_vector_typemap(TYPE1, TYPE2)
+ %typemap(in) (TYPE1, TYPE2);
+ %typemap(freearg) (TYPE1, TYPE2);
+%enddef
diff --git a/python/pyarmnn/src/pyarmnn/swig/typemaps/tensor_memory.i b/python/pyarmnn/src/pyarmnn/swig/typemaps/tensor_memory.i
new file mode 100644
index 0000000000..de38a63b97
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/swig/typemaps/tensor_memory.i
@@ -0,0 +1,52 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+%define %mutable_memory(TYPEMAP)
+ %typemap(in) (TYPEMAP) {
+ int res; void *buf = 0;
+ Py_buffer view;
+ res = PyObject_GetBuffer($input, &view, PyBUF_WRITABLE);
+ buf = view.buf;
+ PyBuffer_Release(&view);
+ if (res < 0) {
+ PyErr_Clear();
+ %argument_fail(res, "(TYPEMAP)", $symname, $argnum);
+ }
+ $1 = buf;
+ }
+
+ %typemap(typecheck) (TYPEMAP) {
+ $1 = PyObject_CheckBuffer($input) || PyTuple_Check($input) ? 1 : 0;
+ }
+%enddef
+
+%define %clear_mutable_memory(TYPEMAP)
+ %typemap(in) (TYPEMAP);
+ %typemap(typecheck) (TYPEMAP);
+%enddef
+
+%define %const_memory(TYPEMAP)
+ %typemap(in) (TYPEMAP) {
+ int res; void *buf = 0;
+ Py_buffer view;
+ res = PyObject_GetBuffer($input, &view, PyBUF_CONTIG_RO);
+ buf = view.buf;
+ PyBuffer_Release(&view);
+ if (res < 0) {
+ PyErr_Clear();
+ %argument_fail(res, "(TYPEMAP)", $symname, $argnum);
+ }
+ $1 = buf;
+ }
+
+ %typemap(typecheck) (TYPEMAP) {
+ $1 = PyObject_CheckBuffer($input) || PyTuple_Check($input) ? 1 : 0;
+ }
+%enddef
+
+%define %clear_const_memory(TYPEMAP)
+ %typemap(in) (TYPEMAP);
+ %typemap(typecheck) (TYPEMAP);
+%enddef
+
diff --git a/python/pyarmnn/src/pyarmnn/swig/typemaps/tensor_shape.i b/python/pyarmnn/src/pyarmnn/swig/typemaps/tensor_shape.i
new file mode 100644
index 0000000000..3e7c98f4c6
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/swig/typemaps/tensor_shape.i
@@ -0,0 +1,51 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+%define %tensor_shape_typemap(TYPE1, TYPE2)
+ %typemap(in) (TYPE1, TYPE2) {
+ if (PyTuple_Check($input)) {
+ PyObject* seq = $input;
+
+ $1 = PySequence_Fast_GET_SIZE(seq);
+ $2 = (unsigned int*)PyMem_RawMalloc($1*sizeof(unsigned int));
+
+ if(!$2) {
+ PyErr_NoMemory();
+ SWIG_fail;
+ }
+ int size = (int)$1;
+ for(int i=0; i < size; i++) {
+ PyObject *longItem;
+ // Borrowed reference. No need to Py_DECREF
+ PyObject *item = PySequence_Fast_GET_ITEM(seq, i);
+ if(!item) {
+ PyErr_SetString(PyExc_TypeError, "Failed to read data from tuple");
+ SWIG_fail;
+ }
+ // New reference. Need to Py_DECREF
+ longItem = PyNumber_Long(item);
+ if(!longItem) {
+ Py_XDECREF(longItem);
+ PyErr_SetString(PyExc_TypeError, "All elements must be numbers");
+ SWIG_fail;
+ }
+ $2[i] = (unsigned int)PyLong_AsUnsignedLong(longItem);
+ Py_XDECREF(longItem);
+ }
+
+ } else {
+ PyErr_SetString(PyExc_TypeError, "Argument is not a tuple");
+ SWIG_fail;
+ }
+ }
+
+ %typemap(freearg) (TYPE1, TYPE2) {
+ PyMem_RawFree($2);
+ }
+%enddef
+
+%define %clear_tensor_shape_typemap(TYPE1, TYPE2)
+ %typemap(in) (TYPE1, TYPE2);
+ %typemap(freearg) (TYPE1, TYPE2);
+%enddef
diff --git a/python/pyarmnn/src/pyarmnn/swig/typemaps/vectors.i b/python/pyarmnn/src/pyarmnn/swig/typemaps/vectors.i
new file mode 100644
index 0000000000..1566bb0c3b
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/swig/typemaps/vectors.i
@@ -0,0 +1,235 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+%inline %{
+//-------------------------from_python_to_cpp-----------------------------
+ int from_python_to_cpp(PyObject *obj, long* val) {
+ return SWIG_AsVal_long(obj, val);
+ }
+
+ int from_python_to_cpp(PyObject *obj, int* val) {
+ return SWIG_AsVal_int(obj, val);
+ }
+
+ int from_python_to_cpp(PyObject *obj, unsigned int* val) {
+ return SWIG_AsVal_unsigned_SS_int(obj, val);
+ }
+
+ int from_python_to_cpp(PyObject *obj, unsigned short* val) {
+ return SWIG_AsVal_unsigned_SS_short(obj, val);
+ }
+
+ int from_python_to_cpp(PyObject *obj, float* val) {
+ return SWIG_AsVal_float(obj, val);
+ }
+
+ int from_python_to_cpp(PyObject *obj, double* val) {
+ return SWIG_AsVal_double(obj, val);
+ }
+#ifdef SWIG_LONG_LONG_AVAILABLE
+ int from_python_to_cpp(PyObject *obj, unsigned long long* val) {
+ return SWIG_AsVal_unsigned_SS_long_SS_long(obj, val);
+ }
+
+ int from_python_to_cpp(PyObject *obj, long long* val) {
+ return SWIG_AsVal_long_SS_long(obj, val);
+ }
+#endif
+
+ int from_python_to_cpp(PyObject *obj, unsigned long* val) {
+ return SWIG_AsVal_unsigned_SS_long(obj, val);
+ }
+
+ int from_python_to_cpp(PyObject *obj, short* val) {
+ return SWIG_AsVal_short(obj, val);
+ }
+//-------------------------from_cpp_to_python-----------------------------
+ PyObject* from_cpp_to_python(long& val){
+ return PyLong_FromLong(val);
+ }
+
+ PyObject* from_cpp_to_python(unsigned long& val){
+ return PyLong_FromUnsignedLong(val);
+ }
+#ifdef SWIG_LONG_LONG_AVAILABLE
+ PyObject* from_cpp_to_python(long long& val){
+ return PyLong_FromLongLong(val);
+ }
+
+ PyObject* from_cpp_to_python(unsigned long long& val){
+ return PyLong_FromUnsignedLongLong(val);
+ }
+#endif
+
+ PyObject* from_cpp_to_python(int& val){
+ return PyLong_FromLong(static_cast<long>(val));
+ }
+
+ PyObject* from_cpp_to_python(unsigned int& val){
+ return PyLong_FromUnsignedLong(static_cast<unsigned long>(val));
+ }
+
+ PyObject* from_cpp_to_python(unsigned short& val){
+ return PyLong_FromUnsignedLong(static_cast<unsigned long>(val));
+ }
+
+ PyObject* from_cpp_to_python(float& val){
+ return PyFloat_FromDouble(static_cast<double>(val));
+ }
+
+ PyObject* from_cpp_to_python(double& val){
+ return PyFloat_FromDouble(val);
+ }
+
+ template<class U, class V>
+ PyObject* from_cpp_to_python(std::pair<U, V>& pair){
+
+ PyObject* first = from_cpp_to_python(pair.first);
+ PyObject* second = from_cpp_to_python(pair.second);
+
+ PyObject* localTuple = PyTuple_New(2);
+
+ if (!localTuple) {
+ Py_XDECREF(localTuple);
+ return PyErr_NoMemory();
+ }
+
+ PyTuple_SetItem(localTuple, 0, first);
+ PyTuple_SetItem(localTuple, 1, second);
+
+ return localTuple;
+ }
+
+ template<class K, class V>
+ static int from_python_to_cpp(PyObject* tuple, std::pair<K,V>* out) {
+
+ if (PyTuple_Check(tuple)) {
+
+ auto size = PyTuple_Size(tuple);
+
+ if (size != 2) {
+ return SWIG_ValueError;
+ }
+
+ PyObject* firstPy = PyTuple_GetItem(tuple, 0);
+ PyObject* secondPy = PyTuple_GetItem(tuple, 1);
+
+ if (!SWIG_IsOK(from_python_to_cpp(firstPy, &out->first))) {
+ return SWIG_TypeError;
+ }
+
+ if (!SWIG_IsOK(from_python_to_cpp(secondPy, &out->second))) {
+ return SWIG_TypeError;
+ }
+
+ } else {
+ return SWIG_TypeError;
+ }
+
+ return SWIG_OK;
+ }
+//---------------std::vector <-> python list ---------------------
+ template<class T>
+ static PyObject* from_vector_to_python(std::vector<T>* input) {
+ Py_ssize_t size = input->size();
+ PyObject* localList = PyList_New(size);
+
+ if (!localList) {
+ Py_XDECREF(localList);
+ return PyErr_NoMemory();
+ }
+
+ for(Py_ssize_t i = 0; i < size; ++i) {
+
+ PyObject* obj = from_cpp_to_python(input->at(i));
+
+ PyList_SET_ITEM(localList, i, obj);
+ }
+ return localList;
+ }
+
+ template<class T>
+ int from_python_to_vector(PyObject* seq, std::vector<T>& out) {
+ Py_ssize_t size = PySequence_Fast_GET_SIZE(seq);
+
+ for(Py_ssize_t i=0; i < size; i++) {
+ PyObject *item = PySequence_Fast_GET_ITEM(seq, i);
+ if(!item) {
+ PyErr_SetString(PyExc_TypeError, "Failed to read data from given sequence");
+
+ return SWIG_NullReferenceError;
+ }
+
+ T element;
+ int res = from_python_to_cpp(item, &element);
+ if (!SWIG_IsOK(res)) {
+ PyObject* itemRepr = PyObject_Repr(item);
+ PyObject* itemStrObj = PyUnicode_AsEncodedString(itemRepr, "utf-8", "replace");
+ const char* itemStr = PyBytes_AS_STRING(itemStrObj);
+
+ auto pythonType = Py_TYPE(item)->tp_name;
+
+ PyErr_Format(PyExc_TypeError, "Failed to convert python input value %s of type '%s' to C type '%s'", itemStr, pythonType, typeid(T).name());
+ Py_XDECREF(itemStrObj);
+ Py_XDECREF(itemRepr);
+ Py_DECREF(seq);
+ return SWIG_TypeError;
+ }
+ out.push_back(element);
+ }
+ return SWIG_OK;
+ }
+
+%}
+
+%define %list_to_vector(TYPEMAP...)
+
+// this typemap works for struct argument set
+ %typemap(in) TYPEMAP* (TYPEMAP tmp) {
+ if (PySequence_Check($input)) {
+
+ if (from_python_to_vector($input, tmp) < 0) {
+ SWIG_fail;
+ }
+
+ $1 = &tmp;
+
+ } else {
+ PyErr_SetString(PyExc_TypeError, "Argument value object does not provide sequence protocol, implement __getitem__() method.");
+ SWIG_fail;
+ }
+ }
+
+// this typemap works for constructor
+ %typemap(in) TYPEMAP {
+ if (PySequence_Check($input)) {
+ if (from_python_to_vector($input, $1) < 0){
+ SWIG_fail;
+ }
+ } else {
+ PyErr_SetString(PyExc_TypeError, "Argument value object does not provide sequence protocol, implement __getitem__() method.");
+ SWIG_fail;
+ }
+ }
+
+// this typemap works for struct argument get
+
+ %typemap(out) TYPEMAP* {
+ $result = from_vector_to_python($1);
+ }
+
+// this typemap works for overloaded methods and ctors
+ %typemap(typecheck) (TYPEMAP) {
+ $1 = PySequence_Check($input) ? 1 : 0;
+ }
+
+%enddef
+
+%define %list_to_vector_clear(TYPEMAP...)
+ %typemap(in) (TYPEMAP);
+ %typemap(in) TYPEMAP* (TYPEMAP tmp);
+ %typemap(typecheck) (TYPEMAP);
+ %typemap(out) TYPEMAP*;
+%enddef
+