aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/tensorflow')
-rw-r--r--src/mlia/nn/tensorflow/config.py101
-rw-r--r--src/mlia/nn/tensorflow/optimizations/quantization.py74
-rw-r--r--src/mlia/nn/tensorflow/utils.py30
3 files changed, 202 insertions, 3 deletions
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py
index c6a7c88..b94350a 100644
--- a/src/mlia/nn/tensorflow/config.py
+++ b/src/mlia/nn/tensorflow/config.py
@@ -7,13 +7,23 @@ import logging
import tempfile
from collections import defaultdict
from pathlib import Path
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import List
import numpy as np
import tensorflow as tf
from mlia.core.context import Context
+from mlia.nn.tensorflow.optimizations.quantization import dequantize
+from mlia.nn.tensorflow.optimizations.quantization import is_quantized
+from mlia.nn.tensorflow.optimizations.quantization import QuantizationParameters
+from mlia.nn.tensorflow.optimizations.quantization import quantize
from mlia.nn.tensorflow.tflite_graph import load_fb
from mlia.nn.tensorflow.tflite_graph import save_fb
+from mlia.nn.tensorflow.utils import check_tflite_datatypes
from mlia.nn.tensorflow.utils import convert_to_tflite
from mlia.nn.tensorflow.utils import is_keras_model
from mlia.nn.tensorflow.utils import is_saved_model
@@ -71,6 +81,11 @@ class KerasModel(ModelConfiguration):
return self
+TFLiteIODetails = Dict[str, Dict[str, Any]]
+TFLiteIODetailsList = List[TFLiteIODetails]
+NameToTensorMap = Dict[str, np.ndarray]
+
+
class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
"""TensorFlow Lite model configuration."""
@@ -119,7 +134,22 @@ class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
self.shape_from_name = {d["name"]: d["shape"] for d in details}
self.batch_size = next(iter(self.shape_from_name.values()))[0]
- def __call__(self, named_input: dict) -> dict:
+ # Prepare quantization parameters for input and output
+ def named_quant_params(
+ details: TFLiteIODetailsList,
+ ) -> dict[str, QuantizationParameters]:
+ return {
+ str(detail["name"]): QuantizationParameters(
+ **detail["quantization_parameters"]
+ )
+ for detail in details
+ if TFLiteModel._is_tensor_quantized(detail)
+ }
+
+ self._quant_params_input = named_quant_params(self.input_details)
+ self._quant_params_output = named_quant_params(self.output_details)
+
+ def __call__(self, named_input: dict) -> NameToTensorMap:
"""Execute the model on one or a batch of named inputs \
(a dict of name: numpy array)."""
input_len = next(iter(named_input.values())).shape[0]
@@ -150,11 +180,11 @@ class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
)
return {k: np.concatenate(v) for k, v in named_ys.items()}
- def input_tensors(self) -> list:
+ def input_tensors(self) -> list[str]:
"""Return name from input details."""
return [d["name"] for d in self.input_details]
- def output_tensors(self) -> list:
+ def output_tensors(self) -> list[str]:
"""Return name from output details."""
return [d["name"] for d in self.output_details]
@@ -164,6 +194,71 @@ class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
"""Convert model to TensorFlow Lite format."""
return self
+ def _tensor_details(
+ self, name: str | None = None, idx: int | None = None
+ ) -> TFLiteIODetails:
+ """Get the details of the tensor by name or index."""
+ if idx is not None:
+ details = self.interpreter.get_tensor_details()[idx]
+ assert details["index"] == idx
+ elif name is not None:
+ for details_ in self.interpreter.get_tensor_details():
+ if name == details_["name"]:
+ details = details_
+ break
+ else:
+ raise NameError(
+ f"Tensor '{name}' not found in model {self.model_path}."
+ )
+ else:
+ raise ValueError("Either tensor name or index needs to be passed.")
+
+ assert isinstance(details, dict)
+ return cast(TFLiteIODetails, details)
+
+ @staticmethod
+ def _is_tensor_quantized(details: TFLiteIODetails) -> bool:
+ """Use tensor details to check if the corresponding tensor is quantized."""
+ quant_params = QuantizationParameters(**details["quantization_parameters"])
+ return is_quantized(quant_params)
+
+ def is_tensor_quantized(
+ self,
+ name: str | None = None,
+ idx: int | None = None,
+ ) -> bool:
+ """Check if the given tensor (identified by name or index) is quantized."""
+ details = self._tensor_details(name, idx)
+ return self._is_tensor_quantized(details)
+
+ def check_datatypes(self, *allowed_types: type) -> None:
+ """Check if the model only has the given allowed datatypes."""
+ check_tflite_datatypes(self.model_path, *allowed_types)
+
+ @staticmethod
+ def _quant_dequant(
+ tensors: NameToTensorMap,
+ quant_params: dict[str, QuantizationParameters],
+ func: Callable,
+ ) -> NameToTensorMap:
+ """Quantize/de-quantize tensor using the given parameters and function."""
+ return {
+ name: (func(tensor, quant_params[name]) if name in quant_params else tensor)
+ for name, tensor in tensors.items()
+ }
+
+ def dequantize_outputs(self, outputs: NameToTensorMap) -> NameToTensorMap:
+ """De-quantize the given model outputs."""
+ dequant_outputs = self._quant_dequant(
+ outputs, self._quant_params_output, dequantize
+ )
+ return dequant_outputs
+
+ def quantize_inputs(self, inputs: NameToTensorMap) -> NameToTensorMap:
+ """Quantize the given model inputs."""
+ quant_inputs = self._quant_dequant(inputs, self._quant_params_input, quantize)
+ return quant_inputs
+
class TfModel(ModelConfiguration): # pylint: disable=abstract-method
"""TensorFlow model configuration.
diff --git a/src/mlia/nn/tensorflow/optimizations/quantization.py b/src/mlia/nn/tensorflow/optimizations/quantization.py
new file mode 100644
index 0000000..4e3c2c2
--- /dev/null
+++ b/src/mlia/nn/tensorflow/optimizations/quantization.py
@@ -0,0 +1,74 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Contains functionality for quantization and de-quantization."""
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import cast
+
+import numpy as np
+import tensorflow as tf
+
+
+@dataclass
+class QuantizationParameters:
+ """Collection of TensorFlow Lite quantization parameters.
+
+ Can be directly initialized from TensorFlow Lite tensor details, e.g.
+
+ ```
+ QuantizationParameters(
+ **interpreter.get_input_details()[0]["quantization_parameters"]
+ )
+ ```
+ """
+
+ scales: np.ndarray
+ zero_points: np.ndarray
+ quantized_dimension: int
+
+
+def is_quantized(quant_params: QuantizationParameters) -> bool:
+ """Check if the quantization parameters describe a quantized tensor."""
+ return quant_params.scales.size > 0
+
+
+def dequantize(
+ quantized_tensor: np.ndarray | tf.Tensor, quant_params: QuantizationParameters
+) -> np.ndarray:
+ """De-quantize the input tensor using the given quantization parameters."""
+ assert isinstance(quantized_tensor, (tf.Tensor, np.ndarray))
+ assert (
+ not isinstance(quantized_tensor, tf.Tensor)
+ or quantized_tensor.dtype.is_quantized
+ ) and (
+ not isinstance(quantized_tensor, np.ndarray)
+ or issubclass(quantized_tensor.dtype.type, np.integer)
+ ), (
+ f"Input tensor for de-quantization is of type {quantized_tensor.dtype}, "
+ "but it must be int."
+ )
+
+ dequantized_tensor = np.subtract(
+ quantized_tensor, quant_params.zero_points, dtype=np.float32
+ )
+ dequantized_tensor = np.multiply(
+ dequantized_tensor, quant_params.scales, dtype=np.float32
+ )
+ return dequantized_tensor
+
+
+def quantize(
+ tensor: np.ndarray | tf.Tensor, quant_params: QuantizationParameters
+) -> np.ndarray:
+ """Quantize the given float input tensor to int8."""
+ assert isinstance(tensor, (tf.Tensor, np.ndarray))
+ assert (not isinstance(tensor, tf.Tensor) or tensor.dtype.is_floating) and (
+ not isinstance(tensor, np.ndarray) or issubclass(tensor.dtype.type, np.floating)
+ ), f"Input tensor for quantization is of type {tensor.dtype}, but it must be float."
+
+ quantized_tensor = (tensor / quant_params.scales) + quant_params.zero_points
+ quantized_tensor = np.clip( # type: ignore
+ quantized_tensor, -128, 127, dtype=np.int8, casting="unsafe"
+ )
+ return cast(np.ndarray, quantized_tensor)
diff --git a/src/mlia/nn/tensorflow/utils.py b/src/mlia/nn/tensorflow/utils.py
index 77ac529..b8d45c6 100644
--- a/src/mlia/nn/tensorflow/utils.py
+++ b/src/mlia/nn/tensorflow/utils.py
@@ -123,3 +123,33 @@ def get_tflite_converter(
converter.inference_output_type = tf.int8
return converter
+
+
+def get_tflite_model_type_map(model_filename: str | Path) -> dict[str, type]:
+ """Get type map from tflite model."""
+ model_type_map: dict[str, Any] = {}
+ interpreter = tf.lite.Interpreter(str(model_filename))
+ interpreter.allocate_tensors()
+ input_details = interpreter.get_input_details()
+ model_type_map = {
+ input_detail["name"]: input_detail["dtype"] for input_detail in input_details
+ }
+ return model_type_map
+
+
+def check_tflite_datatypes(model_filename: str | Path, *allowed_types: type) -> None:
+ """Check if the model only has the given allowed datatypes."""
+ type_map = get_tflite_model_type_map(model_filename)
+ types = set(type_map.values())
+ allowed = set(allowed_types)
+ unexpected = types - allowed
+
+ def cls_to_str(types: set[type]) -> list[str]:
+ return [t.__name__ for t in types]
+
+ if len(unexpected) > 0:
+ raise TypeError(
+ f"Model {model_filename} has "
+ f"unexpected data types: {cls_to_str(unexpected)}. "
+ f"Only {cls_to_str(allowed)} are allowed."
+ )