aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/test/test_quantize_and_dequantize.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyarmnn/test/test_quantize_and_dequantize.py')
-rw-r--r--python/pyarmnn/test/test_quantize_and_dequantize.py79
1 files changed, 79 insertions, 0 deletions
diff --git a/python/pyarmnn/test/test_quantize_and_dequantize.py b/python/pyarmnn/test/test_quantize_and_dequantize.py
new file mode 100644
index 0000000000..d0c711ac13
--- /dev/null
+++ b/python/pyarmnn/test/test_quantize_and_dequantize.py
@@ -0,0 +1,79 @@
+# Copyright © 2019 Arm Ltd. All rights reserved.
+# SPDX-License-Identifier: MIT
+import pytest
+import numpy as np
+
+import pyarmnn as ann
+
+# import generated so we can test for Dequantize_* and Quantize_*
+# functions not available in the public API.
+import pyarmnn._generated.pyarmnn as gen_ann
+
+
+@pytest.mark.parametrize('method', ['Quantize_uint8_t',
+ 'Quantize_int16_t',
+ 'Quantize_int32_t',
+ 'Dequantize_uint8_t',
+ 'Dequantize_int16_t',
+ 'Dequantize_int32_t'])
+def test_quantize_exists(method):
+ assert method in dir(gen_ann) and callable(getattr(gen_ann, method))
+
+
+@pytest.mark.parametrize('dt, min, max', [('uint8', 0, 255),
+ ('int16', -32768, 32767),
+ ('int32', -2147483648, 2147483647)])
+def test_quantize_uint8_output(dt, min, max):
+ result = ann.quantize(3.3274056911468506, 0.02620004490017891, 128, dt)
+ assert type(result) is int and min <= result <= max
+
+
+@pytest.mark.parametrize('dt', ['uint8',
+ 'int16',
+ 'int32'])
+def test_dequantize_uint8_output(dt):
+ result = ann.dequantize(3, 0.02620004490017891, 128, dt)
+ assert type(result) is float
+
+
+def test_quantize_unsupported_dtype():
+ with pytest.raises(ValueError) as err:
+ ann.quantize(3.3274056911468506, 0.02620004490017891, 128, 'int8')
+
+ assert 'Unexpected target datatype int8 given.' in str(err.value)
+
+
+def test_dequantize_unsupported_dtype():
+ with pytest.raises(ValueError) as err:
+ ann.dequantize(3, 0.02620004490017891, 128, 'int8')
+
+ assert 'Unexpected value datatype int8 given.' in str(err.value)
+
+
+def test_dequantize_value_range():
+ with pytest.raises(ValueError) as err:
+ ann.dequantize(-1, 0.02620004490017891, 128, 'uint8')
+
+ assert 'Value is not within range of the given datatype uint8' in str(err.value)
+
+
+@pytest.mark.parametrize('dt, data', [('uint8', np.uint8(255)),
+ ('int16', np.int16(32767)),
+ ('int32', np.int32(2147483647)),
+
+ ('uint8', np.int16(255)),
+ ('uint8', np.int32(255)),
+
+ ('int16', np.uint8(255)),
+ ('int16', np.int32(32767)),
+
+ ('int32', np.uint8(255)),
+ ('int32', np.int16(32767))
+
+ ])
+def test_dequantize_numpy_dt(dt, data):
+ result = ann.dequantize(data, 1, 0, dt)
+
+ assert type(result) is float
+
+ assert np.float32(data) == result