diff options
Diffstat (limited to 'python/pyarmnn/test/test_const_tensor.py')
-rw-r--r-- | python/pyarmnn/test/test_const_tensor.py | 54 |
1 files changed, 39 insertions, 15 deletions
diff --git a/python/pyarmnn/test/test_const_tensor.py b/python/pyarmnn/test/test_const_tensor.py index fa6327f19c..2358d65918 100644 --- a/python/pyarmnn/test/test_const_tensor.py +++ b/python/pyarmnn/test/test_const_tensor.py @@ -6,8 +6,8 @@ import numpy as np import pyarmnn as ann -def _get_tensor_info(dt): - tensor_info = ann.TensorInfo(ann.TensorShape((2, 3)), dt) +def _get_const_tensor_info(dt): + tensor_info = ann.TensorInfo(ann.TensorShape((2, 3)), dt, 0.0, 0, True) return tensor_info @@ -23,7 +23,7 @@ def _get_tensor_info(dt): (ann.DataType_QSymmS16, np.random.randint(1, size=(2, 4)).astype(np.int16)) ], ids=['float32', 'float16', 'unsigned int8', 'signed int8', 'signed int8', 'int32', 'int16']) def test_const_tensor_too_many_elements(dt, data): - tensor_info = _get_tensor_info(dt) + tensor_info = _get_const_tensor_info(dt) num_bytes = tensor_info.GetNumBytes() with pytest.raises(ValueError) as err: @@ -43,7 +43,7 @@ def test_const_tensor_too_many_elements(dt, data): (ann.DataType_QSymmS16, np.random.randint(1, size=(2, 2)).astype(np.int16)) ], ids=['float32', 'float16', 'unsigned int8', 'signed int8', 'signed int8', 'int32', 'int16']) def test_const_tensor_too_little_elements(dt, data): - tensor_info = _get_tensor_info(dt) + tensor_info = _get_const_tensor_info(dt) num_bytes = tensor_info.GetNumBytes() with pytest.raises(ValueError) as err: @@ -63,7 +63,7 @@ def test_const_tensor_too_little_elements(dt, data): (ann.DataType_QSymmS16, np.random.randint(1, size=(2, 2, 3, 3)).astype(np.int16)) ], ids=['float32', 'float16', 'unsigned int8', 'signed int8', 'signed int8', 'int32', 'int16']) def test_const_tensor_multi_dimensional_input(dt, data): - tensor = ann.ConstTensor(ann.TensorInfo(ann.TensorShape((2, 2, 3, 3)), dt), data) + tensor = ann.ConstTensor(ann.TensorInfo(ann.TensorShape((2, 2, 3, 3)), dt, 0.0, 0, True), data) assert data.size == tensor.GetNumElements() assert data.nbytes == tensor.GetNumBytes() @@ -72,7 +72,7 @@ def test_const_tensor_multi_dimensional_input(dt, data): def test_create_const_tensor_from_tensor(): - tensor_info = ann.TensorInfo(ann.TensorShape((2, 3)), ann.DataType_Float32) + tensor_info = ann.TensorInfo(ann.TensorShape((2, 3)), ann.DataType_Float32, 0.0, 0, True) tensor = ann.Tensor(tensor_info) copied_tensor = ann.ConstTensor(tensor) @@ -85,7 +85,7 @@ def test_create_const_tensor_from_tensor(): def test_const_tensor_from_tensor_has_memory_area_access_after_deletion_of_original_tensor(): - tensor_info = ann.TensorInfo(ann.TensorShape((2, 3)), ann.DataType_Float32) + tensor_info = ann.TensorInfo(ann.TensorShape((2, 3)), ann.DataType_Float32, 0.0, 0, True) tensor = ann.Tensor(tensor_info) tensor.get_memory_area()[0] = 100 @@ -125,7 +125,7 @@ def test_create_const_tensor_incorrect_args(): (-1, np.random.randint(1, size=(2, 3)).astype(np.float32)), ], ids=['unknown']) def test_const_tensor_unsupported_datatype(dt, data): - tensor_info = _get_tensor_info(dt) + tensor_info = _get_const_tensor_info(dt) with pytest.raises(ValueError) as err: ann.ConstTensor(tensor_info, data) @@ -142,7 +142,7 @@ def test_const_tensor_unsupported_datatype(dt, data): (ann.DataType_QSymmS8, [[1, 1, 1], [1, 1, 1]]) ], ids=['float32', 'float16', 'unsigned int8', 'signed int8', 'signed int8']) def test_const_tensor_incorrect_input_datatype(dt, data): - tensor_info = _get_tensor_info(dt) + tensor_info = _get_const_tensor_info(dt) with pytest.raises(TypeError) as err: ann.ConstTensor(tensor_info, data) @@ -163,7 +163,7 @@ def test_const_tensor_incorrect_input_datatype(dt, data): class TestNumpyDataTypes: def test_copy_const_tensor(self, dt, data): - tensor_info = _get_tensor_info(dt) + tensor_info = _get_const_tensor_info(dt) tensor = ann.ConstTensor(tensor_info, data) copied_tensor = ann.ConstTensor(tensor) @@ -175,7 +175,7 @@ class TestNumpyDataTypes: assert copied_tensor.GetDataType() == tensor.GetDataType() def test_const_tensor__str__(self, dt, data): - tensor_info = _get_tensor_info(dt) + tensor_info = _get_const_tensor_info(dt) d_type = tensor_info.GetDataType() num_dimensions = tensor_info.GetNumDimensions() num_bytes = tensor_info.GetNumBytes() @@ -186,7 +186,7 @@ class TestNumpyDataTypes: "{}, NumElements: {}}}".format(d_type, num_bytes, num_dimensions, num_elements) def test_const_tensor_with_info(self, dt, data): - tensor_info = _get_tensor_info(dt) + tensor_info = _get_const_tensor_info(dt) elements = tensor_info.GetNumElements() num_bytes = tensor_info.GetNumBytes() d_type = dt @@ -199,7 +199,7 @@ class TestNumpyDataTypes: assert d_type == tensor.GetDataType() def test_immutable_memory(self, dt, data): - tensor_info = _get_tensor_info(dt) + tensor_info = _get_const_tensor_info(dt) tensor = ann.ConstTensor(tensor_info, data) @@ -217,7 +217,7 @@ class TestNumpyDataTypes: ann.DataType_Signed32: np.int32, ann.DataType_Float16: np.float16} - tensor_info = _get_tensor_info(dt) + tensor_info = _get_const_tensor_info(dt) tensor = ann.ConstTensor(tensor_info, data) assert np_data_type_mapping[tensor.GetDataType()] == data.dtype @@ -242,10 +242,34 @@ def test_numpy_dtype_mismatch_ann_dtype(dt, data): ann.DataType_Signed32: np.int32, ann.DataType_Float16: np.float16} - tensor_info = _get_tensor_info(dt) + tensor_info = _get_const_tensor_info(dt) with pytest.raises(TypeError) as err: ann.ConstTensor(tensor_info, data) assert str(err.value) == "Expected data to have type {} for type {} but instead got numpy.{}".format( np_data_type_mapping[dt], dt, data.dtype) + +@pytest.mark.parametrize("dt, data", + [ + (ann.DataType_Float32, np.random.randint(1, size=(2, 3)).astype(np.float32)), + (ann.DataType_Float16, np.random.randint(1, size=(2, 3)).astype(np.float16)), + (ann.DataType_QAsymmU8, np.random.randint(1, size=(2, 3)).astype(np.uint8)), + (ann.DataType_QAsymmS8, np.random.randint(1, size=(2, 3)).astype(np.int8)), + (ann.DataType_QSymmS8, np.random.randint(1, size=(2, 3)).astype(np.int8)), + (ann.DataType_Signed32, np.random.randint(1, size=(2, 3)).astype(np.int32)), + (ann.DataType_QSymmS16, np.random.randint(1, size=(2, 3)).astype(np.int16)) + ], ids=['float32', 'float16', 'unsigned int8', 'signed int8', 'signed int8', 'int32', 'int16']) +class TestConstTensorConstructorErrors: + + def test_tensorinfo_isconstant_not_set(self, dt, data): + with pytest.raises(ValueError) as err: + ann.ConstTensor(ann.TensorInfo(ann.TensorShape((2, 2, 3, 3)), dt, 0.0, 0, False), data) + + assert str(err.value) == "TensorInfo when initializing ConstTensor must be set to constant." + + def test_tensor_tensorinfo_isconstant_not_set(self, dt, data): + with pytest.raises(ValueError) as err: + ann.ConstTensor(ann.Tensor(ann.TensorInfo(ann.TensorShape((2, 2, 3, 3)), dt, 0.0, 0, False), data)) + + assert str(err.value) == "TensorInfo of Tensor when initializing ConstTensor must be set to constant."
\ No newline at end of file |