aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/src/pyarmnn/_tensor/const_tensor.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyarmnn/src/pyarmnn/_tensor/const_tensor.py')
-rw-r--r--python/pyarmnn/src/pyarmnn/_tensor/const_tensor.py159
1 files changed, 159 insertions, 0 deletions
diff --git a/python/pyarmnn/src/pyarmnn/_tensor/const_tensor.py b/python/pyarmnn/src/pyarmnn/_tensor/const_tensor.py
new file mode 100644
index 0000000000..9735d7a63b
--- /dev/null
+++ b/python/pyarmnn/src/pyarmnn/_tensor/const_tensor.py
@@ -0,0 +1,159 @@
+# Copyright © 2019 Arm Ltd. All rights reserved.
+# SPDX-License-Identifier: MIT
+import numpy as np
+
+from .._generated.pyarmnn import DataType_QuantisedAsymm8, DataType_QuantisedSymm16, DataType_Signed32, \
+ DataType_Float32, DataType_Float16
+from .._generated.pyarmnn import ConstTensor as AnnConstTensor, TensorInfo, Tensor
+
+
+class ConstTensor(AnnConstTensor):
+ """Creates a PyArmNN ConstTensor object.
+
+ A ConstTensor is a Tensor with an immutable data store. Typically, a ConstTensor
+ is used to input data into a network when running inference.
+
+ This class overrides the swig generated Tensor class. The aim of
+ this is to have an easy to use public API for the ConstTensor objects.
+
+ """
+
+ def __init__(self, *args):
+ """
+ Supported tensor data types:
+ DataType_QuantisedAsymm8,
+ DataType_QuantisedSymm16,
+ DataType_Signed32,
+ DataType_Float32,
+ DataType_Float16
+
+ Examples:
+ Create empty ConstTensor
+ >>> import pyarmnn as ann
+ >>> ann.ConstTensor()
+
+ Create ConstTensor given tensor info and input data
+ >>> input_data = ... # numpy array
+ >>> ann.ConstTensor(ann.TensorInfo(...), input_data)
+
+ Create ConstTensor from another ConstTensor i.e. copy ConstTensor
+ >>> ann.ConstTensor(ann.ConstTensor())
+
+ Create ConstTensor from tensor
+ >>> ann.ConstTensor(ann.Tensor())
+
+ Args:
+ tensor (Tensor, optional): Create a ConstTensor from a Tensor.
+ const_tensor (ConstTensor, optional): Create a ConstTensor from a ConstTensor i.e. copy.
+ tensor_info (TensorInfo, optional): Tensor information.
+ input_data (ndarray): Numpy array. The numpy array will be transformed to a
+ buffer according to type returned by `TensorInfo.GetDataType`.
+ Input data values type must correspond to data type returned by
+ `TensorInfo.GetDataType`.
+
+ Raises:
+ TypeError: Unsupported input data type.
+ ValueError: Unsupported tensor data type and incorrect input data size.
+ """
+ self.__memory_area = None
+
+ # TensorInfo as first argument and numpy array as second
+ if len(args) > 1 and isinstance(args[0], TensorInfo):
+ if isinstance(args[1], np.ndarray):
+ self.__create_memory_area(args[0].GetDataType(), args[0].GetNumBytes(), args[0].GetNumElements(),
+ args[1])
+ super().__init__(args[0], self.__memory_area.data)
+ else:
+ raise TypeError('Data must be provided as a numpy array.')
+
+ # copy constructor - reference to memory area is passed from copied const
+ # tensor and armnn's copy constructor is called
+ elif len(args) > 0 and isinstance(args[0], (ConstTensor, Tensor)):
+ self.__memory_area = args[0].get_memory_area()
+ super().__init__(args[0])
+
+ # empty tensor
+ elif len(args) == 0:
+ super().__init__()
+
+ else:
+ raise ValueError('Incorrect number of arguments or type of arguments provided to create Const Tensor.')
+
+ def __copy__(self) -> 'ConstTensor':
+ """ Make copy of a const tensor.
+
+ Make const tensor copyable using the python copy operation.
+
+ Note:
+ The tensor memory area is NOT copied. Instead, the new tensor maintains a
+ reference to the same memory area as the old tensor.
+
+ Example:
+ Copy empty tensor
+ >>> from copy import copy
+ >>> import pyarmnn as ann
+ >>> tensor = ann.ConstTensor()
+ >>> copied_tensor = copy(tensor)
+
+ Returns:
+ Tensor: a copy of the tensor object provided.
+
+ """
+ return ConstTensor(self)
+
+ @staticmethod
+ def __check_size(data: np.ndarray, num_bytes: int, num_elements: int):
+ """ Check the size of the input data against the number of bytes provided by tensor info.
+
+ Args:
+ data (ndarray): Input data.
+ num_bytes (int): Number of bytes required by tensor info.
+ num_elements: Number of elements required by tensor info.
+
+ Raises:
+ ValueError: number of bytes in input data does not match tensor info.
+
+ """
+ size_in_bytes = data.nbytes
+ elements = data.size
+
+ if size_in_bytes != num_bytes:
+ raise ValueError(
+ "ConstTensor requires {} bytes, {} provided. "
+ "Is your input array data type ({}) aligned with TensorInfo?".format(num_bytes, size_in_bytes,
+ data.dtype))
+ elif elements != num_elements:
+ raise ValueError("ConstTensor requires {} elements, {} provided.".format(num_elements, elements))
+
+ def __create_memory_area(self, data_type: int, num_bytes: int, num_elements: int, data: np.ndarray):
+ """ Create the memory area used by the tensor to output its results.
+
+ Args:
+ data_type (int): The type of data that will be stored in the memory area.
+ See DataType_*.
+ num_bytes (int): Determines the size of the memory area that will be created.
+ num_elements (int): Determines number of elements in memory area.
+ data (ndarray): Input data as numpy array.
+
+ """
+ np_data_type_mapping = {DataType_QuantisedAsymm8: np.uint8,
+ DataType_Float32: np.float32,
+ DataType_QuantisedSymm16: np.int16,
+ DataType_Signed32: np.int32,
+ DataType_Float16: np.float16}
+
+ if data_type not in np_data_type_mapping:
+ raise ValueError("The data type provided for this Tensor is not supported: {}".format(data_type))
+
+ self.__check_size(data, num_bytes, num_elements)
+ self.__memory_area = data
+ self.__memory_area.flags.writeable = False
+
+ def get_memory_area(self) -> np.ndarray:
+ """ Get values that are stored by the tensor.
+
+ Returns:
+ ndarray: Tensor data (as numpy array).
+
+ """
+ return self.__memory_area