aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/test/test_const_tensor.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyarmnn/test/test_const_tensor.py')
-rw-r--r--python/pyarmnn/test/test_const_tensor.py54
1 files changed, 39 insertions, 15 deletions
diff --git a/python/pyarmnn/test/test_const_tensor.py b/python/pyarmnn/test/test_const_tensor.py
index fa6327f19c..2358d65918 100644
--- a/python/pyarmnn/test/test_const_tensor.py
+++ b/python/pyarmnn/test/test_const_tensor.py
@@ -6,8 +6,8 @@ import numpy as np
import pyarmnn as ann
-def _get_tensor_info(dt):
- tensor_info = ann.TensorInfo(ann.TensorShape((2, 3)), dt)
+def _get_const_tensor_info(dt):
+ tensor_info = ann.TensorInfo(ann.TensorShape((2, 3)), dt, 0.0, 0, True)
return tensor_info
@@ -23,7 +23,7 @@ def _get_tensor_info(dt):
(ann.DataType_QSymmS16, np.random.randint(1, size=(2, 4)).astype(np.int16))
], ids=['float32', 'float16', 'unsigned int8', 'signed int8', 'signed int8', 'int32', 'int16'])
def test_const_tensor_too_many_elements(dt, data):
- tensor_info = _get_tensor_info(dt)
+ tensor_info = _get_const_tensor_info(dt)
num_bytes = tensor_info.GetNumBytes()
with pytest.raises(ValueError) as err:
@@ -43,7 +43,7 @@ def test_const_tensor_too_many_elements(dt, data):
(ann.DataType_QSymmS16, np.random.randint(1, size=(2, 2)).astype(np.int16))
], ids=['float32', 'float16', 'unsigned int8', 'signed int8', 'signed int8', 'int32', 'int16'])
def test_const_tensor_too_little_elements(dt, data):
- tensor_info = _get_tensor_info(dt)
+ tensor_info = _get_const_tensor_info(dt)
num_bytes = tensor_info.GetNumBytes()
with pytest.raises(ValueError) as err:
@@ -63,7 +63,7 @@ def test_const_tensor_too_little_elements(dt, data):
(ann.DataType_QSymmS16, np.random.randint(1, size=(2, 2, 3, 3)).astype(np.int16))
], ids=['float32', 'float16', 'unsigned int8', 'signed int8', 'signed int8', 'int32', 'int16'])
def test_const_tensor_multi_dimensional_input(dt, data):
- tensor = ann.ConstTensor(ann.TensorInfo(ann.TensorShape((2, 2, 3, 3)), dt), data)
+ tensor = ann.ConstTensor(ann.TensorInfo(ann.TensorShape((2, 2, 3, 3)), dt, 0.0, 0, True), data)
assert data.size == tensor.GetNumElements()
assert data.nbytes == tensor.GetNumBytes()
@@ -72,7 +72,7 @@ def test_const_tensor_multi_dimensional_input(dt, data):
def test_create_const_tensor_from_tensor():
- tensor_info = ann.TensorInfo(ann.TensorShape((2, 3)), ann.DataType_Float32)
+ tensor_info = ann.TensorInfo(ann.TensorShape((2, 3)), ann.DataType_Float32, 0.0, 0, True)
tensor = ann.Tensor(tensor_info)
copied_tensor = ann.ConstTensor(tensor)
@@ -85,7 +85,7 @@ def test_create_const_tensor_from_tensor():
def test_const_tensor_from_tensor_has_memory_area_access_after_deletion_of_original_tensor():
- tensor_info = ann.TensorInfo(ann.TensorShape((2, 3)), ann.DataType_Float32)
+ tensor_info = ann.TensorInfo(ann.TensorShape((2, 3)), ann.DataType_Float32, 0.0, 0, True)
tensor = ann.Tensor(tensor_info)
tensor.get_memory_area()[0] = 100
@@ -125,7 +125,7 @@ def test_create_const_tensor_incorrect_args():
(-1, np.random.randint(1, size=(2, 3)).astype(np.float32)),
], ids=['unknown'])
def test_const_tensor_unsupported_datatype(dt, data):
- tensor_info = _get_tensor_info(dt)
+ tensor_info = _get_const_tensor_info(dt)
with pytest.raises(ValueError) as err:
ann.ConstTensor(tensor_info, data)
@@ -142,7 +142,7 @@ def test_const_tensor_unsupported_datatype(dt, data):
(ann.DataType_QSymmS8, [[1, 1, 1], [1, 1, 1]])
], ids=['float32', 'float16', 'unsigned int8', 'signed int8', 'signed int8'])
def test_const_tensor_incorrect_input_datatype(dt, data):
- tensor_info = _get_tensor_info(dt)
+ tensor_info = _get_const_tensor_info(dt)
with pytest.raises(TypeError) as err:
ann.ConstTensor(tensor_info, data)
@@ -163,7 +163,7 @@ def test_const_tensor_incorrect_input_datatype(dt, data):
class TestNumpyDataTypes:
def test_copy_const_tensor(self, dt, data):
- tensor_info = _get_tensor_info(dt)
+ tensor_info = _get_const_tensor_info(dt)
tensor = ann.ConstTensor(tensor_info, data)
copied_tensor = ann.ConstTensor(tensor)
@@ -175,7 +175,7 @@ class TestNumpyDataTypes:
assert copied_tensor.GetDataType() == tensor.GetDataType()
def test_const_tensor__str__(self, dt, data):
- tensor_info = _get_tensor_info(dt)
+ tensor_info = _get_const_tensor_info(dt)
d_type = tensor_info.GetDataType()
num_dimensions = tensor_info.GetNumDimensions()
num_bytes = tensor_info.GetNumBytes()
@@ -186,7 +186,7 @@ class TestNumpyDataTypes:
"{}, NumElements: {}}}".format(d_type, num_bytes, num_dimensions, num_elements)
def test_const_tensor_with_info(self, dt, data):
- tensor_info = _get_tensor_info(dt)
+ tensor_info = _get_const_tensor_info(dt)
elements = tensor_info.GetNumElements()
num_bytes = tensor_info.GetNumBytes()
d_type = dt
@@ -199,7 +199,7 @@ class TestNumpyDataTypes:
assert d_type == tensor.GetDataType()
def test_immutable_memory(self, dt, data):
- tensor_info = _get_tensor_info(dt)
+ tensor_info = _get_const_tensor_info(dt)
tensor = ann.ConstTensor(tensor_info, data)
@@ -217,7 +217,7 @@ class TestNumpyDataTypes:
ann.DataType_Signed32: np.int32,
ann.DataType_Float16: np.float16}
- tensor_info = _get_tensor_info(dt)
+ tensor_info = _get_const_tensor_info(dt)
tensor = ann.ConstTensor(tensor_info, data)
assert np_data_type_mapping[tensor.GetDataType()] == data.dtype
@@ -242,10 +242,34 @@ def test_numpy_dtype_mismatch_ann_dtype(dt, data):
ann.DataType_Signed32: np.int32,
ann.DataType_Float16: np.float16}
- tensor_info = _get_tensor_info(dt)
+ tensor_info = _get_const_tensor_info(dt)
with pytest.raises(TypeError) as err:
ann.ConstTensor(tensor_info, data)
assert str(err.value) == "Expected data to have type {} for type {} but instead got numpy.{}".format(
np_data_type_mapping[dt], dt, data.dtype)
+
+@pytest.mark.parametrize("dt, data",
+ [
+ (ann.DataType_Float32, np.random.randint(1, size=(2, 3)).astype(np.float32)),
+ (ann.DataType_Float16, np.random.randint(1, size=(2, 3)).astype(np.float16)),
+ (ann.DataType_QAsymmU8, np.random.randint(1, size=(2, 3)).astype(np.uint8)),
+ (ann.DataType_QAsymmS8, np.random.randint(1, size=(2, 3)).astype(np.int8)),
+ (ann.DataType_QSymmS8, np.random.randint(1, size=(2, 3)).astype(np.int8)),
+ (ann.DataType_Signed32, np.random.randint(1, size=(2, 3)).astype(np.int32)),
+ (ann.DataType_QSymmS16, np.random.randint(1, size=(2, 3)).astype(np.int16))
+ ], ids=['float32', 'float16', 'unsigned int8', 'signed int8', 'signed int8', 'int32', 'int16'])
+class TestConstTensorConstructorErrors:
+
+ def test_tensorinfo_isconstant_not_set(self, dt, data):
+ with pytest.raises(ValueError) as err:
+ ann.ConstTensor(ann.TensorInfo(ann.TensorShape((2, 2, 3, 3)), dt, 0.0, 0, False), data)
+
+ assert str(err.value) == "TensorInfo when initializing ConstTensor must be set to constant."
+
+ def test_tensor_tensorinfo_isconstant_not_set(self, dt, data):
+ with pytest.raises(ValueError) as err:
+ ann.ConstTensor(ann.Tensor(ann.TensorInfo(ann.TensorShape((2, 2, 3, 3)), dt, 0.0, 0, False), data))
+
+ assert str(err.value) == "TensorInfo of Tensor when initializing ConstTensor must be set to constant." \ No newline at end of file