diff options
author | Richard Burton <richard.burton@arm.com> | 2020-04-08 16:39:05 +0100 |
---|---|---|
committer | Jim Flynn <jim.flynn@arm.com> | 2020-04-10 16:11:09 +0000 |
commit | dc0c6ed9f8b993e63f492f203d7d7080ab4c835c (patch) | |
tree | ea8541990b13ebf1a038009aa6b8b4b1ea8c3f55 /python/pyarmnn/test/test_quantize_and_dequantize.py | |
parent | fe5a24beeef6e9a41366e694f41093565e748048 (diff) | |
download | armnn-dc0c6ed9f8b993e63f492f203d7d7080ab4c835c.tar.gz |
Add PyArmNN to work with ArmNN API of 20.02
* Add Swig rules for generating python wrapper
* Add documentation
* Add tests and testing data
Change-Id: If48eda08931514fa21e72214dfead2835f07237c
Signed-off-by: Richard Burton <richard.burton@arm.com>
Signed-off-by: Derek Lamberti <derek.lamberti@arm.com>
Diffstat (limited to 'python/pyarmnn/test/test_quantize_and_dequantize.py')
-rw-r--r-- | python/pyarmnn/test/test_quantize_and_dequantize.py | 91 |
1 files changed, 91 insertions, 0 deletions
diff --git a/python/pyarmnn/test/test_quantize_and_dequantize.py b/python/pyarmnn/test/test_quantize_and_dequantize.py new file mode 100644 index 0000000000..08fea39eda --- /dev/null +++ b/python/pyarmnn/test/test_quantize_and_dequantize.py @@ -0,0 +1,91 @@ +# Copyright © 2020 Arm Ltd. All rights reserved. +# SPDX-License-Identifier: MIT +import pytest +import numpy as np + +import pyarmnn as ann + +# import generated so we can test for Dequantize_* and Quantize_* +# functions not available in the public API. +import pyarmnn._generated.pyarmnn as gen_ann + + +@pytest.mark.parametrize('method', ['Quantize_int8_t', + 'Quantize_uint8_t', + 'Quantize_int16_t', + 'Quantize_int32_t', + 'Dequantize_int8_t', + 'Dequantize_uint8_t', + 'Dequantize_int16_t', + 'Dequantize_int32_t']) +def test_quantize_exists(method): + assert method in dir(gen_ann) and callable(getattr(gen_ann, method)) + + +@pytest.mark.parametrize('dt, min, max', [('uint8', 0, 255), + ('int8', -128, 127), + ('int16', -32768, 32767), + ('int32', -2147483648, 2147483647)]) +def test_quantize_uint8_output(dt, min, max): + result = ann.quantize(3.3274056911468506, 0.02620004490017891, 128, dt) + assert type(result) is int and min <= result <= max + + +@pytest.mark.parametrize('dt', ['uint8', + 'int8', + 'int16', + 'int32']) +def test_dequantize_uint8_output(dt): + result = ann.dequantize(3, 0.02620004490017891, 128, dt) + assert type(result) is float + + +def test_quantize_unsupported_dtype(): + with pytest.raises(ValueError) as err: + ann.quantize(3.3274056911468506, 0.02620004490017891, 128, 'uint16') + + assert 'Unexpected target datatype uint16 given.' in str(err.value) + + +def test_dequantize_unsupported_dtype(): + with pytest.raises(ValueError) as err: + ann.dequantize(3, 0.02620004490017891, 128, 'uint16') + + assert 'Unexpected value datatype uint16 given.' in str(err.value) + + +def test_dequantize_value_range(): + with pytest.raises(ValueError) as err: + ann.dequantize(-1, 0.02620004490017891, 128, 'uint8') + + assert 'Value is not within range of the given datatype uint8' in str(err.value) + + +@pytest.mark.parametrize('dt, data', [('uint8', np.uint8(255)), + ('int8', np.int8(127)), + ('int16', np.int16(32767)), + ('int32', np.int32(2147483647)), + + ('uint8', np.int8(127)), + ('uint8', np.int16(255)), + ('uint8', np.int32(255)), + + ('int8', np.uint8(127)), + ('int8', np.int16(127)), + ('int8', np.int32(127)), + + ('int16', np.int8(127)), + ('int16', np.uint8(255)), + ('int16', np.int32(32767)), + + ('int32', np.uint8(255)), + ('int16', np.int8(127)), + ('int32', np.int16(32767)) + + ]) +def test_dequantize_numpy_dt(dt, data): + result = ann.dequantize(data, 1, 0, dt) + + assert type(result) is float + + assert np.float32(data) == result |