aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/mlia/nn/rewrite/core/extract.py61
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/cut.py16
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/record.py45
-rw-r--r--src/mlia/nn/rewrite/core/train.py67
-rw-r--r--src/mlia/nn/tensorflow/config.py101
-rw-r--r--src/mlia/nn/tensorflow/optimizations/quantization.py74
-rw-r--r--src/mlia/nn/tensorflow/utils.py30
-rw-r--r--tests/test_nn_rewrite_core_extract.py38
-rw-r--r--tests/test_nn_rewrite_core_graph_edit_record.py63
-rw-r--r--tests/test_nn_rewrite_core_train.py67
-rw-r--r--tests/test_nn_tensorflow_config.py12
-rw-r--r--tests/test_nn_tensorflow_optimizations_quantization.py53
-rw-r--r--tests/test_nn_tensorflow_utils.py31
13 files changed, 598 insertions, 60 deletions
diff --git a/src/mlia/nn/rewrite/core/extract.py b/src/mlia/nn/rewrite/core/extract.py
index f609955..4fcf735 100644
--- a/src/mlia/nn/rewrite/core/extract.py
+++ b/src/mlia/nn/rewrite/core/extract.py
@@ -2,19 +2,62 @@
# SPDX-License-Identifier: Apache-2.0
"""Extract module."""
# pylint: disable=too-many-arguments, too-many-locals
+from __future__ import annotations
+
import os
+from functools import partial
+from pathlib import Path
import tensorflow as tf
from tensorflow.lite.python.schema_py_generated import SubGraphT
from mlia.nn.rewrite.core.graph_edit.cut import cut_model
+from mlia.nn.rewrite.core.graph_edit.record import dequantized_path
from mlia.nn.rewrite.core.graph_edit.record import record_model
-
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+def _get_path(
+ ext: str, name: str, dir_path: str | Path, model_is_quantized: bool = False
+) -> Path:
+ """Create a file path for extracted files."""
+ path = Path(dir_path, f"{name}{ext}")
+ return dequantized_path(path) if model_is_quantized else path
+
+
+class TFLitePaths: # pylint: disable=too-few-public-methods
+ """Provide safe access to TensorFlow Lite file paths."""
+
+ _get_path_tflite = partial(_get_path, ".tflite")
+
+ start = partial(_get_path_tflite, "start")
+ replace = partial(_get_path_tflite, "replace")
+ end = partial(_get_path_tflite, "end")
+
+
+class TFRecordPaths: # pylint: disable=too-few-public-methods
+ """Provide safe access to tfrec file paths."""
+
+ _get_path_tfrec = partial(_get_path, ".tfrec")
+
+ input = partial(_get_path_tfrec, "input")
+ output = partial(_get_path_tfrec, "output")
+ end = partial(_get_path_tfrec, "end")
+
+
+class ExtractPaths: # pylint: disable=too-few-public-methods
+ """Get paths to extract files.
+
+ This is meant to be the single source of truth regarding all file names
+ created by the extract() function in an output directory.
+ """
+
+ tflite = TFLitePaths
+ tfrec = TFRecordPaths
+
+
def extract(
output_path: str,
model_file: str,
@@ -26,6 +69,7 @@ def extract(
show_progress: bool = False,
num_procs: int = 1,
num_threads: int = 0,
+ dequantize_output: bool = False,
) -> None:
"""Extract a model after cut and record."""
try:
@@ -33,7 +77,7 @@ def extract(
except FileExistsError:
pass
- start_file = os.path.join(output_path, "start.tflite")
+ start_file = ExtractPaths.tflite.start(output_path)
cut_model(
model_file,
input_names=None,
@@ -42,7 +86,7 @@ def extract(
output_file=start_file,
)
- input_tfrec = os.path.join(output_path, "input.tfrec")
+ input_tfrec = ExtractPaths.tfrec.input(output_path)
record_model(
input_filename,
start_file,
@@ -50,9 +94,10 @@ def extract(
show_progress=show_progress,
num_procs=num_procs,
num_threads=num_threads,
+ dequantize_output=dequantize_output,
)
- replace_file = os.path.join(output_path, "replace.tflite")
+ replace_file = ExtractPaths.tflite.replace(output_path)
cut_model(
model_file,
input_names=input_names,
@@ -61,7 +106,7 @@ def extract(
output_file=replace_file,
)
- end_file = os.path.join(output_path, "end.tflite")
+ end_file = ExtractPaths.tflite.end(output_path)
cut_model(
model_file,
input_names=output_names,
@@ -71,7 +116,7 @@ def extract(
)
if not skip_outputs:
- output_tfrec = os.path.join(output_path, "output.tfrec")
+ output_tfrec = ExtractPaths.tfrec.output(output_path)
record_model(
input_tfrec,
replace_file,
@@ -79,9 +124,10 @@ def extract(
show_progress=show_progress,
num_procs=num_procs,
num_threads=num_threads,
+ dequantize_output=dequantize_output,
)
- end_tfrec = os.path.join(output_path, "end.tfrec")
+ end_tfrec = ExtractPaths.tfrec.end(output_path)
record_model(
output_tfrec,
end_file,
@@ -89,4 +135,5 @@ def extract(
show_progress=show_progress,
num_procs=num_procs,
num_threads=num_threads,
+ dequantize_output=dequantize_output,
)
diff --git a/src/mlia/nn/rewrite/core/graph_edit/cut.py b/src/mlia/nn/rewrite/core/graph_edit/cut.py
index 13a5268..53d5389 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/cut.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/cut.py
@@ -1,9 +1,11 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Cut module."""
+from __future__ import annotations
+
import os
from collections import defaultdict
-from typing import Optional
+from pathlib import Path
import tensorflow as tf
from tensorflow.lite.python.schema_py_generated import ModelT
@@ -25,8 +27,8 @@ def tensors_by_name(subgraph: SubGraphT, names: list) -> list:
def cut_subgraph(
subgraph: SubGraphT,
- input_tensor_names: Optional[list],
- output_tensor_names: Optional[list],
+ input_tensor_names: list | None,
+ output_tensor_names: list | None,
) -> None:
"""Change the global inputs and outputs of a graph to the provided named tensors."""
if input_tensor_names is not None:
@@ -131,11 +133,11 @@ def filter_relabel(src_subgraph: SubGraphT, relabel_filter: set) -> tuple:
def cut_model(
- model_file: str,
- input_names: Optional[list],
- output_names: Optional[list],
+ model_file: str | Path,
+ input_names: list | None,
+ output_names: list | None,
subgraph_index: int,
- output_file: str,
+ output_file: str | Path,
) -> None:
"""Cut subgraphs and simplify a given model."""
model = load_fb(model_file)
diff --git a/src/mlia/nn/rewrite/core/graph_edit/record.py b/src/mlia/nn/rewrite/core/graph_edit/record.py
index 90f3db8..f85433d 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/record.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/record.py
@@ -6,6 +6,7 @@ from __future__ import annotations
import math
import os
+from contextlib import ExitStack
from pathlib import Path
import tensorflow as tf
@@ -15,11 +16,22 @@ from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFWriter
from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel
+from mlia.nn.tensorflow.config import NameToTensorMap
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+DEQUANT_SUFFIX = "_dequant"
+
+
+def dequantized_path(filename: str | Path) -> Path:
+ """Append the de-quantization suffix to the given filename."""
+ path = Path(filename)
+ path = Path(path.parent, f"{path.stem}{DEQUANT_SUFFIX}{path.suffix}")
+ return path
+
+
def record_model(
input_filename: str | Path,
model_filename: str | Path,
@@ -28,11 +40,14 @@ def record_model(
show_progress: bool = False,
num_procs: int = 1,
num_threads: int = 0,
+ dequantize_output: bool = False,
) -> None:
"""Model recorder.
num_procs: 0 => detect real cores on system
num_threads: 0 => TFLite impl. specific setting, usually 3
+
+ dequantize: True => de-quantize the recorded output before saving
"""
model = ParallelTFLiteModel(model_filename, num_procs, num_threads, batch_size)
if not batch_size:
@@ -51,22 +66,38 @@ def record_model(
dataset = dataset.batch(batch_size, drop_remainder=False)
total = int(math.ceil(total / batch_size))
- with NumpyTFWriter(output_filename) as writer:
- for _, named_x in enumerate(
- track(dataset.as_numpy_iterator(), total=total, disable=not show_progress)
- ):
- named_y = model(named_x)
+ with ExitStack() as stack:
+ writer = stack.enter_context(NumpyTFWriter(output_filename))
+ writer_dequant = None
+ if dequantize_output:
+ dequant_path = dequantized_path(output_filename)
+ writer_dequant = stack.enter_context(NumpyTFWriter(dequant_path))
+
+ def write(writer: NumpyTFWriter, data: NameToTensorMap) -> None:
+ """Write the data using the given NumpyTFWriter instance."""
if batch_size > 1:
for i in range(batch_size):
# Expand the batches and recreate each dict as a
# batch-size 1 item for the tfrec output
recreated_dict = {
k: v[i : i + 1] # noqa: E203
- for k, v in named_y.items()
+ for k, v in data.items()
if i < v.shape[0]
}
if recreated_dict:
writer.write(recreated_dict)
else:
- writer.write(named_y)
+ writer.write(data)
+
+ for _, named_x in enumerate(
+ track(dataset.as_numpy_iterator(), total=total, disable=not show_progress)
+ ):
+ named_y = model(named_x)
+ write(writer, named_y)
+
+ if dequantize_output:
+ assert writer_dequant
+ named_y_dequant = model.dequantize_outputs(named_y)
+ write(writer_dequant, named_y_dequant)
+
model.close()
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index 82af747..6345f07 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -22,9 +22,11 @@ from typing import Literal
import numpy as np
import tensorflow as tf
+import tensorflow_model_optimization as tfmot
from numpy.random import Generator
from mlia.nn.rewrite.core.extract import extract
+from mlia.nn.rewrite.core.extract import ExtractPaths
from mlia.nn.rewrite.core.graph_edit.diff import diff_stats
from mlia.nn.rewrite.core.graph_edit.join import join_models
from mlia.nn.rewrite.core.graph_edit.record import record_model
@@ -34,6 +36,7 @@ from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel
from mlia.nn.tensorflow.config import TFLiteModel
from mlia.nn.tensorflow.tflite_graph import load_fb
from mlia.nn.tensorflow.tflite_graph import save_fb
+from mlia.nn.tensorflow.utils import get_tflite_converter
from mlia.utils.logging import log_action
@@ -91,6 +94,7 @@ def train(
input_tfrec,
input_tensors,
output_tensors,
+ dequantize_output=True,
)
else:
unmodified_model_dir = None
@@ -106,6 +110,7 @@ def train(
output_tensors,
num_procs=train_params.num_procs,
num_threads=train_params.num_threads,
+ dequantize_output=True,
)
tflite_filenames = train_in_dir(
@@ -160,7 +165,10 @@ def train(
def eval_in_dir(
- target_dir: str, new_part: str, num_procs: int = 1, num_threads: int = 0
+ target_dir: str,
+ new_part: str,
+ num_procs: int = 1,
+ num_threads: int = 0,
) -> tuple:
"""Evaluate a model in a given directory."""
model_input_path = Path(target_dir, "input_orig.tfrec")
@@ -168,12 +176,12 @@ def eval_in_dir(
model_input = (
model_input_path
if model_input_path.exists()
- else Path(target_dir, "input.tfrec")
+ else ExtractPaths.tfrec.input(target_dir, False)
)
output = (
model_output_path
if model_output_path.exists()
- else Path(target_dir, "output.tfrec")
+ else ExtractPaths.tfrec.output(target_dir, False)
)
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -194,8 +202,8 @@ def join_in_dir(model_dir: str, new_part: str, output_model: str) -> None:
"""Join two models in a given directory."""
with tempfile.TemporaryDirectory() as tmp_dir:
new_end = Path(tmp_dir, "new_end.tflite")
- join_models(new_part, Path(model_dir, "end.tflite"), new_end)
- join_models(Path(model_dir, "start.tflite"), new_end, output_model)
+ join_models(new_part, ExtractPaths.tflite.end(model_dir), new_end)
+ join_models(ExtractPaths.tflite.start(model_dir), new_end, output_model)
def _get_io_tensors(model: TFLiteModel) -> tuple[str, str]:
@@ -244,7 +252,9 @@ def set_up_data_pipeline(
input_name, output_name = _get_io_tensors(teacher)
- input_filename = Path(train_dir, "input.tfrec")
+ model_is_quantized = replace.is_tensor_quantized(name=input_name)
+
+ input_filename = ExtractPaths.tfrec.input(train_dir, model_is_quantized)
total = numpytf_count(str(input_filename))
dict_inputs = numpytf_read(str(input_filename))
@@ -264,13 +274,13 @@ def set_up_data_pipeline(
if any(augmentations):
# Map the teacher inputs here because the augmentation stage passes these
# through a TFLite model to get the outputs
- teacher_outputs = numpytf_read(str(Path(teacher_dir, "input.tfrec"))).map(
- lambda d: tf.squeeze(d[input_name], axis=0)
- )
+ teacher_outputs = numpytf_read(
+ str(ExtractPaths.tfrec.input(teacher_dir, model_is_quantized))
+ ).map(lambda d: tf.squeeze(d[input_name], axis=0))
else:
- teacher_outputs = numpytf_read(str(Path(teacher_dir, "output.tfrec"))).map(
- lambda d: tf.squeeze(d[output_name], axis=0)
- )
+ teacher_outputs = numpytf_read(
+ str(ExtractPaths.tfrec.output(teacher_dir, model_is_quantized))
+ ).map(lambda d: tf.squeeze(d[output_name], axis=0))
dataset = tf.data.Dataset.zip((inputs, teacher_outputs))
if epochs > 1:
@@ -285,7 +295,23 @@ def set_up_data_pipeline(
) -> tuple:
"""Return results of train and teach based on augmentations."""
augmented_train = augment_train({input_name: train})[input_name]
- augmented_teach = teacher(augment_teacher({input_name: teach}))[output_name]
+ # If augmentation of the input is enabled, we need to re-generate
+ # the reference output by running the augmented data through the
+ # teacher model.
+ if model_is_quantized:
+ # If the input model is quantized we have to additionally
+ # - quantize the augmented data before running it through the
+ # (quantized) teacher model
+ # - de-quantize the output for the training.
+ augmented_teach = teacher.dequantize_outputs(
+ teacher(
+ teacher.quantize_inputs(augment_teacher({input_name: teach}))
+ )
+ )[output_name]
+ else:
+ augmented_teach = teacher(augment_teacher({input_name: teach}))[
+ output_name
+ ]
return (augmented_train, augmented_teach)
dataset = dataset.map(
@@ -329,15 +355,20 @@ def train_in_dir(
"""
teacher_dir = baseline_dir if baseline_dir else train_dir
teacher = ParallelTFLiteModel(
- f"{teacher_dir}/replace.tflite",
+ ExtractPaths.tflite.replace(teacher_dir),
train_params.num_procs,
train_params.num_threads,
batch_size=train_params.batch_size,
)
- replace = TFLiteModel(f"{train_dir}/replace.tflite")
+ replace = TFLiteModel(ExtractPaths.tflite.replace(train_dir))
input_name, output_name = _get_io_tensors(teacher)
+ model_is_quantized = replace.is_tensor_quantized(name=input_name)
+
+ if model_is_quantized:
+ replace.check_datatypes(np.int8)
+
dataset = set_up_data_pipeline(
teacher,
replace,
@@ -354,6 +385,8 @@ def train_in_dir(
optimizer = tf.keras.optimizers.Nadam(learning_rate=train_params.learning_rate)
loss_fn = tf.keras.losses.MeanSquaredError()
+ if model_is_quantized:
+ model = tfmot.quantization.keras.quantize_model(model)
model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"])
logger.info(model.summary())
@@ -432,6 +465,7 @@ def train_in_dir(
replace.shape_from_name[input_name],
output_name,
replace.shape_from_name[output_name],
+ model_is_quantized,
)
output_filenames.append(checkpoint_filename)
@@ -446,6 +480,7 @@ def save_as_tflite(
input_shape: list,
output_name: str,
output_shape: list,
+ model_is_quantized: bool = False,
) -> None:
"""Save Keras model as TFLite file."""
@@ -464,7 +499,7 @@ def save_as_tflite(
keras_model.input.set_shape(orig_shape)
with fixed_input(keras_model, input_shape) as fixed_model:
- converter = tf.lite.TFLiteConverter.from_keras_model(fixed_model)
+ converter = get_tflite_converter(fixed_model, quantized=model_is_quantized)
tflite_model = converter.convert()
with open(filename, "wb") as file:
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py
index c6a7c88..b94350a 100644
--- a/src/mlia/nn/tensorflow/config.py
+++ b/src/mlia/nn/tensorflow/config.py
@@ -7,13 +7,23 @@ import logging
import tempfile
from collections import defaultdict
from pathlib import Path
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import List
import numpy as np
import tensorflow as tf
from mlia.core.context import Context
+from mlia.nn.tensorflow.optimizations.quantization import dequantize
+from mlia.nn.tensorflow.optimizations.quantization import is_quantized
+from mlia.nn.tensorflow.optimizations.quantization import QuantizationParameters
+from mlia.nn.tensorflow.optimizations.quantization import quantize
from mlia.nn.tensorflow.tflite_graph import load_fb
from mlia.nn.tensorflow.tflite_graph import save_fb
+from mlia.nn.tensorflow.utils import check_tflite_datatypes
from mlia.nn.tensorflow.utils import convert_to_tflite
from mlia.nn.tensorflow.utils import is_keras_model
from mlia.nn.tensorflow.utils import is_saved_model
@@ -71,6 +81,11 @@ class KerasModel(ModelConfiguration):
return self
+TFLiteIODetails = Dict[str, Dict[str, Any]]
+TFLiteIODetailsList = List[TFLiteIODetails]
+NameToTensorMap = Dict[str, np.ndarray]
+
+
class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
"""TensorFlow Lite model configuration."""
@@ -119,7 +134,22 @@ class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
self.shape_from_name = {d["name"]: d["shape"] for d in details}
self.batch_size = next(iter(self.shape_from_name.values()))[0]
- def __call__(self, named_input: dict) -> dict:
+ # Prepare quantization parameters for input and output
+ def named_quant_params(
+ details: TFLiteIODetailsList,
+ ) -> dict[str, QuantizationParameters]:
+ return {
+ str(detail["name"]): QuantizationParameters(
+ **detail["quantization_parameters"]
+ )
+ for detail in details
+ if TFLiteModel._is_tensor_quantized(detail)
+ }
+
+ self._quant_params_input = named_quant_params(self.input_details)
+ self._quant_params_output = named_quant_params(self.output_details)
+
+ def __call__(self, named_input: dict) -> NameToTensorMap:
"""Execute the model on one or a batch of named inputs \
(a dict of name: numpy array)."""
input_len = next(iter(named_input.values())).shape[0]
@@ -150,11 +180,11 @@ class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
)
return {k: np.concatenate(v) for k, v in named_ys.items()}
- def input_tensors(self) -> list:
+ def input_tensors(self) -> list[str]:
"""Return name from input details."""
return [d["name"] for d in self.input_details]
- def output_tensors(self) -> list:
+ def output_tensors(self) -> list[str]:
"""Return name from output details."""
return [d["name"] for d in self.output_details]
@@ -164,6 +194,71 @@ class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
"""Convert model to TensorFlow Lite format."""
return self
+ def _tensor_details(
+ self, name: str | None = None, idx: int | None = None
+ ) -> TFLiteIODetails:
+ """Get the details of the tensor by name or index."""
+ if idx is not None:
+ details = self.interpreter.get_tensor_details()[idx]
+ assert details["index"] == idx
+ elif name is not None:
+ for details_ in self.interpreter.get_tensor_details():
+ if name == details_["name"]:
+ details = details_
+ break
+ else:
+ raise NameError(
+ f"Tensor '{name}' not found in model {self.model_path}."
+ )
+ else:
+ raise ValueError("Either tensor name or index needs to be passed.")
+
+ assert isinstance(details, dict)
+ return cast(TFLiteIODetails, details)
+
+ @staticmethod
+ def _is_tensor_quantized(details: TFLiteIODetails) -> bool:
+ """Use tensor details to check if the corresponding tensor is quantized."""
+ quant_params = QuantizationParameters(**details["quantization_parameters"])
+ return is_quantized(quant_params)
+
+ def is_tensor_quantized(
+ self,
+ name: str | None = None,
+ idx: int | None = None,
+ ) -> bool:
+ """Check if the given tensor (identified by name or index) is quantized."""
+ details = self._tensor_details(name, idx)
+ return self._is_tensor_quantized(details)
+
+ def check_datatypes(self, *allowed_types: type) -> None:
+ """Check if the model only has the given allowed datatypes."""
+ check_tflite_datatypes(self.model_path, *allowed_types)
+
+ @staticmethod
+ def _quant_dequant(
+ tensors: NameToTensorMap,
+ quant_params: dict[str, QuantizationParameters],
+ func: Callable,
+ ) -> NameToTensorMap:
+ """Quantize/de-quantize tensor using the given parameters and function."""
+ return {
+ name: (func(tensor, quant_params[name]) if name in quant_params else tensor)
+ for name, tensor in tensors.items()
+ }
+
+ def dequantize_outputs(self, outputs: NameToTensorMap) -> NameToTensorMap:
+ """De-quantize the given model outputs."""
+ dequant_outputs = self._quant_dequant(
+ outputs, self._quant_params_output, dequantize
+ )
+ return dequant_outputs
+
+ def quantize_inputs(self, inputs: NameToTensorMap) -> NameToTensorMap:
+ """Quantize the given model inputs."""
+ quant_inputs = self._quant_dequant(inputs, self._quant_params_input, quantize)
+ return quant_inputs
+
class TfModel(ModelConfiguration): # pylint: disable=abstract-method
"""TensorFlow model configuration.
diff --git a/src/mlia/nn/tensorflow/optimizations/quantization.py b/src/mlia/nn/tensorflow/optimizations/quantization.py
new file mode 100644
index 0000000..4e3c2c2
--- /dev/null
+++ b/src/mlia/nn/tensorflow/optimizations/quantization.py
@@ -0,0 +1,74 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Contains functionality for quantization and de-quantization."""
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import cast
+
+import numpy as np
+import tensorflow as tf
+
+
+@dataclass
+class QuantizationParameters:
+ """Collection of TensorFlow Lite quantization parameters.
+
+ Can be directly initialized from TensorFlow Lite tensor details, e.g.
+
+ ```
+ QuantizationParameters(
+ **interpreter.get_input_details()[0]["quantization_parameters"]
+ )
+ ```
+ """
+
+ scales: np.ndarray
+ zero_points: np.ndarray
+ quantized_dimension: int
+
+
+def is_quantized(quant_params: QuantizationParameters) -> bool:
+ """Check if the quantization parameters describe a quantized tensor."""
+ return quant_params.scales.size > 0
+
+
+def dequantize(
+ quantized_tensor: np.ndarray | tf.Tensor, quant_params: QuantizationParameters
+) -> np.ndarray:
+ """De-quantize the input tensor using the given quantization parameters."""
+ assert isinstance(quantized_tensor, (tf.Tensor, np.ndarray))
+ assert (
+ not isinstance(quantized_tensor, tf.Tensor)
+ or quantized_tensor.dtype.is_quantized
+ ) and (
+ not isinstance(quantized_tensor, np.ndarray)
+ or issubclass(quantized_tensor.dtype.type, np.integer)
+ ), (
+ f"Input tensor for de-quantization is of type {quantized_tensor.dtype}, "
+ "but it must be int."
+ )
+
+ dequantized_tensor = np.subtract(
+ quantized_tensor, quant_params.zero_points, dtype=np.float32
+ )
+ dequantized_tensor = np.multiply(
+ dequantized_tensor, quant_params.scales, dtype=np.float32
+ )
+ return dequantized_tensor
+
+
+def quantize(
+ tensor: np.ndarray | tf.Tensor, quant_params: QuantizationParameters
+) -> np.ndarray:
+ """Quantize the given float input tensor to int8."""
+ assert isinstance(tensor, (tf.Tensor, np.ndarray))
+ assert (not isinstance(tensor, tf.Tensor) or tensor.dtype.is_floating) and (
+ not isinstance(tensor, np.ndarray) or issubclass(tensor.dtype.type, np.floating)
+ ), f"Input tensor for quantization is of type {tensor.dtype}, but it must be float."
+
+ quantized_tensor = (tensor / quant_params.scales) + quant_params.zero_points
+ quantized_tensor = np.clip( # type: ignore
+ quantized_tensor, -128, 127, dtype=np.int8, casting="unsafe"
+ )
+ return cast(np.ndarray, quantized_tensor)
diff --git a/src/mlia/nn/tensorflow/utils.py b/src/mlia/nn/tensorflow/utils.py
index 77ac529..b8d45c6 100644
--- a/src/mlia/nn/tensorflow/utils.py
+++ b/src/mlia/nn/tensorflow/utils.py
@@ -123,3 +123,33 @@ def get_tflite_converter(
converter.inference_output_type = tf.int8
return converter
+
+
+def get_tflite_model_type_map(model_filename: str | Path) -> dict[str, type]:
+ """Get type map from tflite model."""
+ model_type_map: dict[str, Any] = {}
+ interpreter = tf.lite.Interpreter(str(model_filename))
+ interpreter.allocate_tensors()
+ input_details = interpreter.get_input_details()
+ model_type_map = {
+ input_detail["name"]: input_detail["dtype"] for input_detail in input_details
+ }
+ return model_type_map
+
+
+def check_tflite_datatypes(model_filename: str | Path, *allowed_types: type) -> None:
+ """Check if the model only has the given allowed datatypes."""
+ type_map = get_tflite_model_type_map(model_filename)
+ types = set(type_map.values())
+ allowed = set(allowed_types)
+ unexpected = types - allowed
+
+ def cls_to_str(types: set[type]) -> list[str]:
+ return [t.__name__ for t in types]
+
+ if len(unexpected) > 0:
+ raise TypeError(
+ f"Model {model_filename} has "
+ f"unexpected data types: {cls_to_str(unexpected)}. "
+ f"Only {cls_to_str(allowed)} are allowed."
+ )
diff --git a/tests/test_nn_rewrite_core_extract.py b/tests/test_nn_rewrite_core_extract.py
new file mode 100644
index 0000000..09eca77
--- /dev/null
+++ b/tests/test_nn_rewrite_core_extract.py
@@ -0,0 +1,38 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module mlia.nn.rewrite.core.extract."""
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Any
+from typing import Iterable
+
+import pytest
+
+from mlia.nn.rewrite.core.extract import ExtractPaths
+from mlia.nn.rewrite.core.graph_edit.record import DEQUANT_SUFFIX
+
+
+@pytest.mark.parametrize("dir_path", ("/dev/null", Path("/dev/null")))
+@pytest.mark.parametrize("model_is_quantized", (False, True))
+@pytest.mark.parametrize(
+ ("obj", "func_names", "suffix"),
+ (
+ (ExtractPaths.tflite, ("start", "replace", "end"), ".tflite"),
+ (ExtractPaths.tfrec, ("input", "output", "end"), ".tfrec"),
+ ),
+)
+def test_extract_paths(
+ dir_path: str | Path,
+ model_is_quantized: bool,
+ obj: Any,
+ func_names: Iterable[str],
+ suffix: str,
+) -> None:
+ """Test class ExtractPaths."""
+ for func_name in func_names:
+ func = getattr(obj, func_name)
+ path = func(dir_path, model_is_quantized)
+ assert path == Path(dir_path, path.relative_to(dir_path))
+ assert path.suffix == suffix
+ assert not model_is_quantized or path.stem.endswith(DEQUANT_SUFFIX)
diff --git a/tests/test_nn_rewrite_core_graph_edit_record.py b/tests/test_nn_rewrite_core_graph_edit_record.py
index 41b9c50..422b53e 100644
--- a/tests/test_nn_rewrite_core_graph_edit_record.py
+++ b/tests/test_nn_rewrite_core_graph_edit_record.py
@@ -3,43 +3,57 @@
"""Tests for module mlia.nn.rewrite.graph_edit.record."""
from pathlib import Path
+import numpy as np
+import pytest
import tensorflow as tf
+from mlia.nn.rewrite.core.extract import ExtractPaths
from mlia.nn.rewrite.core.graph_edit.record import record_model
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
+def data_matches_outputs(
+ name: str,
+ tensor: tf.Tensor,
+ model_outputs: list,
+ dequantized_output: bool,
+) -> bool:
+ """Check that the name and the tensor match any of the model outputs."""
+ for model_output in model_outputs:
+ if model_output["name"] == name:
+ # If the name is a match, tensor shape and type have to match!
+ tensor_shape = tensor.shape.as_list()
+ tensor_type = tensor.dtype.as_numpy_dtype
+ return all(
+ (
+ tensor_shape == model_output["shape"].tolist(),
+ tensor_type == np.float32
+ if dequantized_output
+ else model_output["dtype"],
+ )
+ )
+ return False
+
+
def check_record_model(
test_tflite_model: Path,
tmp_path: Path,
test_tfrecord: Path,
batch_size: int,
+ dequantize_output: bool,
) -> None:
"""Test the function record_model()."""
- output_file = tmp_path / "out.tfrecord"
+ output_file = ExtractPaths.tfrec.output(tmp_path)
record_model(
input_filename=str(test_tfrecord),
model_filename=str(test_tflite_model),
output_filename=str(output_file),
batch_size=batch_size,
+ dequantize_output=dequantize_output,
)
+ output_file = ExtractPaths.tfrec.output(tmp_path, dequantize_output)
assert output_file.is_file()
- def data_matches_outputs(name: str, tensor: tf.Tensor, model_outputs: list) -> bool:
- """Check that the name and the tensor match any of the model outputs."""
- for model_output in model_outputs:
- if model_output["name"] == name:
- # If the name is a match, tensor shape and type have to match!
- tensor_shape = tensor.shape.as_list()
- tensor_type = tensor.dtype.as_numpy_dtype
- return all(
- (
- tensor_shape == model_output["shape"].tolist(),
- tensor_type == model_output["dtype"],
- )
- )
- return False
-
# Now load model and the data and make sure that the written data matches
# any of the model outputs
interpreter = tf.lite.Interpreter(str(test_tflite_model))
@@ -47,4 +61,19 @@ def check_record_model(
dataset = numpytf_read(str(output_file))
for data in dataset:
for name, tensor in data.items():
- assert data_matches_outputs(name, tensor, model_outputs)
+ assert data_matches_outputs(name, tensor, model_outputs, dequantize_output)
+
+
+@pytest.mark.parametrize("batch_size", (None, 1, 2))
+@pytest.mark.parametrize("dequantize_output", (True, False))
+def test_record_model(
+ test_tflite_model: Path,
+ tmp_path: Path,
+ test_tfrecord: Path,
+ batch_size: int,
+ dequantize_output: bool,
+) -> None:
+ """Test the function record_model()."""
+ check_record_model(
+ test_tflite_model, tmp_path, test_tfrecord, batch_size, dequantize_output
+ )
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
index b001a09..ef52320 100644
--- a/tests/test_nn_rewrite_core_train.py
+++ b/tests/test_nn_rewrite_core_train.py
@@ -1,6 +1,6 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
-"""Tests for module mlia.nn.rewrite.train."""
+"""Tests for module mlia.nn.rewrite.core.train."""
# pylint: disable=too-many-arguments
from __future__ import annotations
@@ -47,10 +47,11 @@ def check_train(
tfrecord: Path,
train_params: TrainingParameters = TestTrainingParameters(),
use_unmodified_model: bool = False,
+ quantized: bool = False,
) -> None:
"""Test the train() function."""
with TemporaryDirectory() as tmp_dir:
- output_file = Path(tmp_dir, "out.tfrecord")
+ output_file = Path(tmp_dir, "out.tflite")
result = train(
source_model=str(tflite_model),
unmodified_model=str(tflite_model) if use_unmodified_model else None,
@@ -65,6 +66,17 @@ def check_train(
assert all(res >= 0.0 for res in result[0]), f"Results out of bound: {result}"
assert output_file.is_file()
+ if quantized:
+ interpreter = tf.lite.Interpreter(model_path=str(output_file))
+ interpreter.allocate_tensors()
+ # Check that the quantization parameters are non-zero
+ assert all(interpreter.get_output_details()[0]["quantization"])
+ assert all(interpreter.get_input_details()[0]["quantization"])
+ dtypes = []
+ for tensor_detail in interpreter.get_tensor_details():
+ dtypes.append(tensor_detail["dtype"])
+ assert all(np.issubdtype(dtype, np.integer) for dtype in dtypes)
+
@pytest.mark.parametrize(
(
@@ -89,7 +101,7 @@ def check_train(
),
),
)
-def test_train(
+def test_train_fp32(
test_tflite_model_fp32: Path,
test_tfrecord_fp32: Path,
batch_size: int,
@@ -114,6 +126,55 @@ def test_train(
)
+@pytest.mark.parametrize(
+ (
+ "batch_size",
+ "show_progress",
+ "augmentation_preset",
+ "lr_schedule",
+ "use_unmodified_model",
+ "num_procs",
+ ),
+ (
+ (1, False, AUGMENTATION_PRESETS["none"], "cosine", False, 2),
+ (32, True, AUGMENTATION_PRESETS["gaussian"], "late", True, 1),
+ (2, False, AUGMENTATION_PRESETS["mixup"], "constant", True, 0),
+ (
+ 1,
+ False,
+ AUGMENTATION_PRESETS["mix_gaussian_large"],
+ "cosine",
+ False,
+ 2,
+ ),
+ ),
+)
+def test_train_int8(
+ test_tflite_model: Path,
+ test_tfrecord: Path,
+ batch_size: int,
+ show_progress: bool,
+ augmentation_preset: tuple[float | None, float | None],
+ lr_schedule: LearningRateSchedule,
+ use_unmodified_model: bool,
+ num_procs: int,
+) -> None:
+ """Test the train() function with valid parameters."""
+ check_train(
+ tflite_model=test_tflite_model,
+ tfrecord=test_tfrecord,
+ train_params=TestTrainingParameters(
+ batch_size=batch_size,
+ show_progress=show_progress,
+ augmentations=augmentation_preset,
+ learning_rate_schedule=lr_schedule,
+ num_procs=num_procs,
+ ),
+ use_unmodified_model=use_unmodified_model,
+ quantized=True,
+ )
+
+
def test_train_invalid_schedule(
test_tflite_model_fp32: Path,
test_tfrecord_fp32: Path,
diff --git a/tests/test_nn_tensorflow_config.py b/tests/test_nn_tensorflow_config.py
index 48aec0a..fff3857 100644
--- a/tests/test_nn_tensorflow_config.py
+++ b/tests/test_nn_tensorflow_config.py
@@ -111,3 +111,15 @@ def test_tflite_model_call(
for named_input in data.as_numpy_iterator():
res = model(named_input)
assert res
+
+
+def test_tflite_model_is_tensor_quantized(test_tflite_model: Path) -> None:
+ """Test function TFLiteModel.is_tensor_quantized()."""
+ model = TFLiteModel(test_tflite_model)
+ input_details = model.input_details[0]
+ assert model.is_tensor_quantized(name=input_details["name"])
+ assert model.is_tensor_quantized(idx=input_details["index"])
+ with pytest.raises(ValueError):
+ assert model.is_tensor_quantized()
+ with pytest.raises(NameError):
+ assert model.is_tensor_quantized(name="NAME_DOES_NOT_EXIST")
diff --git a/tests/test_nn_tensorflow_optimizations_quantization.py b/tests/test_nn_tensorflow_optimizations_quantization.py
new file mode 100644
index 0000000..5228cec
--- /dev/null
+++ b/tests/test_nn_tensorflow_optimizations_quantization.py
@@ -0,0 +1,53 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module optimizations/quantization."""
+from __future__ import annotations
+
+from itertools import chain
+from pathlib import Path
+from typing import Generator
+
+import numpy as np
+from numpy.core.numeric import isclose
+
+from mlia.nn.tensorflow.config import TFLiteModel
+from mlia.nn.tensorflow.optimizations.quantization import dequantize
+from mlia.nn.tensorflow.optimizations.quantization import is_quantized
+from mlia.nn.tensorflow.optimizations.quantization import QuantizationParameters
+from mlia.nn.tensorflow.optimizations.quantization import quantize
+
+
+def model_io_quant_params(model_path: Path) -> Generator:
+ """Generate QuantizationParameters for all model inputs and outputs."""
+ model = TFLiteModel(model_path=model_path)
+ for details in chain(model.input_details, model.output_details):
+ yield QuantizationParameters(**details["quantization_parameters"])
+
+
+def test_is_quantized(test_tflite_model: Path) -> None:
+ """Test function is_quantized() with a quantized model."""
+ for quant_params in model_io_quant_params(test_tflite_model):
+ assert is_quantized(quant_params)
+
+
+def test_is_not_quantized(test_tflite_model_fp32: Path) -> None:
+ """Test function is_quantized() with an unquantized model."""
+ for quant_params in model_io_quant_params(test_tflite_model_fp32):
+ assert not is_quantized(quant_params)
+
+
+def test_quantize() -> None:
+ """Test function quantize()."""
+ ref_dequant = np.array((0.0, 0.1, 0.2, 0.3))
+ ref_quant = np.array((0, 10, 20, 30), dtype=np.int8)
+ quant_params = QuantizationParameters(
+ scales=np.array([0.01]), zero_points=np.array([0.0]), quantized_dimension=0
+ )
+
+ quant = quantize(ref_dequant, quant_params)
+ assert quant.dtype == np.int8
+ assert np.all(quant == ref_quant)
+
+ dequant = dequantize(quant, quant_params)
+ assert dequant.dtype == np.float32
+ assert np.all(isclose(dequant, ref_dequant, atol=0.03))
diff --git a/tests/test_nn_tensorflow_utils.py b/tests/test_nn_tensorflow_utils.py
index 14b06c4..dab8b4e 100644
--- a/tests/test_nn_tensorflow_utils.py
+++ b/tests/test_nn_tensorflow_utils.py
@@ -1,14 +1,17 @@
# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Test for module utils/test_utils."""
+import re
from pathlib import Path
import numpy as np
import pytest
import tensorflow as tf
+from mlia.nn.tensorflow.utils import check_tflite_datatypes
from mlia.nn.tensorflow.utils import convert_to_tflite
from mlia.nn.tensorflow.utils import get_tf_tensor_shape
+from mlia.nn.tensorflow.utils import get_tflite_model_type_map
from mlia.nn.tensorflow.utils import is_keras_model
from mlia.nn.tensorflow.utils import is_tflite_model
from mlia.nn.tensorflow.utils import representative_dataset
@@ -109,3 +112,31 @@ def test_is_keras_model(model_path: Path, expected_result: bool) -> None:
def test_get_tf_tensor_shape(test_tf_model: Path) -> None:
"""Test get_tf_tensor_shape with test model."""
assert get_tf_tensor_shape(str(test_tf_model)) == [1, 28, 28, 1]
+
+
+def test_tflite_model_type_map(
+ test_tflite_model_fp32: Path, test_tflite_model: Path
+) -> None:
+ """Test the model type map function."""
+ assert get_tflite_model_type_map(test_tflite_model_fp32) == {
+ "serving_default_input:0": np.float32
+ }
+ assert get_tflite_model_type_map(test_tflite_model) == {
+ "serving_default_input:0": np.int8
+ }
+
+
+def test_check_tflite_datatypes(
+ test_tflite_model_fp32: Path, test_tflite_model: Path
+) -> None:
+ """Test the model type map function."""
+ check_tflite_datatypes(test_tflite_model_fp32, np.float32)
+ check_tflite_datatypes(test_tflite_model, np.int8)
+
+ with pytest.raises(
+ Exception,
+ match=re.escape(
+ "unexpected data types: ['float32']. Only ['int8'] are allowed"
+ ),
+ ):
+ check_tflite_datatypes(test_tflite_model_fp32, np.int8)