aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/src/pyarmnn/_tensor
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyarmnn/src/pyarmnn/_tensor')
-rw-r--r--python/pyarmnn/src/pyarmnn/_tensor/__init__.py6
-rw-r--r--python/pyarmnn/src/pyarmnn/_tensor/const_tensor.py159
-rw-r--r--python/pyarmnn/src/pyarmnn/_tensor/tensor.py119
-rw-r--r--python/pyarmnn/src/pyarmnn/_tensor/workload_tensors.py123
4 files changed, 407 insertions, 0 deletions
diff --git a/python/pyarmnn/src/pyarmnn/_tensor/__init__.py b/python/pyarmnn/src/pyarmnn/_tensor/__init__.py
new file mode 100644
index 0000000000..0c928785b4
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/_tensor/__init__.py
@@ -0,0 +1,6 @@
+# Copyright © 2019 Arm Ltd. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+from .const_tensor import ConstTensor
+from .tensor import Tensor
+from .workload_tensors import make_input_tensors, make_output_tensors, workload_tensors_to_ndarray
diff --git a/python/pyarmnn/src/pyarmnn/_tensor/const_tensor.py b/python/pyarmnn/src/pyarmnn/_tensor/const_tensor.py
new file mode 100644
index 0000000000..9735d7a63b
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/_tensor/const_tensor.py
@@ -0,0 +1,159 @@
+# Copyright © 2019 Arm Ltd. All rights reserved.
+# SPDX-License-Identifier: MIT
+import numpy as np
+
+from .._generated.pyarmnn import DataType_QuantisedAsymm8, DataType_QuantisedSymm16, DataType_Signed32, \
+ DataType_Float32, DataType_Float16
+from .._generated.pyarmnn import ConstTensor as AnnConstTensor, TensorInfo, Tensor
+
+
+class ConstTensor(AnnConstTensor):
+ """Creates a PyArmNN ConstTensor object.
+
+ A ConstTensor is a Tensor with an immutable data store. Typically, a ConstTensor
+ is used to input data into a network when running inference.
+
+ This class overrides the swig generated Tensor class. The aim of
+ this is to have an easy to use public API for the ConstTensor objects.
+
+ """
+
+ def __init__(self, *args):
+ """
+ Supported tensor data types:
+ DataType_QuantisedAsymm8,
+ DataType_QuantisedSymm16,
+ DataType_Signed32,
+ DataType_Float32,
+ DataType_Float16
+
+ Examples:
+ Create empty ConstTensor
+ >>> import pyarmnn as ann
+ >>> ann.ConstTensor()
+
+ Create ConstTensor given tensor info and input data
+ >>> input_data = ... # numpy array
+ >>> ann.ConstTensor(ann.TensorInfo(...), input_data)
+
+ Create ConstTensor from another ConstTensor i.e. copy ConstTensor
+ >>> ann.ConstTensor(ann.ConstTensor())
+
+ Create ConstTensor from tensor
+ >>> ann.ConstTensor(ann.Tensor())
+
+ Args:
+ tensor (Tensor, optional): Create a ConstTensor from a Tensor.
+ const_tensor (ConstTensor, optional): Create a ConstTensor from a ConstTensor i.e. copy.
+ tensor_info (TensorInfo, optional): Tensor information.
+ input_data (ndarray): Numpy array. The numpy array will be transformed to a
+ buffer according to type returned by `TensorInfo.GetDataType`.
+ Input data values type must correspond to data type returned by
+ `TensorInfo.GetDataType`.
+
+ Raises:
+ TypeError: Unsupported input data type.
+ ValueError: Unsupported tensor data type and incorrect input data size.
+ """
+ self.__memory_area = None
+
+ # TensorInfo as first argument and numpy array as second
+ if len(args) > 1 and isinstance(args[0], TensorInfo):
+ if isinstance(args[1], np.ndarray):
+ self.__create_memory_area(args[0].GetDataType(), args[0].GetNumBytes(), args[0].GetNumElements(),
+ args[1])
+ super().__init__(args[0], self.__memory_area.data)
+ else:
+ raise TypeError('Data must be provided as a numpy array.')
+
+ # copy constructor - reference to memory area is passed from copied const
+ # tensor and armnn's copy constructor is called
+ elif len(args) > 0 and isinstance(args[0], (ConstTensor, Tensor)):
+ self.__memory_area = args[0].get_memory_area()
+ super().__init__(args[0])
+
+ # empty tensor
+ elif len(args) == 0:
+ super().__init__()
+
+ else:
+ raise ValueError('Incorrect number of arguments or type of arguments provided to create Const Tensor.')
+
+ def __copy__(self) -> 'ConstTensor':
+ """ Make copy of a const tensor.
+
+ Make const tensor copyable using the python copy operation.
+
+ Note:
+ The tensor memory area is NOT copied. Instead, the new tensor maintains a
+ reference to the same memory area as the old tensor.
+
+ Example:
+ Copy empty tensor
+ >>> from copy import copy
+ >>> import pyarmnn as ann
+ >>> tensor = ann.ConstTensor()
+ >>> copied_tensor = copy(tensor)
+
+ Returns:
+ Tensor: a copy of the tensor object provided.
+
+ """
+ return ConstTensor(self)
+
+ @staticmethod
+ def __check_size(data: np.ndarray, num_bytes: int, num_elements: int):
+ """ Check the size of the input data against the number of bytes provided by tensor info.
+
+ Args:
+ data (ndarray): Input data.
+ num_bytes (int): Number of bytes required by tensor info.
+ num_elements: Number of elements required by tensor info.
+
+ Raises:
+ ValueError: number of bytes in input data does not match tensor info.
+
+ """
+ size_in_bytes = data.nbytes
+ elements = data.size
+
+ if size_in_bytes != num_bytes:
+ raise ValueError(
+ "ConstTensor requires {} bytes, {} provided. "
+ "Is your input array data type ({}) aligned with TensorInfo?".format(num_bytes, size_in_bytes,
+ data.dtype))
+ elif elements != num_elements:
+ raise ValueError("ConstTensor requires {} elements, {} provided.".format(num_elements, elements))
+
+ def __create_memory_area(self, data_type: int, num_bytes: int, num_elements: int, data: np.ndarray):
+ """ Create the memory area used by the tensor to output its results.
+
+ Args:
+ data_type (int): The type of data that will be stored in the memory area.
+ See DataType_*.
+ num_bytes (int): Determines the size of the memory area that will be created.
+ num_elements (int): Determines number of elements in memory area.
+ data (ndarray): Input data as numpy array.
+
+ """
+ np_data_type_mapping = {DataType_QuantisedAsymm8: np.uint8,
+ DataType_Float32: np.float32,
+ DataType_QuantisedSymm16: np.int16,
+ DataType_Signed32: np.int32,
+ DataType_Float16: np.float16}
+
+ if data_type not in np_data_type_mapping:
+ raise ValueError("The data type provided for this Tensor is not supported: {}".format(data_type))
+
+ self.__check_size(data, num_bytes, num_elements)
+ self.__memory_area = data
+ self.__memory_area.flags.writeable = False
+
+ def get_memory_area(self) -> np.ndarray:
+ """ Get values that are stored by the tensor.
+
+ Returns:
+ ndarray: Tensor data (as numpy array).
+
+ """
+ return self.__memory_area
diff --git a/python/pyarmnn/src/pyarmnn/_tensor/tensor.py b/python/pyarmnn/src/pyarmnn/_tensor/tensor.py
new file mode 100644
index 0000000000..5906b6bae6
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/_tensor/tensor.py
@@ -0,0 +1,119 @@
+# Copyright © 2019 Arm Ltd. All rights reserved.
+# SPDX-License-Identifier: MIT
+import numpy as np
+
+from .._generated.pyarmnn import Tensor as annTensor, TensorInfo, DataType_QuantisedAsymm8, \
+ DataType_Float32, DataType_QuantisedSymm16, DataType_Signed32, DataType_Float16
+
+
+class Tensor(annTensor):
+ """pyArmnn Tensor object
+
+ This class overrides the swig generated Tensor class. The aim of
+ this is to create an easy to use public api for the Tensor object.
+
+ Memory is allocated and managed by this class, avoiding the need to manage
+ a separate memory area for the tensor compared to the swig generated api.
+
+ """
+
+ def __init__(self, *args):
+ """ Create Tensor object.
+
+ Supported tensor data types:
+ DataType_QuantisedAsymm8,
+ DataType_QuantisedSymm16,
+ DataType_Signed32,
+ DataType_Float32,
+ DataType_Float16
+
+ Examples:
+ Create an empty tensor
+ >>> import pyarmnn as ann
+ >>> ann.Tensor()
+
+ Create tensor given tensor information
+ >>> ann.Tensor(ann.TensorInfo(...))
+
+ Create tensor from another tensor i.e. copy a tensor
+ >>> ann.Tensor(ann.Tensor())
+
+ Args:
+ tensor(Tensor, optional): Create Tensor from a Tensor i.e. copy.
+ tensor_info (TensorInfo, optional): Tensor information.
+
+ Raises:
+ TypeError: unsupported input data type.
+ ValueError: appropriate constructor could not be found with provided arguments.
+
+ """
+ self.__memory_area = None
+
+ # TensorInfo as first argument, we need to create memory area manually
+ if len(args) > 0 and isinstance(args[0], TensorInfo):
+ self.__create_memory_area(args[0].GetDataType(), args[0].GetNumElements())
+ super().__init__(args[0], self.__memory_area.data)
+
+ # copy constructor - reference to memory area is passed from copied tensor
+ # and armnn's copy constructor is called
+ elif len(args) > 0 and isinstance(args[0], Tensor):
+ self.__memory_area = args[0].get_memory_area()
+ super().__init__(args[0])
+
+ # empty constructor
+ elif len(args) == 0:
+ super().__init__()
+
+ else:
+ raise ValueError('Incorrect number of arguments or type of arguments provided to create Tensor.')
+
+ def __copy__(self) -> 'Tensor':
+ """ Make copy of a tensor.
+
+ Make tensor copyable using the python copy operation.
+
+ Note:
+ The tensor memory area is NOT copied. Instead, the new tensor maintains a
+ reference to the same memory area as the old tensor.
+
+ Example:
+ Copy empty tensor
+ >>> from copy import copy
+ >>> import pyarmnn as ann
+ >>> tensor = ann.Tensor()
+ >>> copied_tensor = copy(tensor)
+
+ Returns:
+ Tensor: a copy of the tensor object provided.
+
+ """
+ return Tensor(self)
+
+ def __create_memory_area(self, data_type: int, num_elements: int):
+ """ Create the memory area used by the tensor to output its results.
+
+ Args:
+ data_type (int): The type of data that will be stored in the memory area.
+ See DataType_*.
+ num_elements (int): Determines the size of the memory area that will be created.
+
+ """
+ np_data_type_mapping = {DataType_QuantisedAsymm8: np.uint8,
+ DataType_Float32: np.float32,
+ DataType_QuantisedSymm16: np.int16,
+ DataType_Signed32: np.int32,
+ DataType_Float16: np.float16}
+
+ if data_type not in np_data_type_mapping:
+ raise ValueError("The data type provided for this Tensor is not supported.")
+
+ self.__memory_area = np.empty(shape=(num_elements,), dtype=np_data_type_mapping[data_type])
+
+ def get_memory_area(self) -> np.ndarray:
+ """ Get values that are stored by the tensor.
+
+ Returns:
+ ndarray : Tensor data (as numpy array).
+
+ """
+ return self.__memory_area
diff --git a/python/pyarmnn/src/pyarmnn/_tensor/workload_tensors.py b/python/pyarmnn/src/pyarmnn/_tensor/workload_tensors.py
new file mode 100644
index 0000000000..e345a1a5d4
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/_tensor/workload_tensors.py
@@ -0,0 +1,123 @@
+# Copyright © 2019 Arm Ltd. All rights reserved.
+# SPDX-License-Identifier: MIT
+"""
+This file contains functions relating to WorkloadTensors.
+WorkloadTensors are the inputTensors and outputTensors that are consumed by IRuntime.EnqueueWorkload.
+"""
+from typing import Union, List, Tuple
+
+import numpy as np
+
+from .tensor import Tensor
+from .const_tensor import ConstTensor
+
+
+def make_input_tensors(inputs_binding_info: List[Tuple],
+ input_data: List[np.ndarray]) -> List[Tuple[int, ConstTensor]]:
+ """Returns `inputTensors` to be used with `IRuntime.EnqueueWorkload`.
+
+ This is the primary function to call when you want to produce `inputTensors` for `IRuntime.EnqueueWorkload`.
+ The output is a list of tuples containing ConstTensors with a corresponding input tensor id.
+ The output should be used directly with `IRuntime.EnqueueWorkload`.
+ This function works for single or multiple input data and binding information.
+
+ Examples:
+ Creating inputTensors.
+ >>> import pyarmnn as ann
+ >>> import numpy as np
+ >>>
+ >>> parser = ann.ITfLiteParser()
+ >>> ...
+ >>> example_image = np.array(...)
+ >>> input_binding_info = parser.GetNetworkInputBindingInfo(...)
+ >>>
+ >>> input_tensors = ann.make_input_tensors([input_binding_info], [example_image])
+
+ Args:
+ inputs_binding_info (list of tuples): (int, `TensorInfo`) Binding information for input tensors obtained from `GetNetworkInputBindingInfo`.
+ input_data (ndarray): Tensor data to be used for inference.
+
+ Returns:
+ list: `inputTensors` - A list of tuples (`int` , `ConstTensor`).
+
+
+ Raises:
+ ValueError: If length of `inputs_binding_info` and `input_data` are not the same.
+ """
+ if len(inputs_binding_info) != len(input_data):
+ raise ValueError("Length of 'inputs_binding_info' does not match length of 'input_data'")
+
+ input_tensors = []
+
+ for in_bind_info, in_data in zip(inputs_binding_info, input_data):
+ in_tensor_id = in_bind_info[0]
+ in_tensor_info = in_bind_info[1]
+ input_tensors.append((in_tensor_id, ConstTensor(in_tensor_info, in_data)))
+
+ return input_tensors
+
+
+def make_output_tensors(outputs_binding_info: List[Tuple]) -> List[Tuple[int, Tensor]]:
+ """Returns `outputTensors` to be used with `IRuntime.EnqueueWorkload`.
+
+ This is the primary function to call when you want to produce `outputTensors` for `IRuntime.EnqueueWorkload`.
+ The output is a list of tuples containing Tensors with a corresponding output tensor id.
+ The output should be used directly with `IRuntime.EnqueueWorkload`.
+
+ Examples:
+ Creating outputTensors.
+ >>> import pyarmnn as ann
+ >>>
+ >>> parser = ann.ITfLiteParser()
+ >>> ...
+ >>> output_binding_info = parser.GetNetworkOutputBindingInfo(...)
+ >>>
+ >>> output_tensors = ann.make_output_tensors([output_binding_info])
+
+ Args:
+ outputs_binding_info (list of tuples): (int, `TensorInfo`) Binding information for output tensors obtained from `GetNetworkOutputBindingInfo`.
+
+ Returns:
+ list: `outputTensors` - A list of tuples (`int`, `Tensor`).
+ """
+ output_tensors = []
+
+ for out_bind_info in outputs_binding_info:
+ out_tensor_id = out_bind_info[0]
+ out_tensor_info = out_bind_info[1]
+ output_tensors.append((out_tensor_id, Tensor(out_tensor_info)))
+
+ return output_tensors
+
+
+def workload_tensors_to_ndarray(workload_tensors: List[Tuple[int, Union[Tensor, ConstTensor]]]) -> List[np.ndarray]:
+ """Returns a list of the underlying tensor data as ndarrays from `inputTensors` or `outputTensors`.
+
+ We refer to `inputTensors` and `outputTensors` as workload tensors because
+ they are used with `IRuntime.EnqueueWorkload`.
+ Although this function can be used on either `inputTensors` or `outputTensors` the main use of this function
+ is to collect results from `outputTensors` after `IRuntime.EnqueueWorkload` has been called.
+
+ Examples:
+ Getting results after inference.
+ >>> import pyarmnn as ann
+ >>>
+ >>> ...
+ >>> runtime = ann.IRuntime(...)
+ >>> ...
+ >>> runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
+ >>>
+ >>> inference_results = tensors_to_ndarray(output_tensors)
+
+ Args:
+ workload_tensors (inputTensors or outputTensors): `inputTensors` or `outputTensors` to get data from.
+
+ Returns:
+ list: List of `ndarrays` for the underlying tensor data from given `inputTensors` or `outputTensors`.
+ """
+ arrays = []
+ for index, (_, tensor) in enumerate(workload_tensors):
+ arrays.append(tensor.get_memory_area())
+ print("Workload tensor {} shape: {}".format(index, tensor.GetShape()))
+
+ return arrays