aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2023-07-12 15:18:26 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-10-11 16:16:32 +0100
commitecc4264b93d4a89fa2cb40518b225d8371b7ffad (patch)
tree47244d2d67ab6c50bfc15eab768252359eae0df6
parentbaaf4de286762c1955c874f78cd802d4703a8ba5 (diff)
downloadmlia-ecc4264b93d4a89fa2cb40518b225d8371b7ffad.tar.gz
Enable rewrites for quantized input models
If the input model for rewriting is quantized: - Record de-quantized TFRecords - enable writing de-quantized calibration data for the training - re-generate augmented training data, if needed - Use quantization-aware training (QAT) to train the replacement models - Check if replacement model is quantized: If source model is quantized, we make sure rewrite's output model is quantized too. Right now, only int8 is supported so raising an error if any other datatype is present in the output. Resolves: MLIA-907, MLIA-908, MLIA-927 Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com> Change-Id: Icb4070a9e6f1fdb5ce36120d73823986e89ac955
-rw-r--r--src/mlia/nn/rewrite/core/extract.py61
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/cut.py16
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/record.py45
-rw-r--r--src/mlia/nn/rewrite/core/train.py67
-rw-r--r--src/mlia/nn/tensorflow/config.py101
-rw-r--r--src/mlia/nn/tensorflow/optimizations/quantization.py74
-rw-r--r--src/mlia/nn/tensorflow/utils.py30
-rw-r--r--tests/test_nn_rewrite_core_extract.py38
-rw-r--r--tests/test_nn_rewrite_core_graph_edit_record.py63
-rw-r--r--tests/test_nn_rewrite_core_train.py67
-rw-r--r--tests/test_nn_tensorflow_config.py12
-rw-r--r--tests/test_nn_tensorflow_optimizations_quantization.py53
-rw-r--r--tests/test_nn_tensorflow_utils.py31
13 files changed, 598 insertions, 60 deletions
diff --git a/src/mlia/nn/rewrite/core/extract.py b/src/mlia/nn/rewrite/core/extract.py
index f609955..4fcf735 100644
--- a/src/mlia/nn/rewrite/core/extract.py
+++ b/src/mlia/nn/rewrite/core/extract.py
@@ -2,19 +2,62 @@
# SPDX-License-Identifier: Apache-2.0
"""Extract module."""
# pylint: disable=too-many-arguments, too-many-locals
+from __future__ import annotations
+
import os
+from functools import partial
+from pathlib import Path
import tensorflow as tf
from tensorflow.lite.python.schema_py_generated import SubGraphT
from mlia.nn.rewrite.core.graph_edit.cut import cut_model
+from mlia.nn.rewrite.core.graph_edit.record import dequantized_path
from mlia.nn.rewrite.core.graph_edit.record import record_model
-
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+def _get_path(
+ ext: str, name: str, dir_path: str | Path, model_is_quantized: bool = False
+) -> Path:
+ """Create a file path for extracted files."""
+ path = Path(dir_path, f"{name}{ext}")
+ return dequantized_path(path) if model_is_quantized else path
+
+
+class TFLitePaths: # pylint: disable=too-few-public-methods
+ """Provide safe access to TensorFlow Lite file paths."""
+
+ _get_path_tflite = partial(_get_path, ".tflite")
+
+ start = partial(_get_path_tflite, "start")
+ replace = partial(_get_path_tflite, "replace")
+ end = partial(_get_path_tflite, "end")
+
+
+class TFRecordPaths: # pylint: disable=too-few-public-methods
+ """Provide safe access to tfrec file paths."""
+
+ _get_path_tfrec = partial(_get_path, ".tfrec")
+
+ input = partial(_get_path_tfrec, "input")
+ output = partial(_get_path_tfrec, "output")
+ end = partial(_get_path_tfrec, "end")
+
+
+class ExtractPaths: # pylint: disable=too-few-public-methods
+ """Get paths to extract files.
+
+ This is meant to be the single source of truth regarding all file names
+ created by the extract() function in an output directory.
+ """
+
+ tflite = TFLitePaths
+ tfrec = TFRecordPaths
+
+
def extract(
output_path: str,
model_file: str,
@@ -26,6 +69,7 @@ def extract(
show_progress: bool = False,
num_procs: int = 1,
num_threads: int = 0,
+ dequantize_output: bool = False,
) -> None:
"""Extract a model after cut and record."""
try:
@@ -33,7 +77,7 @@ def extract(
except FileExistsError:
pass
- start_file = os.path.join(output_path, "start.tflite")
+ start_file = ExtractPaths.tflite.start(output_path)
cut_model(
model_file,
input_names=None,
@@ -42,7 +86,7 @@ def extract(
output_file=start_file,
)
- input_tfrec = os.path.join(output_path, "input.tfrec")
+ input_tfrec = ExtractPaths.tfrec.input(output_path)
record_model(
input_filename,
start_file,
@@ -50,9 +94,10 @@ def extract(
show_progress=show_progress,
num_procs=num_procs,
num_threads=num_threads,
+ dequantize_output=dequantize_output,
)
- replace_file = os.path.join(output_path, "replace.tflite")
+ replace_file = ExtractPaths.tflite.replace(output_path)
cut_model(
model_file,
input_names=input_names,
@@ -61,7 +106,7 @@ def extract(
output_file=replace_file,
)
- end_file = os.path.join(output_path, "end.tflite")
+ end_file = ExtractPaths.tflite.end(output_path)
cut_model(
model_file,
input_names=output_names,
@@ -71,7 +116,7 @@ def extract(
)
if not skip_outputs:
- output_tfrec = os.path.join(output_path, "output.tfrec")
+ output_tfrec = ExtractPaths.tfrec.output(output_path)
record_model(
input_tfrec,
replace_file,
@@ -79,9 +124,10 @@ def extract(
show_progress=show_progress,
num_procs=num_procs,
num_threads=num_threads,
+ dequantize_output=dequantize_output,
)
- end_tfrec = os.path.join(output_path, "end.tfrec")
+ end_tfrec = ExtractPaths.tfrec.end(output_path)
record_model(
output_tfrec,
end_file,
@@ -89,4 +135,5 @@ def extract(
show_progress=show_progress,
num_procs=num_procs,
num_threads=num_threads,
+ dequantize_output=dequantize_output,
)
diff --git a/src/mlia/nn/rewrite/core/graph_edit/cut.py b/src/mlia/nn/rewrite/core/graph_edit/cut.py
index 13a5268..53d5389 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/cut.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/cut.py
@@ -1,9 +1,11 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Cut module."""
+from __future__ import annotations
+
import os
from collections import defaultdict
-from typing import Optional
+from pathlib import Path
import tensorflow as tf
from tensorflow.lite.python.schema_py_generated import ModelT
@@ -25,8 +27,8 @@ def tensors_by_name(subgraph: SubGraphT, names: list) -> list:
def cut_subgraph(
subgraph: SubGraphT,
- input_tensor_names: Optional[list],
- output_tensor_names: Optional[list],
+ input_tensor_names: list | None,
+ output_tensor_names: list | None,
) -> None:
"""Change the global inputs and outputs of a graph to the provided named tensors."""
if input_tensor_names is not None:
@@ -131,11 +133,11 @@ def filter_relabel(src_subgraph: SubGraphT, relabel_filter: set) -> tuple:
def cut_model(
- model_file: str,
- input_names: Optional[list],
- output_names: Optional[list],
+ model_file: str | Path,
+ input_names: list | None,
+ output_names: list | None,
subgraph_index: int,
- output_file: str,
+ output_file: str | Path,
) -> None:
"""Cut subgraphs and simplify a given model."""
model = load_fb(model_file)
diff --git a/src/mlia/nn/rewrite/core/graph_edit/record.py b/src/mlia/nn/rewrite/core/graph_edit/record.py
index 90f3db8..f85433d 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/record.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/record.py
@@ -6,6 +6,7 @@ from __future__ import annotations
import math
import os
+from contextlib import ExitStack
from pathlib import Path
import tensorflow as tf
@@ -15,11 +16,22 @@ from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFWriter
from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel
+from mlia.nn.tensorflow.config import NameToTensorMap
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+DEQUANT_SUFFIX = "_dequant"
+
+
+def dequantized_path(filename: str | Path) -> Path:
+ """Append the de-quantization suffix to the given filename."""
+ path = Path(filename)
+ path = Path(path.parent, f"{path.stem}{DEQUANT_SUFFIX}{path.suffix}")
+ return path
+
+
def record_model(
input_filename: str | Path,
model_filename: str | Path,
@@ -28,11 +40,14 @@ def record_model(
show_progress: bool = False,
num_procs: int = 1,
num_threads: int = 0,
+ dequantize_output: bool = False,
) -> None:
"""Model recorder.
num_procs: 0 => detect real cores on system
num_threads: 0 => TFLite impl. specific setting, usually 3
+
+ dequantize: True => de-quantize the recorded output before saving
"""
model = ParallelTFLiteModel(model_filename, num_procs, num_threads, batch_size)
if not batch_size:
@@ -51,22 +66,38 @@ def record_model(
dataset = dataset.batch(batch_size, drop_remainder=False)
total = int(math.ceil(total / batch_size))
- with NumpyTFWriter(output_filename) as writer:
- for _, named_x in enumerate(
- track(dataset.as_numpy_iterator(), total=total, disable=not show_progress)
- ):
- named_y = model(named_x)
+ with ExitStack() as stack:
+ writer = stack.enter_context(NumpyTFWriter(output_filename))
+ writer_dequant = None
+ if dequantize_output:
+ dequant_path = dequantized_path(output_filename)
+ writer_dequant = stack.enter_context(NumpyTFWriter(dequant_path))
+
+ def write(writer: NumpyTFWriter, data: NameToTensorMap) -> None:
+ """Write the data using the given NumpyTFWriter instance."""
if batch_size > 1:
for i in range(batch_size):
# Expand the batches and recreate each dict as a
# batch-size 1 item for the tfrec output
recreated_dict = {
k: v[i : i + 1] # noqa: E203
- for k, v in named_y.items()
+ for k, v in data.items()
if i < v.shape[0]
}
if recreated_dict:
writer.write(recreated_dict)
else:
- writer.write(named_y)
+ writer.write(data)
+
+ for _, named_x in enumerate(
+ track(dataset.as_numpy_iterator(), total=total, disable=not show_progress)
+ ):
+ named_y = model(named_x)
+ write(writer, named_y)
+
+ if dequantize_output:
+ assert writer_dequant
+ named_y_dequant = model.dequantize_outputs(named_y)
+ write(writer_dequant, named_y_dequant)
+
model.close()
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index 82af747..6345f07 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -22,9 +22,11 @@ from typing import Literal
import numpy as np
import tensorflow as tf
+import tensorflow_model_optimization as tfmot
from numpy.random import Generator
from mlia.nn.rewrite.core.extract import extract
+from mlia.nn.rewrite.core.extract import ExtractPaths
from mlia.nn.rewrite.core.graph_edit.diff import diff_stats
from mlia.nn.rewrite.core.graph_edit.join import join_models
from mlia.nn.rewrite.core.graph_edit.record import record_model
@@ -34,6 +36,7 @@ from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel
from mlia.nn.tensorflow.config import TFLiteModel
from mlia.nn.tensorflow.tflite_graph import load_fb
from mlia.nn.tensorflow.tflite_graph import save_fb
+from mlia.nn.tensorflow.utils import get_tflite_converter
from mlia.utils.logging import log_action
@@ -91,6 +94,7 @@ def train(
input_tfrec,
input_tensors,
output_tensors,
+ dequantize_output=True,
)
else:
unmodified_model_dir = None
@@ -106,6 +110,7 @@ def train(
output_tensors,
num_procs=train_params.num_procs,
num_threads=train_params.num_threads,
+ dequantize_output=True,
)
tflite_filenames = train_in_dir(
@@ -160,7 +165,10 @@ def train(
def eval_in_dir(
- target_dir: str, new_part: str, num_procs: int = 1, num_threads: int = 0
+ target_dir: str,
+ new_part: str,
+ num_procs: int = 1,
+ num_threads: int = 0,
) -> tuple:
"""Evaluate a model in a given directory."""
model_input_path = Path(target_dir, "input_orig.tfrec")
@@ -168,12 +176,12 @@ def eval_in_dir(
model_input = (
model_input_path
if model_input_path.exists()
- else Path(target_dir, "input.tfrec")
+ else ExtractPaths.tfrec.input(target_dir, False)
)
output = (
model_output_path
if model_output_path.exists()
- else Path(target_dir, "output.tfrec")
+ else ExtractPaths.tfrec.output(target_dir, False)
)
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -194,8 +202,8 @@ def join_in_dir(model_dir: str, new_part: str, output_model: str) -> None:
"""Join two models in a given directory."""
with tempfile.TemporaryDirectory() as tmp_dir:
new_end = Path(tmp_dir, "new_end.tflite")
- join_models(new_part, Path(model_dir, "end.tflite"), new_end)
- join_models(Path(model_dir, "start.tflite"), new_end, output_model)
+ join_models(new_part, ExtractPaths.tflite.end(model_dir), new_end)
+ join_models(ExtractPaths.tflite.start(model_dir), new_end, output_model)
def _get_io_tensors(model: TFLiteModel) -> tuple[str, str]:
@@ -244,7 +252,9 @@ def set_up_data_pipeline(
input_name, output_name = _get_io_tensors(teacher)
- input_filename = Path(train_dir, "input.tfrec")
+ model_is_quantized = replace.is_tensor_quantized(name=input_name)
+
+ input_filename = ExtractPaths.tfrec.input(train_dir, model_is_quantized)
total = numpytf_count(str(input_filename))
dict_inputs = numpytf_read(str(input_filename))
@@ -264,13 +274,13 @@ def set_up_data_pipeline(
if any(augmentations):
# Map the teacher inputs here because the augmentation stage passes these
# through a TFLite model to get the outputs
- teacher_outputs = numpytf_read(str(Path(teacher_dir, "input.tfrec"))).map(
- lambda d: tf.squeeze(d[input_name], axis=0)
- )
+ teacher_outputs = numpytf_read(
+ str(ExtractPaths.tfrec.input(teacher_dir, model_is_quantized))
+ ).map(lambda d: tf.squeeze(d[input_name], axis=0))
else:
- teacher_outputs = numpytf_read(str(Path(teacher_dir, "output.tfrec"))).map(
- lambda d: tf.squeeze(d[output_name], axis=0)
- )
+ teacher_outputs = numpytf_read(
+ str(ExtractPaths.tfrec.output(teacher_dir, model_is_quantized))
+ ).map(lambda d: tf.squeeze(d[output_name], axis=0))
dataset = tf.data.Dataset.zip((inputs, teacher_outputs))
if epochs > 1:
@@ -285,7 +295,23 @@ def set_up_data_pipeline(
) -> tuple:
"""Return results of train and teach based on augmentations."""
augmented_train = augment_train({input_name: train})[input_name]
- augmented_teach = teacher(augment_teacher({input_name: teach}))[output_name]
+ # If augmentation of the input is enabled, we need to re-generate
+ # the reference output by running the augmented data through the
+ # teacher model.
+ if model_is_quantized:
+ # If the input model is quantized we have to additionally
+ # - quantize the augmented data before running it through the
+ # (quantized) teacher model
+ # - de-quantize the output for the training.
+ augmented_teach = teacher.dequantize_outputs(
+ teacher(
+ teacher.quantize_inputs(augment_teacher({input_name: teach}))
+ )
+ )[output_name]
+ else:
+ augmented_teach = teacher(augment_teacher({input_name: teach}))[
+ output_name
+ ]
return (augmented_train, augmented_teach)
dataset = dataset.map(
@@ -329,15 +355,20 @@ def train_in_dir(
"""
teacher_dir = baseline_dir if baseline_dir else train_dir
teacher = ParallelTFLiteModel(
- f"{teacher_dir}/replace.tflite",
+ ExtractPaths.tflite.replace(teacher_dir),
train_params.num_procs,
train_params.num_threads,
batch_size=train_params.batch_size,
)
- replace = TFLiteModel(f"{train_dir}/replace.tflite")
+ replace = TFLiteModel(ExtractPaths.tflite.replace(train_dir))
input_name, output_name = _get_io_tensors(teacher)
+ model_is_quantized = replace.is_tensor_quantized(name=input_name)
+
+ if model_is_quantized:
+ replace.check_datatypes(np.int8)
+
dataset = set_up_data_pipeline(
teacher,
replace,
@@ -354,6 +385,8 @@ def train_in_dir(
optimizer = tf.keras.optimizers.Nadam(learning_rate=train_params.learning_rate)
loss_fn = tf.keras.losses.MeanSquaredError()
+ if model_is_quantized:
+ model = tfmot.quantization.keras.quantize_model(model)
model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"])
logger.info(model.summary())
@@ -432,6 +465,7 @@ def train_in_dir(
replace.shape_from_name[input_name],
output_name,
replace.shape_from_name[output_name],
+ model_is_quantized,
)
output_filenames.append(checkpoint_filename)
@@ -446,6 +480,7 @@ def save_as_tflite(
input_shape: list,
output_name: str,
output_shape: list,
+ model_is_quantized: bool = False,
) -> None:
"""Save Keras model as TFLite file."""
@@ -464,7 +499,7 @@ def save_as_tflite(
keras_model.input.set_shape(orig_shape)
with fixed_input(keras_model, input_shape) as fixed_model:
- converter = tf.lite.TFLiteConverter.from_keras_model(fixed_model)
+ converter = get_tflite_converter(fixed_model, quantized=model_is_quantized)
tflite_model = converter.convert()
with open(filename, "wb") as file:
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py
index c6a7c88..b94350a 100644
--- a/src/mlia/nn/tensorflow/config.py
+++ b/src/mlia/nn/tensorflow/config.py
@@ -7,13 +7,23 @@ import logging
import tempfile
from collections import defaultdict
from pathlib import Path
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import List
import numpy as np
import tensorflow as tf
from mlia.core.context import Context
+from mlia.nn.tensorflow.optimizations.quantization import dequantize
+from mlia.nn.tensorflow.optimizations.quantization import is_quantized
+from mlia.nn.tensorflow.optimizations.quantization import QuantizationParameters
+from mlia.nn.tensorflow.optimizations.quantization import quantize
from mlia.nn.tensorflow.tflite_graph import load_fb
from mlia.nn.tensorflow.tflite_graph import save_fb
+from mlia.nn.tensorflow.utils import check_tflite_datatypes
from mlia.nn.tensorflow.utils import convert_to_tflite
from mlia.nn.tensorflow.utils import is_keras_model
from mlia.nn.tensorflow.utils import is_saved_model
@@ -71,6 +81,11 @@ class KerasModel(ModelConfiguration):
return self
+TFLiteIODetails = Dict[str, Dict[str, Any]]
+TFLiteIODetailsList = List[TFLiteIODetails]
+NameToTensorMap = Dict[str, np.ndarray]
+
+
class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
"""TensorFlow Lite model configuration."""
@@ -119,7 +134,22 @@ class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
self.shape_from_name = {d["name"]: d["shape"] for d in details}
self.batch_size = next(iter(self.shape_from_name.values()))[0]
- def __call__(self, named_input: dict) -> dict:
+ # Prepare quantization parameters for input and output
+ def named_quant_params(
+ details: TFLiteIODetailsList,
+ ) -> dict[str, QuantizationParameters]:
+ return {
+ str(detail["name"]): QuantizationParameters(
+ **detail["quantization_parameters"]
+ )
+ for detail in details
+ if TFLiteModel._is_tensor_quantized(detail)
+ }
+
+ self._quant_params_input = named_quant_params(self.input_details)
+ self._quant_params_output = named_quant_params(self.output_details)
+
+ def __call__(self, named_input: dict) -> NameToTensorMap:
"""Execute the model on one or a batch of named inputs \
(a dict of name: numpy array)."""
input_len = next(iter(named_input.values())).shape[0]
@@ -150,11 +180,11 @@ class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
)
return {k: np.concatenate(v) for k, v in named_ys.items()}
- def input_tensors(self) -> list:
+ def input_tensors(self) -> list[str]:
"""Return name from input details."""
return [d["name"] for d in self.input_details]
- def output_tensors(self) -> list:
+ def output_tensors(self) -> list[str]:
"""Return name from output details."""
return [d["name"] for d in self.output_details]
@@ -164,6 +194,71 @@ class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
"""Convert model to TensorFlow Lite format."""
return self
+ def _tensor_details(
+ self, name: str | None = None, idx: int | None = None
+ ) -> TFLiteIODetails:
+ """Get the details of the tensor by name or index."""
+ if idx is not None:
+ details = self.interpreter.get_tensor_details()[idx]
+ assert details["index"] == idx
+ elif name is not None:
+ for details_ in self.interpreter.get_tensor_details():
+ if name == details_["name"]:
+ details = details_
+ break
+ else:
+ raise NameError(
+ f"Tensor '{name}' not found in model {self.model_path}."
+ )
+ else:
+ raise ValueError("Either tensor name or index needs to be passed.")
+
+ assert isinstance(details, dict)
+ return cast(TFLiteIODetails, details)
+
+ @staticmethod
+ def _is_tensor_quantized(details: TFLiteIODetails) -> bool:
+ """Use tensor details to check if the corresponding tensor is quantized."""
+ quant_params = QuantizationParameters(**details["quantization_parameters"])
+ return is_quantized(quant_params)
+
+ def is_tensor_quantized(
+ self,
+ name: str | None = None,
+ idx: int | None = None,
+ ) -> bool:
+ """Check if the given tensor (identified by name or index) is quantized."""
+ details = self._tensor_details(name, idx)
+ return self._is_tensor_quantized(details)
+
+ def check_datatypes(self, *allowed_types: type) -> None:
+ """Check if the model only has the given allowed datatypes."""
+ check_tflite_datatypes(self.model_path, *allowed_types)
+
+ @staticmethod
+ def _quant_dequant(
+ tensors: NameToTensorMap,
+ quant_params: dict[str, QuantizationParameters],
+ func: Callable,
+ ) -> NameToTensorMap:
+ """Quantize/de-quantize tensor using the given parameters and function."""
+ return {
+ name: (func(tensor, quant_params[name]) if name in quant_params else tensor)
+ for name, tensor in tensors.items()
+ }
+
+ def dequantize_outputs(self, outputs: NameToTensorMap) -> NameToTensorMap:
+ """De-quantize the given model outputs."""
+ dequant_outputs = self._quant_dequant(
+ outputs, self._quant_params_output, dequantize
+ )
+ return dequant_outputs
+
+ def quantize_inputs(self, inputs: NameToTensorMap) -> NameToTensorMap:
+ """Quantize the given model inputs."""
+ quant_inputs = self._quant_dequant(inputs, self._quant_params_input, quantize)
+ return quant_inputs
+
class TfModel(ModelConfiguration): # pylint: disable=abstract-method
"""TensorFlow model configuration.
diff --git a/src/mlia/nn/tensorflow/optimizations/quantization.py b/src/mlia/nn/tensorflow/optimizations/quantization.py
new file mode 100644
index 0000000..4e3c2c2
--- /dev/null
+++ b/src/mlia/nn/tensorflow/optimizations/quantization.py
@@ -0,0 +1,74 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Contains functionality for quantization and de-quantization."""
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import cast
+
+import numpy as np
+import tensorflow as tf
+
+
+@dataclass
+class QuantizationParameters:
+ """Collection of TensorFlow Lite quantization parameters.
+
+ Can be directly initialized from TensorFlow Lite tensor details, e.g.
+
+ ```
+ QuantizationParameters(
+ **interpreter.get_input_details()[0]["quantization_parameters"]
+ )
+ ```
+ """
+
+ scales: np.ndarray
+ zero_points: np.ndarray
+ quantized_dimension: int
+
+
+def is_quantized(quant_params: QuantizationParameters) -> bool:
+ """Check if the quantization parameters describe a quantized tensor."""
+ return quant_params.scales.size > 0
+
+
+def dequantize(
+ quantized_tensor: np.ndarray | tf.Tensor, quant_params: QuantizationParameters
+) -> np.ndarray:
+ """De-quantize the input tensor using the given quantization parameters."""
+ assert isinstance(quantized_tensor, (tf.Tensor, np.ndarray))
+ assert (
+ not isinstance(quantized_tensor, tf.Tensor)
+ or quantized_tensor.dtype.is_quantized
+ ) and (
+ not isinstance(quantized_tensor, np.ndarray)
+ or issubclass(quantized_tensor.dtype.type, np.integer)
+ ), (
+ f"Input tensor for de-quantization is of type {quantized_tensor.dtype}, "
+ "but it must be int."
+ )
+
+ dequantized_tensor = np.subtract(
+ quantized_tensor, quant_params.zero_points, dtype=np.float32
+ )
+ dequantized_tensor = np.multiply(
+ dequantized_tensor, quant_params.scales, dtype=np.float32
+ )
+ return dequantized_tensor
+
+
+def quantize(
+ tensor: np.ndarray | tf.Tensor, quant_params: QuantizationParameters
+) -> np.ndarray:
+ """Quantize the given float input tensor to int8."""
+ assert isinstance(tensor, (tf.Tensor, np.ndarray))
+ assert (not isinstance(tensor, tf.Tensor) or tensor.dtype.is_floating) and (
+ not isinstance(tensor, np.ndarray) or issubclass(tensor.dtype.type, np.floating)
+ ), f"Input tensor for quantization is of type {tensor.dtype}, but it must be float."
+
+ quantized_tensor = (tensor / quant_params.scales) + quant_params.zero_points
+ quantized_tensor = np.clip( # type: ignore
+ quantized_tensor, -128, 127, dtype=np.int8, casting="unsafe"
+ )
+ return cast(np.ndarray, quantized_tensor)
diff --git a/src/mlia/nn/tensorflow/utils.py b/src/mlia/nn/tensorflow/utils.py
index 77ac529..b8d45c6 100644
--- a/src/mlia/nn/tensorflow/utils.py
+++ b/src/mlia/nn/tensorflow/utils.py
@@ -123,3 +123,33 @@ def get_tflite_converter(
converter.inference_output_type = tf.int8
return converter
+
+
+def get_tflite_model_type_map(model_filename: str | Path) -> dict[str, type]:
+ """Get type map from tflite model."""
+ model_type_map: dict[str, Any] = {}
+ interpreter = tf.lite.Interpreter(str(model_filename))
+ interpreter.allocate_tensors()
+ input_details = interpreter.get_input_details()
+ model_type_map = {
+ input_detail["name"]: input_detail["dtype"] for input_detail in input_details
+ }
+ return model_type_map
+
+
+def check_tflite_datatypes(model_filename: str | Path, *allowed_types: type) -> None:
+ """Check if the model only has the given allowed datatypes."""
+ type_map = get_tflite_model_type_map(model_filename)
+ types = set(type_map.values())
+ allowed = set(allowed_types)
+ unexpected = types - allowed
+
+ def cls_to_str(types: set[type]) -> list[str]:
+ return [t.__name__ for t in types]
+
+ if len(unexpected) > 0:
+ raise TypeError(
+ f"Model {model_filename} has "
+ f"unexpected data types: {cls_to_str(unexpected)}. "
+ f"Only {cls_to_str(allowed)} are allowed."
+ )
diff --git a/tests/test_nn_rewrite_core_extract.py b/tests/test_nn_rewrite_core_extract.py
new file mode 100644
index 0000000..09eca77
--- /dev/null
+++ b/tests/test_nn_rewrite_core_extract.py
@@ -0,0 +1,38 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module mlia.nn.rewrite.core.extract."""
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Any
+from typing import Iterable
+
+import pytest
+
+from mlia.nn.rewrite.core.extract import ExtractPaths
+from mlia.nn.rewrite.core.graph_edit.record import DEQUANT_SUFFIX
+
+
+@pytest.mark.parametrize("dir_path", ("/dev/null", Path("/dev/null")))
+@pytest.mark.parametrize("model_is_quantized", (False, True))
+@pytest.mark.parametrize(
+ ("obj", "func_names", "suffix"),
+ (
+ (ExtractPaths.tflite, ("start", "replace", "end"), ".tflite"),
+ (ExtractPaths.tfrec, ("input", "output", "end"), ".tfrec"),
+ ),
+)
+def test_extract_paths(
+ dir_path: str | Path,
+ model_is_quantized: bool,
+ obj: Any,
+ func_names: Iterable[str],
+ suffix: str,
+) -> None:
+ """Test class ExtractPaths."""
+ for func_name in func_names:
+ func = getattr(obj, func_name)
+ path = func(dir_path, model_is_quantized)
+ assert path == Path(dir_path, path.relative_to(dir_path))
+ assert path.suffix == suffix
+ assert not model_is_quantized or path.stem.endswith(DEQUANT_SUFFIX)
diff --git a/tests/test_nn_rewrite_core_graph_edit_record.py b/tests/test_nn_rewrite_core_graph_edit_record.py
index 41b9c50..422b53e 100644
--- a/tests/test_nn_rewrite_core_graph_edit_record.py
+++ b/tests/test_nn_rewrite_core_graph_edit_record.py
@@ -3,43 +3,57 @@
"""Tests for module mlia.nn.rewrite.graph_edit.record."""
from pathlib import Path
+import numpy as np
+import pytest
import tensorflow as tf
+from mlia.nn.rewrite.core.extract import ExtractPaths
from mlia.nn.rewrite.core.graph_edit.record import record_model
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
+def data_matches_outputs(
+ name: str,
+ tensor: tf.Tensor,
+ model_outputs: list,
+ dequantized_output: bool,
+) -> bool:
+ """Check that the name and the tensor match any of the model outputs."""
+ for model_output in model_outputs:
+ if model_output["name"] == name:
+ # If the name is a match, tensor shape and type have to match!
+ tensor_shape = tensor.shape.as_list()
+ tensor_type = tensor.dtype.as_numpy_dtype
+ return all(
+ (
+ tensor_shape == model_output["shape"].tolist(),
+ tensor_type == np.float32
+ if dequantized_output
+ else model_output["dtype"],
+ )
+ )
+ return False
+
+
def check_record_model(
test_tflite_model: Path,
tmp_path: Path,
test_tfrecord: Path,
batch_size: int,
+ dequantize_output: bool,
) -> None:
"""Test the function record_model()."""
- output_file = tmp_path / "out.tfrecord"
+ output_file = ExtractPaths.tfrec.output(tmp_path)
record_model(
input_filename=str(test_tfrecord),
model_filename=str(test_tflite_model),
output_filename=str(output_file),
batch_size=batch_size,
+ dequantize_output=dequantize_output,
)
+ output_file = ExtractPaths.tfrec.output(tmp_path, dequantize_output)
assert output_file.is_file()
- def data_matches_outputs(name: str, tensor: tf.Tensor, model_outputs: list) -> bool:
- """Check that the name and the tensor match any of the model outputs."""
- for model_output in model_outputs:
- if model_output["name"] == name:
- # If the name is a match, tensor shape and type have to match!
- tensor_shape = tensor.shape.as_list()
- tensor_type = tensor.dtype.as_numpy_dtype
- return all(
- (
- tensor_shape == model_output["shape"].tolist(),
- tensor_type == model_output["dtype"],
- )
- )
- return False
-
# Now load model and the data and make sure that the written data matches
# any of the model outputs
interpreter = tf.lite.Interpreter(str(test_tflite_model))
@@ -47,4 +61,19 @@ def check_record_model(
dataset = numpytf_read(str(output_file))
for data in dataset:
for name, tensor in data.items():
- assert data_matches_outputs(name, tensor, model_outputs)
+ assert data_matches_outputs(name, tensor, model_outputs, dequantize_output)
+
+
+@pytest.mark.parametrize("batch_size", (None, 1, 2))
+@pytest.mark.parametrize("dequantize_output", (True, False))
+def test_record_model(
+ test_tflite_model: Path,
+ tmp_path: Path,
+ test_tfrecord: Path,
+ batch_size: int,
+ dequantize_output: bool,
+) -> None:
+ """Test the function record_model()."""
+ check_record_model(
+ test_tflite_model, tmp_path, test_tfrecord, batch_size, dequantize_output
+ )
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
index b001a09..ef52320 100644
--- a/tests/test_nn_rewrite_core_train.py
+++ b/tests/test_nn_rewrite_core_train.py
@@ -1,6 +1,6 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
-"""Tests for module mlia.nn.rewrite.train."""
+"""Tests for module mlia.nn.rewrite.core.train."""
# pylint: disable=too-many-arguments
from __future__ import annotations
@@ -47,10 +47,11 @@ def check_train(
tfrecord: Path,
train_params: TrainingParameters = TestTrainingParameters(),
use_unmodified_model: bool = False,
+ quantized: bool = False,
) -> None:
"""Test the train() function."""
with TemporaryDirectory() as tmp_dir:
- output_file = Path(tmp_dir, "out.tfrecord")
+ output_file = Path(tmp_dir, "out.tflite")
result = train(
source_model=str(tflite_model),
unmodified_model=str(tflite_model) if use_unmodified_model else None,
@@ -65,6 +66,17 @@ def check_train(
assert all(res >= 0.0 for res in result[0]), f"Results out of bound: {result}"
assert output_file.is_file()
+ if quantized:
+ interpreter = tf.lite.Interpreter(model_path=str(output_file))
+ interpreter.allocate_tensors()
+ # Check that the quantization parameters are non-zero
+ assert all(interpreter.get_output_details()[0]["quantization"])
+ assert all(interpreter.get_input_details()[0]["quantization"])
+ dtypes = []
+ for tensor_detail in interpreter.get_tensor_details():
+ dtypes.append(tensor_detail["dtype"])
+ assert all(np.issubdtype(dtype, np.integer) for dtype in dtypes)
+
@pytest.mark.parametrize(
(
@@ -89,7 +101,7 @@ def check_train(
),
),
)
-def test_train(
+def test_train_fp32(
test_tflite_model_fp32: Path,
test_tfrecord_fp32: Path,
batch_size: int,
@@ -114,6 +126,55 @@ def test_train(
)
+@pytest.mark.parametrize(
+ (
+ "batch_size",
+ "show_progress",
+ "augmentation_preset",
+ "lr_schedule",
+ "use_unmodified_model",
+ "num_procs",
+ ),
+ (
+ (1, False, AUGMENTATION_PRESETS["none"], "cosine", False, 2),
+ (32, True, AUGMENTATION_PRESETS["gaussian"], "late", True, 1),
+ (2, False, AUGMENTATION_PRESETS["mixup"], "constant", True, 0),
+ (
+ 1,
+ False,
+ AUGMENTATION_PRESETS["mix_gaussian_large"],
+ "cosine",
+ False,
+ 2,
+ ),
+ ),
+)
+def test_train_int8(
+ test_tflite_model: Path,
+ test_tfrecord: Path,
+ batch_size: int,
+ show_progress: bool,
+ augmentation_preset: tuple[float | None, float | None],
+ lr_schedule: LearningRateSchedule,
+ use_unmodified_model: bool,
+ num_procs: int,
+) -> None:
+ """Test the train() function with valid parameters."""
+ check_train(
+ tflite_model=test_tflite_model,
+ tfrecord=test_tfrecord,
+ train_params=TestTrainingParameters(
+ batch_size=batch_size,
+ show_progress=show_progress,
+ augmentations=augmentation_preset,
+ learning_rate_schedule=lr_schedule,
+ num_procs=num_procs,
+ ),
+ use_unmodified_model=use_unmodified_model,
+ quantized=True,
+ )
+
+
def test_train_invalid_schedule(
test_tflite_model_fp32: Path,
test_tfrecord_fp32: Path,
diff --git a/tests/test_nn_tensorflow_config.py b/tests/test_nn_tensorflow_config.py
index 48aec0a..fff3857 100644
--- a/tests/test_nn_tensorflow_config.py
+++ b/tests/test_nn_tensorflow_config.py
@@ -111,3 +111,15 @@ def test_tflite_model_call(
for named_input in data.as_numpy_iterator():
res = model(named_input)
assert res
+
+
+def test_tflite_model_is_tensor_quantized(test_tflite_model: Path) -> None:
+ """Test function TFLiteModel.is_tensor_quantized()."""
+ model = TFLiteModel(test_tflite_model)
+ input_details = model.input_details[0]
+ assert model.is_tensor_quantized(name=input_details["name"])
+ assert model.is_tensor_quantized(idx=input_details["index"])
+ with pytest.raises(ValueError):
+ assert model.is_tensor_quantized()
+ with pytest.raises(NameError):
+ assert model.is_tensor_quantized(name="NAME_DOES_NOT_EXIST")
diff --git a/tests/test_nn_tensorflow_optimizations_quantization.py b/tests/test_nn_tensorflow_optimizations_quantization.py
new file mode 100644
index 0000000..5228cec
--- /dev/null
+++ b/tests/test_nn_tensorflow_optimizations_quantization.py
@@ -0,0 +1,53 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module optimizations/quantization."""
+from __future__ import annotations
+
+from itertools import chain
+from pathlib import Path
+from typing import Generator
+
+import numpy as np
+from numpy.core.numeric import isclose
+
+from mlia.nn.tensorflow.config import TFLiteModel
+from mlia.nn.tensorflow.optimizations.quantization import dequantize
+from mlia.nn.tensorflow.optimizations.quantization import is_quantized
+from mlia.nn.tensorflow.optimizations.quantization import QuantizationParameters
+from mlia.nn.tensorflow.optimizations.quantization import quantize
+
+
+def model_io_quant_params(model_path: Path) -> Generator:
+ """Generate QuantizationParameters for all model inputs and outputs."""
+ model = TFLiteModel(model_path=model_path)
+ for details in chain(model.input_details, model.output_details):
+ yield QuantizationParameters(**details["quantization_parameters"])
+
+
+def test_is_quantized(test_tflite_model: Path) -> None:
+ """Test function is_quantized() with a quantized model."""
+ for quant_params in model_io_quant_params(test_tflite_model):
+ assert is_quantized(quant_params)
+
+
+def test_is_not_quantized(test_tflite_model_fp32: Path) -> None:
+ """Test function is_quantized() with an unquantized model."""
+ for quant_params in model_io_quant_params(test_tflite_model_fp32):
+ assert not is_quantized(quant_params)
+
+
+def test_quantize() -> None:
+ """Test function quantize()."""
+ ref_dequant = np.array((0.0, 0.1, 0.2, 0.3))
+ ref_quant = np.array((0, 10, 20, 30), dtype=np.int8)
+ quant_params = QuantizationParameters(
+ scales=np.array([0.01]), zero_points=np.array([0.0]), quantized_dimension=0
+ )
+
+ quant = quantize(ref_dequant, quant_params)
+ assert quant.dtype == np.int8
+ assert np.all(quant == ref_quant)
+
+ dequant = dequantize(quant, quant_params)
+ assert dequant.dtype == np.float32
+ assert np.all(isclose(dequant, ref_dequant, atol=0.03))
diff --git a/tests/test_nn_tensorflow_utils.py b/tests/test_nn_tensorflow_utils.py
index 14b06c4..dab8b4e 100644
--- a/tests/test_nn_tensorflow_utils.py
+++ b/tests/test_nn_tensorflow_utils.py
@@ -1,14 +1,17 @@
# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Test for module utils/test_utils."""
+import re
from pathlib import Path
import numpy as np
import pytest
import tensorflow as tf
+from mlia.nn.tensorflow.utils import check_tflite_datatypes
from mlia.nn.tensorflow.utils import convert_to_tflite
from mlia.nn.tensorflow.utils import get_tf_tensor_shape
+from mlia.nn.tensorflow.utils import get_tflite_model_type_map
from mlia.nn.tensorflow.utils import is_keras_model
from mlia.nn.tensorflow.utils import is_tflite_model
from mlia.nn.tensorflow.utils import representative_dataset
@@ -109,3 +112,31 @@ def test_is_keras_model(model_path: Path, expected_result: bool) -> None:
def test_get_tf_tensor_shape(test_tf_model: Path) -> None:
"""Test get_tf_tensor_shape with test model."""
assert get_tf_tensor_shape(str(test_tf_model)) == [1, 28, 28, 1]
+
+
+def test_tflite_model_type_map(
+ test_tflite_model_fp32: Path, test_tflite_model: Path
+) -> None:
+ """Test the model type map function."""
+ assert get_tflite_model_type_map(test_tflite_model_fp32) == {
+ "serving_default_input:0": np.float32
+ }
+ assert get_tflite_model_type_map(test_tflite_model) == {
+ "serving_default_input:0": np.int8
+ }
+
+
+def test_check_tflite_datatypes(
+ test_tflite_model_fp32: Path, test_tflite_model: Path
+) -> None:
+ """Test the model type map function."""
+ check_tflite_datatypes(test_tflite_model_fp32, np.float32)
+ check_tflite_datatypes(test_tflite_model, np.int8)
+
+ with pytest.raises(
+ Exception,
+ match=re.escape(
+ "unexpected data types: ['float32']. Only ['int8'] are allowed"
+ ),
+ ):
+ check_tflite_datatypes(test_tflite_model_fp32, np.int8)