diff options
Diffstat (limited to 'src/mlia/nn/tensorflow')
-rw-r--r-- | src/mlia/nn/tensorflow/config.py | 101 | ||||
-rw-r--r-- | src/mlia/nn/tensorflow/optimizations/quantization.py | 74 | ||||
-rw-r--r-- | src/mlia/nn/tensorflow/utils.py | 30 |
3 files changed, 202 insertions, 3 deletions
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py index c6a7c88..b94350a 100644 --- a/src/mlia/nn/tensorflow/config.py +++ b/src/mlia/nn/tensorflow/config.py @@ -7,13 +7,23 @@ import logging import tempfile from collections import defaultdict from pathlib import Path +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import List import numpy as np import tensorflow as tf from mlia.core.context import Context +from mlia.nn.tensorflow.optimizations.quantization import dequantize +from mlia.nn.tensorflow.optimizations.quantization import is_quantized +from mlia.nn.tensorflow.optimizations.quantization import QuantizationParameters +from mlia.nn.tensorflow.optimizations.quantization import quantize from mlia.nn.tensorflow.tflite_graph import load_fb from mlia.nn.tensorflow.tflite_graph import save_fb +from mlia.nn.tensorflow.utils import check_tflite_datatypes from mlia.nn.tensorflow.utils import convert_to_tflite from mlia.nn.tensorflow.utils import is_keras_model from mlia.nn.tensorflow.utils import is_saved_model @@ -71,6 +81,11 @@ class KerasModel(ModelConfiguration): return self +TFLiteIODetails = Dict[str, Dict[str, Any]] +TFLiteIODetailsList = List[TFLiteIODetails] +NameToTensorMap = Dict[str, np.ndarray] + + class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method """TensorFlow Lite model configuration.""" @@ -119,7 +134,22 @@ class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method self.shape_from_name = {d["name"]: d["shape"] for d in details} self.batch_size = next(iter(self.shape_from_name.values()))[0] - def __call__(self, named_input: dict) -> dict: + # Prepare quantization parameters for input and output + def named_quant_params( + details: TFLiteIODetailsList, + ) -> dict[str, QuantizationParameters]: + return { + str(detail["name"]): QuantizationParameters( + **detail["quantization_parameters"] + ) + for detail in details + if TFLiteModel._is_tensor_quantized(detail) + } + + self._quant_params_input = named_quant_params(self.input_details) + self._quant_params_output = named_quant_params(self.output_details) + + def __call__(self, named_input: dict) -> NameToTensorMap: """Execute the model on one or a batch of named inputs \ (a dict of name: numpy array).""" input_len = next(iter(named_input.values())).shape[0] @@ -150,11 +180,11 @@ class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method ) return {k: np.concatenate(v) for k, v in named_ys.items()} - def input_tensors(self) -> list: + def input_tensors(self) -> list[str]: """Return name from input details.""" return [d["name"] for d in self.input_details] - def output_tensors(self) -> list: + def output_tensors(self) -> list[str]: """Return name from output details.""" return [d["name"] for d in self.output_details] @@ -164,6 +194,71 @@ class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method """Convert model to TensorFlow Lite format.""" return self + def _tensor_details( + self, name: str | None = None, idx: int | None = None + ) -> TFLiteIODetails: + """Get the details of the tensor by name or index.""" + if idx is not None: + details = self.interpreter.get_tensor_details()[idx] + assert details["index"] == idx + elif name is not None: + for details_ in self.interpreter.get_tensor_details(): + if name == details_["name"]: + details = details_ + break + else: + raise NameError( + f"Tensor '{name}' not found in model {self.model_path}." + ) + else: + raise ValueError("Either tensor name or index needs to be passed.") + + assert isinstance(details, dict) + return cast(TFLiteIODetails, details) + + @staticmethod + def _is_tensor_quantized(details: TFLiteIODetails) -> bool: + """Use tensor details to check if the corresponding tensor is quantized.""" + quant_params = QuantizationParameters(**details["quantization_parameters"]) + return is_quantized(quant_params) + + def is_tensor_quantized( + self, + name: str | None = None, + idx: int | None = None, + ) -> bool: + """Check if the given tensor (identified by name or index) is quantized.""" + details = self._tensor_details(name, idx) + return self._is_tensor_quantized(details) + + def check_datatypes(self, *allowed_types: type) -> None: + """Check if the model only has the given allowed datatypes.""" + check_tflite_datatypes(self.model_path, *allowed_types) + + @staticmethod + def _quant_dequant( + tensors: NameToTensorMap, + quant_params: dict[str, QuantizationParameters], + func: Callable, + ) -> NameToTensorMap: + """Quantize/de-quantize tensor using the given parameters and function.""" + return { + name: (func(tensor, quant_params[name]) if name in quant_params else tensor) + for name, tensor in tensors.items() + } + + def dequantize_outputs(self, outputs: NameToTensorMap) -> NameToTensorMap: + """De-quantize the given model outputs.""" + dequant_outputs = self._quant_dequant( + outputs, self._quant_params_output, dequantize + ) + return dequant_outputs + + def quantize_inputs(self, inputs: NameToTensorMap) -> NameToTensorMap: + """Quantize the given model inputs.""" + quant_inputs = self._quant_dequant(inputs, self._quant_params_input, quantize) + return quant_inputs + class TfModel(ModelConfiguration): # pylint: disable=abstract-method """TensorFlow model configuration. diff --git a/src/mlia/nn/tensorflow/optimizations/quantization.py b/src/mlia/nn/tensorflow/optimizations/quantization.py new file mode 100644 index 0000000..4e3c2c2 --- /dev/null +++ b/src/mlia/nn/tensorflow/optimizations/quantization.py @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Contains functionality for quantization and de-quantization.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import cast + +import numpy as np +import tensorflow as tf + + +@dataclass +class QuantizationParameters: + """Collection of TensorFlow Lite quantization parameters. + + Can be directly initialized from TensorFlow Lite tensor details, e.g. + + ``` + QuantizationParameters( + **interpreter.get_input_details()[0]["quantization_parameters"] + ) + ``` + """ + + scales: np.ndarray + zero_points: np.ndarray + quantized_dimension: int + + +def is_quantized(quant_params: QuantizationParameters) -> bool: + """Check if the quantization parameters describe a quantized tensor.""" + return quant_params.scales.size > 0 + + +def dequantize( + quantized_tensor: np.ndarray | tf.Tensor, quant_params: QuantizationParameters +) -> np.ndarray: + """De-quantize the input tensor using the given quantization parameters.""" + assert isinstance(quantized_tensor, (tf.Tensor, np.ndarray)) + assert ( + not isinstance(quantized_tensor, tf.Tensor) + or quantized_tensor.dtype.is_quantized + ) and ( + not isinstance(quantized_tensor, np.ndarray) + or issubclass(quantized_tensor.dtype.type, np.integer) + ), ( + f"Input tensor for de-quantization is of type {quantized_tensor.dtype}, " + "but it must be int." + ) + + dequantized_tensor = np.subtract( + quantized_tensor, quant_params.zero_points, dtype=np.float32 + ) + dequantized_tensor = np.multiply( + dequantized_tensor, quant_params.scales, dtype=np.float32 + ) + return dequantized_tensor + + +def quantize( + tensor: np.ndarray | tf.Tensor, quant_params: QuantizationParameters +) -> np.ndarray: + """Quantize the given float input tensor to int8.""" + assert isinstance(tensor, (tf.Tensor, np.ndarray)) + assert (not isinstance(tensor, tf.Tensor) or tensor.dtype.is_floating) and ( + not isinstance(tensor, np.ndarray) or issubclass(tensor.dtype.type, np.floating) + ), f"Input tensor for quantization is of type {tensor.dtype}, but it must be float." + + quantized_tensor = (tensor / quant_params.scales) + quant_params.zero_points + quantized_tensor = np.clip( # type: ignore + quantized_tensor, -128, 127, dtype=np.int8, casting="unsafe" + ) + return cast(np.ndarray, quantized_tensor) diff --git a/src/mlia/nn/tensorflow/utils.py b/src/mlia/nn/tensorflow/utils.py index 77ac529..b8d45c6 100644 --- a/src/mlia/nn/tensorflow/utils.py +++ b/src/mlia/nn/tensorflow/utils.py @@ -123,3 +123,33 @@ def get_tflite_converter( converter.inference_output_type = tf.int8 return converter + + +def get_tflite_model_type_map(model_filename: str | Path) -> dict[str, type]: + """Get type map from tflite model.""" + model_type_map: dict[str, Any] = {} + interpreter = tf.lite.Interpreter(str(model_filename)) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + model_type_map = { + input_detail["name"]: input_detail["dtype"] for input_detail in input_details + } + return model_type_map + + +def check_tflite_datatypes(model_filename: str | Path, *allowed_types: type) -> None: + """Check if the model only has the given allowed datatypes.""" + type_map = get_tflite_model_type_map(model_filename) + types = set(type_map.values()) + allowed = set(allowed_types) + unexpected = types - allowed + + def cls_to_str(types: set[type]) -> list[str]: + return [t.__name__ for t in types] + + if len(unexpected) > 0: + raise TypeError( + f"Model {model_filename} has " + f"unexpected data types: {cls_to_str(unexpected)}. " + f"Only {cls_to_str(allowed)} are allowed." + ) |