1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
|
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Contains functionality for quantization and de-quantization."""
from __future__ import annotations
from dataclasses import dataclass
from typing import cast
import numpy as np
import tensorflow as tf
@dataclass
class QuantizationParameters:
"""Collection of TensorFlow Lite quantization parameters.
Can be directly initialized from TensorFlow Lite tensor details, e.g.
```
QuantizationParameters(
**interpreter.get_input_details()[0]["quantization_parameters"]
)
```
"""
scales: np.ndarray
zero_points: np.ndarray
quantized_dimension: int
def is_quantized(quant_params: QuantizationParameters) -> bool:
"""Check if the quantization parameters describe a quantized tensor."""
return quant_params.scales.size > 0
def dequantize(
quantized_tensor: np.ndarray | tf.Tensor, quant_params: QuantizationParameters
) -> np.ndarray:
"""De-quantize the input tensor using the given quantization parameters."""
assert isinstance(quantized_tensor, (tf.Tensor, np.ndarray))
assert (
not isinstance(quantized_tensor, tf.Tensor)
or quantized_tensor.dtype.is_quantized
) and (
not isinstance(quantized_tensor, np.ndarray)
or issubclass(quantized_tensor.dtype.type, np.integer)
), (
f"Input tensor for de-quantization is of type {quantized_tensor.dtype}, "
"but it must be int."
)
dequantized_tensor = np.subtract(
quantized_tensor, quant_params.zero_points, dtype=np.float32
)
dequantized_tensor = np.multiply(
dequantized_tensor, quant_params.scales, dtype=np.float32
)
return dequantized_tensor
def quantize(
tensor: np.ndarray | tf.Tensor, quant_params: QuantizationParameters
) -> np.ndarray:
"""Quantize the given float input tensor to int8."""
assert isinstance(tensor, (tf.Tensor, np.ndarray))
assert (not isinstance(tensor, tf.Tensor) or tensor.dtype.is_floating) and (
not isinstance(tensor, np.ndarray) or issubclass(tensor.dtype.type, np.floating)
), f"Input tensor for quantization is of type {tensor.dtype}, but it must be float."
quantized_tensor = (tensor / quant_params.scales) + quant_params.zero_points
quantized_tensor = np.clip( # type: ignore
quantized_tensor, -128, 127, dtype=np.int8, casting="unsafe"
)
return cast(np.ndarray, quantized_tensor)
|