aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/test/test_quantize_and_dequantize.py
blob: d0c711ac13979c3c2b3e2c89f98bc842bb08cbd4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Copyright © 2019 Arm Ltd. All rights reserved.
# SPDX-License-Identifier: MIT
import pytest
import numpy as np

import pyarmnn as ann

# import generated so we can test for Dequantize_* and Quantize_*
# functions not available in the public API.
import pyarmnn._generated.pyarmnn as gen_ann


@pytest.mark.parametrize('method', ['Quantize_uint8_t',
                                    'Quantize_int16_t',
                                    'Quantize_int32_t',
                                    'Dequantize_uint8_t',
                                    'Dequantize_int16_t',
                                    'Dequantize_int32_t'])
def test_quantize_exists(method):
    assert method in dir(gen_ann) and callable(getattr(gen_ann, method))


@pytest.mark.parametrize('dt, min, max', [('uint8', 0, 255),
                                          ('int16', -32768, 32767),
                                          ('int32', -2147483648, 2147483647)])
def test_quantize_uint8_output(dt, min, max):
    result = ann.quantize(3.3274056911468506, 0.02620004490017891, 128, dt)
    assert type(result) is int and min <= result <= max


@pytest.mark.parametrize('dt', ['uint8',
                                'int16',
                                'int32'])
def test_dequantize_uint8_output(dt):
    result = ann.dequantize(3, 0.02620004490017891, 128, dt)
    assert type(result) is float


def test_quantize_unsupported_dtype():
    with pytest.raises(ValueError) as err:
        ann.quantize(3.3274056911468506, 0.02620004490017891, 128, 'int8')

    assert 'Unexpected target datatype int8 given.' in str(err.value)


def test_dequantize_unsupported_dtype():
    with pytest.raises(ValueError) as err:
        ann.dequantize(3, 0.02620004490017891, 128, 'int8')

    assert 'Unexpected value datatype int8 given.' in str(err.value)


def test_dequantize_value_range():
    with pytest.raises(ValueError) as err:
        ann.dequantize(-1, 0.02620004490017891, 128, 'uint8')

    assert 'Value is not within range of the given datatype uint8' in str(err.value)


@pytest.mark.parametrize('dt, data', [('uint8', np.uint8(255)),
                                      ('int16', np.int16(32767)),
                                      ('int32', np.int32(2147483647)),

                                      ('uint8', np.int16(255)),
                                      ('uint8', np.int32(255)),

                                      ('int16', np.uint8(255)),
                                      ('int16', np.int32(32767)),

                                      ('int32', np.uint8(255)),
                                      ('int32', np.int16(32767))

                                      ])
def test_dequantize_numpy_dt(dt, data):
    result = ann.dequantize(data, 1, 0, dt)

    assert type(result) is float

    assert np.float32(data) == result