aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/config.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/tensorflow/config.py')
-rw-r--r--src/mlia/nn/tensorflow/config.py101
1 files changed, 98 insertions, 3 deletions
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py
index c6a7c88..b94350a 100644
--- a/src/mlia/nn/tensorflow/config.py
+++ b/src/mlia/nn/tensorflow/config.py
@@ -7,13 +7,23 @@ import logging
import tempfile
from collections import defaultdict
from pathlib import Path
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import List
import numpy as np
import tensorflow as tf
from mlia.core.context import Context
+from mlia.nn.tensorflow.optimizations.quantization import dequantize
+from mlia.nn.tensorflow.optimizations.quantization import is_quantized
+from mlia.nn.tensorflow.optimizations.quantization import QuantizationParameters
+from mlia.nn.tensorflow.optimizations.quantization import quantize
from mlia.nn.tensorflow.tflite_graph import load_fb
from mlia.nn.tensorflow.tflite_graph import save_fb
+from mlia.nn.tensorflow.utils import check_tflite_datatypes
from mlia.nn.tensorflow.utils import convert_to_tflite
from mlia.nn.tensorflow.utils import is_keras_model
from mlia.nn.tensorflow.utils import is_saved_model
@@ -71,6 +81,11 @@ class KerasModel(ModelConfiguration):
return self
+TFLiteIODetails = Dict[str, Dict[str, Any]]
+TFLiteIODetailsList = List[TFLiteIODetails]
+NameToTensorMap = Dict[str, np.ndarray]
+
+
class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
"""TensorFlow Lite model configuration."""
@@ -119,7 +134,22 @@ class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
self.shape_from_name = {d["name"]: d["shape"] for d in details}
self.batch_size = next(iter(self.shape_from_name.values()))[0]
- def __call__(self, named_input: dict) -> dict:
+ # Prepare quantization parameters for input and output
+ def named_quant_params(
+ details: TFLiteIODetailsList,
+ ) -> dict[str, QuantizationParameters]:
+ return {
+ str(detail["name"]): QuantizationParameters(
+ **detail["quantization_parameters"]
+ )
+ for detail in details
+ if TFLiteModel._is_tensor_quantized(detail)
+ }
+
+ self._quant_params_input = named_quant_params(self.input_details)
+ self._quant_params_output = named_quant_params(self.output_details)
+
+ def __call__(self, named_input: dict) -> NameToTensorMap:
"""Execute the model on one or a batch of named inputs \
(a dict of name: numpy array)."""
input_len = next(iter(named_input.values())).shape[0]
@@ -150,11 +180,11 @@ class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
)
return {k: np.concatenate(v) for k, v in named_ys.items()}
- def input_tensors(self) -> list:
+ def input_tensors(self) -> list[str]:
"""Return name from input details."""
return [d["name"] for d in self.input_details]
- def output_tensors(self) -> list:
+ def output_tensors(self) -> list[str]:
"""Return name from output details."""
return [d["name"] for d in self.output_details]
@@ -164,6 +194,71 @@ class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
"""Convert model to TensorFlow Lite format."""
return self
+ def _tensor_details(
+ self, name: str | None = None, idx: int | None = None
+ ) -> TFLiteIODetails:
+ """Get the details of the tensor by name or index."""
+ if idx is not None:
+ details = self.interpreter.get_tensor_details()[idx]
+ assert details["index"] == idx
+ elif name is not None:
+ for details_ in self.interpreter.get_tensor_details():
+ if name == details_["name"]:
+ details = details_
+ break
+ else:
+ raise NameError(
+ f"Tensor '{name}' not found in model {self.model_path}."
+ )
+ else:
+ raise ValueError("Either tensor name or index needs to be passed.")
+
+ assert isinstance(details, dict)
+ return cast(TFLiteIODetails, details)
+
+ @staticmethod
+ def _is_tensor_quantized(details: TFLiteIODetails) -> bool:
+ """Use tensor details to check if the corresponding tensor is quantized."""
+ quant_params = QuantizationParameters(**details["quantization_parameters"])
+ return is_quantized(quant_params)
+
+ def is_tensor_quantized(
+ self,
+ name: str | None = None,
+ idx: int | None = None,
+ ) -> bool:
+ """Check if the given tensor (identified by name or index) is quantized."""
+ details = self._tensor_details(name, idx)
+ return self._is_tensor_quantized(details)
+
+ def check_datatypes(self, *allowed_types: type) -> None:
+ """Check if the model only has the given allowed datatypes."""
+ check_tflite_datatypes(self.model_path, *allowed_types)
+
+ @staticmethod
+ def _quant_dequant(
+ tensors: NameToTensorMap,
+ quant_params: dict[str, QuantizationParameters],
+ func: Callable,
+ ) -> NameToTensorMap:
+ """Quantize/de-quantize tensor using the given parameters and function."""
+ return {
+ name: (func(tensor, quant_params[name]) if name in quant_params else tensor)
+ for name, tensor in tensors.items()
+ }
+
+ def dequantize_outputs(self, outputs: NameToTensorMap) -> NameToTensorMap:
+ """De-quantize the given model outputs."""
+ dequant_outputs = self._quant_dequant(
+ outputs, self._quant_params_output, dequantize
+ )
+ return dequant_outputs
+
+ def quantize_inputs(self, inputs: NameToTensorMap) -> NameToTensorMap:
+ """Quantize the given model inputs."""
+ quant_inputs = self._quant_dequant(inputs, self._quant_params_input, quantize)
+ return quant_inputs
+
class TfModel(ModelConfiguration): # pylint: disable=abstract-method
"""TensorFlow model configuration.