aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/config.py
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2023-07-12 15:18:26 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-10-11 16:16:32 +0100
commitecc4264b93d4a89fa2cb40518b225d8371b7ffad (patch)
tree47244d2d67ab6c50bfc15eab768252359eae0df6 /src/mlia/nn/tensorflow/config.py
parentbaaf4de286762c1955c874f78cd802d4703a8ba5 (diff)
downloadmlia-ecc4264b93d4a89fa2cb40518b225d8371b7ffad.tar.gz
Enable rewrites for quantized input models
If the input model for rewriting is quantized: - Record de-quantized TFRecords - enable writing de-quantized calibration data for the training - re-generate augmented training data, if needed - Use quantization-aware training (QAT) to train the replacement models - Check if replacement model is quantized: If source model is quantized, we make sure rewrite's output model is quantized too. Right now, only int8 is supported so raising an error if any other datatype is present in the output. Resolves: MLIA-907, MLIA-908, MLIA-927 Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com> Change-Id: Icb4070a9e6f1fdb5ce36120d73823986e89ac955
Diffstat (limited to 'src/mlia/nn/tensorflow/config.py')
-rw-r--r--src/mlia/nn/tensorflow/config.py101
1 files changed, 98 insertions, 3 deletions
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py
index c6a7c88..b94350a 100644
--- a/src/mlia/nn/tensorflow/config.py
+++ b/src/mlia/nn/tensorflow/config.py
@@ -7,13 +7,23 @@ import logging
import tempfile
from collections import defaultdict
from pathlib import Path
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import List
import numpy as np
import tensorflow as tf
from mlia.core.context import Context
+from mlia.nn.tensorflow.optimizations.quantization import dequantize
+from mlia.nn.tensorflow.optimizations.quantization import is_quantized
+from mlia.nn.tensorflow.optimizations.quantization import QuantizationParameters
+from mlia.nn.tensorflow.optimizations.quantization import quantize
from mlia.nn.tensorflow.tflite_graph import load_fb
from mlia.nn.tensorflow.tflite_graph import save_fb
+from mlia.nn.tensorflow.utils import check_tflite_datatypes
from mlia.nn.tensorflow.utils import convert_to_tflite
from mlia.nn.tensorflow.utils import is_keras_model
from mlia.nn.tensorflow.utils import is_saved_model
@@ -71,6 +81,11 @@ class KerasModel(ModelConfiguration):
return self
+TFLiteIODetails = Dict[str, Dict[str, Any]]
+TFLiteIODetailsList = List[TFLiteIODetails]
+NameToTensorMap = Dict[str, np.ndarray]
+
+
class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
"""TensorFlow Lite model configuration."""
@@ -119,7 +134,22 @@ class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
self.shape_from_name = {d["name"]: d["shape"] for d in details}
self.batch_size = next(iter(self.shape_from_name.values()))[0]
- def __call__(self, named_input: dict) -> dict:
+ # Prepare quantization parameters for input and output
+ def named_quant_params(
+ details: TFLiteIODetailsList,
+ ) -> dict[str, QuantizationParameters]:
+ return {
+ str(detail["name"]): QuantizationParameters(
+ **detail["quantization_parameters"]
+ )
+ for detail in details
+ if TFLiteModel._is_tensor_quantized(detail)
+ }
+
+ self._quant_params_input = named_quant_params(self.input_details)
+ self._quant_params_output = named_quant_params(self.output_details)
+
+ def __call__(self, named_input: dict) -> NameToTensorMap:
"""Execute the model on one or a batch of named inputs \
(a dict of name: numpy array)."""
input_len = next(iter(named_input.values())).shape[0]
@@ -150,11 +180,11 @@ class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
)
return {k: np.concatenate(v) for k, v in named_ys.items()}
- def input_tensors(self) -> list:
+ def input_tensors(self) -> list[str]:
"""Return name from input details."""
return [d["name"] for d in self.input_details]
- def output_tensors(self) -> list:
+ def output_tensors(self) -> list[str]:
"""Return name from output details."""
return [d["name"] for d in self.output_details]
@@ -164,6 +194,71 @@ class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
"""Convert model to TensorFlow Lite format."""
return self
+ def _tensor_details(
+ self, name: str | None = None, idx: int | None = None
+ ) -> TFLiteIODetails:
+ """Get the details of the tensor by name or index."""
+ if idx is not None:
+ details = self.interpreter.get_tensor_details()[idx]
+ assert details["index"] == idx
+ elif name is not None:
+ for details_ in self.interpreter.get_tensor_details():
+ if name == details_["name"]:
+ details = details_
+ break
+ else:
+ raise NameError(
+ f"Tensor '{name}' not found in model {self.model_path}."
+ )
+ else:
+ raise ValueError("Either tensor name or index needs to be passed.")
+
+ assert isinstance(details, dict)
+ return cast(TFLiteIODetails, details)
+
+ @staticmethod
+ def _is_tensor_quantized(details: TFLiteIODetails) -> bool:
+ """Use tensor details to check if the corresponding tensor is quantized."""
+ quant_params = QuantizationParameters(**details["quantization_parameters"])
+ return is_quantized(quant_params)
+
+ def is_tensor_quantized(
+ self,
+ name: str | None = None,
+ idx: int | None = None,
+ ) -> bool:
+ """Check if the given tensor (identified by name or index) is quantized."""
+ details = self._tensor_details(name, idx)
+ return self._is_tensor_quantized(details)
+
+ def check_datatypes(self, *allowed_types: type) -> None:
+ """Check if the model only has the given allowed datatypes."""
+ check_tflite_datatypes(self.model_path, *allowed_types)
+
+ @staticmethod
+ def _quant_dequant(
+ tensors: NameToTensorMap,
+ quant_params: dict[str, QuantizationParameters],
+ func: Callable,
+ ) -> NameToTensorMap:
+ """Quantize/de-quantize tensor using the given parameters and function."""
+ return {
+ name: (func(tensor, quant_params[name]) if name in quant_params else tensor)
+ for name, tensor in tensors.items()
+ }
+
+ def dequantize_outputs(self, outputs: NameToTensorMap) -> NameToTensorMap:
+ """De-quantize the given model outputs."""
+ dequant_outputs = self._quant_dequant(
+ outputs, self._quant_params_output, dequantize
+ )
+ return dequant_outputs
+
+ def quantize_inputs(self, inputs: NameToTensorMap) -> NameToTensorMap:
+ """Quantize the given model inputs."""
+ quant_inputs = self._quant_dequant(inputs, self._quant_params_input, quantize)
+ return quant_inputs
+
class TfModel(ModelConfiguration): # pylint: disable=abstract-method
"""TensorFlow model configuration.