diff options
Diffstat (limited to 'src/mlia/nn/tensorflow/config.py')
-rw-r--r-- | src/mlia/nn/tensorflow/config.py | 101 |
1 files changed, 98 insertions, 3 deletions
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py index c6a7c88..b94350a 100644 --- a/src/mlia/nn/tensorflow/config.py +++ b/src/mlia/nn/tensorflow/config.py @@ -7,13 +7,23 @@ import logging import tempfile from collections import defaultdict from pathlib import Path +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import List import numpy as np import tensorflow as tf from mlia.core.context import Context +from mlia.nn.tensorflow.optimizations.quantization import dequantize +from mlia.nn.tensorflow.optimizations.quantization import is_quantized +from mlia.nn.tensorflow.optimizations.quantization import QuantizationParameters +from mlia.nn.tensorflow.optimizations.quantization import quantize from mlia.nn.tensorflow.tflite_graph import load_fb from mlia.nn.tensorflow.tflite_graph import save_fb +from mlia.nn.tensorflow.utils import check_tflite_datatypes from mlia.nn.tensorflow.utils import convert_to_tflite from mlia.nn.tensorflow.utils import is_keras_model from mlia.nn.tensorflow.utils import is_saved_model @@ -71,6 +81,11 @@ class KerasModel(ModelConfiguration): return self +TFLiteIODetails = Dict[str, Dict[str, Any]] +TFLiteIODetailsList = List[TFLiteIODetails] +NameToTensorMap = Dict[str, np.ndarray] + + class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method """TensorFlow Lite model configuration.""" @@ -119,7 +134,22 @@ class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method self.shape_from_name = {d["name"]: d["shape"] for d in details} self.batch_size = next(iter(self.shape_from_name.values()))[0] - def __call__(self, named_input: dict) -> dict: + # Prepare quantization parameters for input and output + def named_quant_params( + details: TFLiteIODetailsList, + ) -> dict[str, QuantizationParameters]: + return { + str(detail["name"]): QuantizationParameters( + **detail["quantization_parameters"] + ) + for detail in details + if TFLiteModel._is_tensor_quantized(detail) + } + + self._quant_params_input = named_quant_params(self.input_details) + self._quant_params_output = named_quant_params(self.output_details) + + def __call__(self, named_input: dict) -> NameToTensorMap: """Execute the model on one or a batch of named inputs \ (a dict of name: numpy array).""" input_len = next(iter(named_input.values())).shape[0] @@ -150,11 +180,11 @@ class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method ) return {k: np.concatenate(v) for k, v in named_ys.items()} - def input_tensors(self) -> list: + def input_tensors(self) -> list[str]: """Return name from input details.""" return [d["name"] for d in self.input_details] - def output_tensors(self) -> list: + def output_tensors(self) -> list[str]: """Return name from output details.""" return [d["name"] for d in self.output_details] @@ -164,6 +194,71 @@ class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method """Convert model to TensorFlow Lite format.""" return self + def _tensor_details( + self, name: str | None = None, idx: int | None = None + ) -> TFLiteIODetails: + """Get the details of the tensor by name or index.""" + if idx is not None: + details = self.interpreter.get_tensor_details()[idx] + assert details["index"] == idx + elif name is not None: + for details_ in self.interpreter.get_tensor_details(): + if name == details_["name"]: + details = details_ + break + else: + raise NameError( + f"Tensor '{name}' not found in model {self.model_path}." + ) + else: + raise ValueError("Either tensor name or index needs to be passed.") + + assert isinstance(details, dict) + return cast(TFLiteIODetails, details) + + @staticmethod + def _is_tensor_quantized(details: TFLiteIODetails) -> bool: + """Use tensor details to check if the corresponding tensor is quantized.""" + quant_params = QuantizationParameters(**details["quantization_parameters"]) + return is_quantized(quant_params) + + def is_tensor_quantized( + self, + name: str | None = None, + idx: int | None = None, + ) -> bool: + """Check if the given tensor (identified by name or index) is quantized.""" + details = self._tensor_details(name, idx) + return self._is_tensor_quantized(details) + + def check_datatypes(self, *allowed_types: type) -> None: + """Check if the model only has the given allowed datatypes.""" + check_tflite_datatypes(self.model_path, *allowed_types) + + @staticmethod + def _quant_dequant( + tensors: NameToTensorMap, + quant_params: dict[str, QuantizationParameters], + func: Callable, + ) -> NameToTensorMap: + """Quantize/de-quantize tensor using the given parameters and function.""" + return { + name: (func(tensor, quant_params[name]) if name in quant_params else tensor) + for name, tensor in tensors.items() + } + + def dequantize_outputs(self, outputs: NameToTensorMap) -> NameToTensorMap: + """De-quantize the given model outputs.""" + dequant_outputs = self._quant_dequant( + outputs, self._quant_params_output, dequantize + ) + return dequant_outputs + + def quantize_inputs(self, inputs: NameToTensorMap) -> NameToTensorMap: + """Quantize the given model inputs.""" + quant_inputs = self._quant_dequant(inputs, self._quant_params_input, quantize) + return quant_inputs + class TfModel(ModelConfiguration): # pylint: disable=abstract-method """TensorFlow model configuration. |