aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/cut.py8
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/join.py10
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py77
-rw-r--r--src/mlia/nn/rewrite/core/train.py306
-rw-r--r--src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py138
-rw-r--r--src/mlia/nn/rewrite/core/utils/parallel.py4
-rw-r--r--src/mlia/nn/rewrite/core/utils/utils.py32
-rw-r--r--src/mlia/nn/rewrite/library/fc_layer.py14
-rw-r--r--src/mlia/nn/select.py15
-rw-r--r--src/mlia/nn/tensorflow/config.py100
-rw-r--r--src/mlia/nn/tensorflow/tflite_graph.py27
-rw-r--r--src/mlia/target/ethos_u/data_collection.py8
-rw-r--r--tests/conftest.py39
-rw-r--r--tests/test_nn_rewrite_core_graph_edit_join.py6
-rw-r--r--tests/test_nn_rewrite_core_graph_edit_record.py4
-rw-r--r--tests/test_nn_rewrite_core_rewrite.py6
-rw-r--r--tests/test_nn_rewrite_core_train.py76
-rw-r--r--tests/test_nn_rewrite_core_utils.py33
-rw-r--r--tests/test_nn_rewrite_core_utils_numpy_tfrecord.py25
-rw-r--r--tests/test_nn_tensorflow_config.py40
-rw-r--r--tests/test_nn_tensorflow_tflite_graph.py30
-rw-r--r--tests/utils/rewrite.py18
22 files changed, 591 insertions, 425 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/cut.py b/src/mlia/nn/rewrite/core/graph_edit/cut.py
index 2707eb1..13a5268 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/cut.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/cut.py
@@ -9,8 +9,8 @@ import tensorflow as tf
from tensorflow.lite.python.schema_py_generated import ModelT
from tensorflow.lite.python.schema_py_generated import SubGraphT
-from mlia.nn.rewrite.core.utils.utils import load
-from mlia.nn.rewrite.core.utils.utils import save
+from mlia.nn.tensorflow.tflite_graph import load_fb
+from mlia.nn.tensorflow.tflite_graph import save_fb
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
@@ -138,8 +138,8 @@ def cut_model(
output_file: str,
) -> None:
"""Cut subgraphs and simplify a given model."""
- model = load(model_file)
+ model = load_fb(model_file)
subgraph = model.subgraphs[subgraph_index]
cut_subgraph(subgraph, input_names, output_names)
simplify(model)
- save(model, output_file)
+ save_fb(model, output_file)
diff --git a/src/mlia/nn/rewrite/core/graph_edit/join.py b/src/mlia/nn/rewrite/core/graph_edit/join.py
index 2530ec8..70109d8 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/join.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/join.py
@@ -11,8 +11,8 @@ from tensorflow.lite.python.schema_py_generated import ModelT
from tensorflow.lite.python.schema_py_generated import OperatorCodeT
from tensorflow.lite.python.schema_py_generated import SubGraphT
-from mlia.nn.rewrite.core.utils.utils import load
-from mlia.nn.rewrite.core.utils.utils import save
+from mlia.nn.tensorflow.tflite_graph import load_fb
+from mlia.nn.tensorflow.tflite_graph import save_fb
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
@@ -26,12 +26,12 @@ def join_models(
subgraph_dst: int = 0,
) -> None:
"""Join two models and save the result into a given model file path."""
- src_model = load(input_src)
- dst_model = load(input_dst)
+ src_model = load_fb(input_src)
+ dst_model = load_fb(input_dst)
src_subgraph = src_model.subgraphs[subgraph_src]
dst_subgraph = dst_model.subgraphs[subgraph_dst]
join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph)
- save(dst_model, output_file)
+ save_fb(dst_model, output_file)
def join_subgraphs(
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
index 0d182df..6b27984 100644
--- a/src/mlia/nn/rewrite/core/rewrite.py
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -4,6 +4,7 @@
from __future__ import annotations
import importlib
+import logging
import tempfile
from dataclasses import dataclass
from pathlib import Path
@@ -12,13 +13,14 @@ from typing import Any
from mlia.core.errors import ConfigurationError
from mlia.nn.common import Optimizer
from mlia.nn.common import OptimizerConfiguration
-from mlia.nn.rewrite.core.train import eval_in_dir
-from mlia.nn.rewrite.core.train import join_in_dir
from mlia.nn.rewrite.core.train import train
-from mlia.nn.rewrite.core.train import train_in_dir
+from mlia.nn.rewrite.core.train import TrainingParameters
from mlia.nn.tensorflow.config import TFLiteModel
+logger = logging.getLogger(__name__)
+
+
@dataclass
class RewriteConfiguration(OptimizerConfiguration):
"""Rewrite configuration."""
@@ -26,6 +28,7 @@ class RewriteConfiguration(OptimizerConfiguration):
optimization_target: str
layers_to_optimize: list[str] | None = None
dataset: Path | None = None
+ train_params: TrainingParameters = TrainingParameters()
def __str__(self) -> str:
"""Return string representation of the configuration."""
@@ -40,8 +43,8 @@ class Rewriter(Optimizer):
):
"""Init Rewriter instance."""
self.model = TFLiteModel(tflite_model_path)
+ self.model_path = tflite_model_path
self.optimizer_configuration = optimizer_configuration
- self.train_dir = ""
def apply_optimization(self) -> None:
"""Apply the rewrite flow."""
@@ -61,50 +64,36 @@ class Rewriter(Optimizer):
replace_fn = get_function(replace_function)
- augmentation_preset = (None, None)
use_unmodified_model = True
tflite_model = self.model.model_path
tfrecord = str(self.optimizer_configuration.dataset)
- with tempfile.TemporaryDirectory() as tmp_dir:
- tmp_output = Path(tmp_dir, "output.tflite")
-
- if self.train_dir:
- tmp_new = Path(tmp_dir, "new.tflite")
- new_part = train_in_dir(
- train_dir=self.train_dir,
- baseline_dir=None,
- output_filename=tmp_new,
- replace_fn=replace_fn,
- augmentations=augmentation_preset,
- steps=32,
- learning_rate=1e-3,
- batch_size=1,
- verbose=True,
- show_progress=True,
- )
- eval_in_dir(self.train_dir, new_part[0])
- join_in_dir(self.train_dir, new_part[0], str(tmp_output))
- else:
- if not self.optimizer_configuration.layers_to_optimize:
- raise ConfigurationError(
- "Input and output tensor names need to be set for rewrite."
- )
- train(
- source_model=tflite_model,
- unmodified_model=tflite_model if use_unmodified_model else None,
- output_model=str(tmp_output),
- input_tfrec=str(tfrecord),
- replace_fn=replace_fn,
- input_tensors=[self.optimizer_configuration.layers_to_optimize[0]],
- output_tensors=[self.optimizer_configuration.layers_to_optimize[1]],
- augment=augmentation_preset,
- steps=32,
- learning_rate=1e-3,
- batch_size=1,
- verbose=True,
- show_progress=True,
- )
+ tmp_dir = tempfile.mkdtemp()
+ tmp_output = Path(tmp_dir, "output.tflite")
+
+ if not self.optimizer_configuration.layers_to_optimize:
+ raise ConfigurationError(
+ "Input and output tensor names need to be set for rewrite."
+ )
+ result = train(
+ source_model=tflite_model,
+ unmodified_model=tflite_model if use_unmodified_model else None,
+ output_model=str(tmp_output),
+ input_tfrec=str(tfrecord),
+ replace_fn=replace_fn,
+ input_tensors=[self.optimizer_configuration.layers_to_optimize[0]],
+ output_tensors=[self.optimizer_configuration.layers_to_optimize[1]],
+ train_params=self.optimizer_configuration.train_params,
+ )
+
+ self.model = TFLiteModel(tmp_output)
+
+ if result:
+ stats_as_str = ", ".join(str(stats) for stats in result)
+ logger.info(
+ "The MAE and NRMSE between original and replacement [%s]",
+ stats_as_str,
+ )
def get_model(self) -> TFLiteModel:
"""Return optimized model."""
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index c8497a4..42bf653 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -1,8 +1,8 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Sequential trainer."""
-# pylint: disable=too-many-arguments, too-many-instance-attributes,
-# pylint: disable=too-many-locals, too-many-branches, too-many-statements
+# pylint: disable=too-many-locals
+# pylint: disable=too-many-statements
from __future__ import annotations
import logging
@@ -10,10 +10,13 @@ import math
import os
import tempfile
from collections import defaultdict
+from contextlib import contextmanager
+from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import Callable
from typing import cast
+from typing import Generator as GeneratorType
from typing import get_args
from typing import Literal
@@ -27,10 +30,10 @@ from mlia.nn.rewrite.core.graph_edit.join import join_models
from mlia.nn.rewrite.core.graph_edit.record import record_model
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
-from mlia.nn.rewrite.core.utils.numpy_tfrecord import TFLiteModel
from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel
-from mlia.nn.rewrite.core.utils.utils import load
-from mlia.nn.rewrite.core.utils.utils import save
+from mlia.nn.tensorflow.config import TFLiteModel
+from mlia.nn.tensorflow.tflite_graph import load_fb
+from mlia.nn.tensorflow.tflite_graph import save_fb
from mlia.utils.logging import log_action
@@ -38,7 +41,7 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
logger = logging.getLogger(__name__)
-augmentation_presets = {
+AUGMENTATION_PRESETS = {
"none": (None, None),
"gaussian": (None, 1.0),
"mixup": (1.0, None),
@@ -51,6 +54,21 @@ LearningRateSchedule = Literal["cosine", "late", "constant"]
LEARNING_RATE_SCHEDULES = get_args(LearningRateSchedule)
+@dataclass
+class TrainingParameters:
+ """Define default parameters for the training."""
+
+ augmentations: tuple[float | None, float | None] = AUGMENTATION_PRESETS["gaussian"]
+ batch_size: int = 32
+ steps: int = 48000
+ learning_rate: float = 1e-3
+ learning_rate_schedule: LearningRateSchedule = "cosine"
+ num_procs: int = 1
+ num_threads: int = 0
+ show_progress: bool = True
+ checkpoint_at: list | None = None
+
+
def train(
source_model: str,
unmodified_model: Any,
@@ -59,16 +77,7 @@ def train(
replace_fn: Callable,
input_tensors: list,
output_tensors: list,
- augment: tuple[float | None, float | None],
- steps: int,
- learning_rate: float,
- batch_size: int,
- verbose: bool,
- show_progress: bool,
- learning_rate_schedule: LearningRateSchedule = "cosine",
- checkpoint_at: list | None = None,
- num_procs: int = 1,
- num_threads: int = 0,
+ train_params: TrainingParameters = TrainingParameters(),
) -> Any:
"""Extract and train a model, and return the results."""
if unmodified_model:
@@ -95,29 +104,27 @@ def train(
input_tfrec,
input_tensors,
output_tensors,
- num_procs=num_procs,
- num_threads=num_threads,
+ num_procs=train_params.num_procs,
+ num_threads=train_params.num_threads,
)
tflite_filenames = train_in_dir(
- train_dir,
- unmodified_model_dir_path,
- Path(train_dir, "new.tflite"),
- replace_fn,
- augment,
- steps,
- learning_rate,
- batch_size,
- checkpoint_at=checkpoint_at,
- verbose=verbose,
- show_progress=show_progress,
- num_procs=num_procs,
- num_threads=num_threads,
- schedule=learning_rate_schedule,
+ train_dir=train_dir,
+ baseline_dir=unmodified_model_dir_path,
+ output_filename=Path(train_dir, "new.tflite"),
+ replace_fn=replace_fn,
+ train_params=train_params,
)
for i, filename in enumerate(tflite_filenames):
- results.append(eval_in_dir(train_dir, filename, num_procs, num_threads))
+ results.append(
+ eval_in_dir(
+ train_dir,
+ filename,
+ train_params.num_procs,
+ train_params.num_threads,
+ )
+ )
if output_model:
if i + 1 < len(tflite_filenames):
@@ -133,7 +140,7 @@ def train(
cast(tempfile.TemporaryDirectory, unmodified_model_dir).cleanup()
return (
- results if checkpoint_at else results[0]
+ results if train_params.checkpoint_at else results[0]
) # only return a list if multiple checkpoints are asked for
@@ -176,46 +183,24 @@ def join_in_dir(model_dir: str, new_part: str, output_model: str) -> None:
join_models(Path(model_dir, "start.tflite"), new_end, output_model)
-def train_in_dir(
- train_dir: str,
- baseline_dir: Any,
- output_filename: Path,
- replace_fn: Callable,
- augmentations: tuple[float | None, float | None],
- steps: int,
- learning_rate: float = 1e-3,
- batch_size: int = 32,
- checkpoint_at: list | None = None,
- schedule: str = "cosine",
- verbose: bool = False,
- show_progress: bool = False,
- num_procs: int = 0,
- num_threads: int = 1,
-) -> list:
- """Train a replacement for replace.tflite using the input.tfrec \
- and output.tfrec in train_dir.
-
- If baseline_dir is provided, train the replacement to match baseline
- outputs for train_dir inputs. Result saved as new.tflite in train_dir.
- """
- teacher_dir = baseline_dir if baseline_dir else train_dir
- teacher = ParallelTFLiteModel(
- f"{teacher_dir}/replace.tflite", num_procs, num_threads, batch_size=batch_size
- )
- replace = TFLiteModel(f"{train_dir}/replace.tflite")
+def _get_io_tensors(model: TFLiteModel) -> tuple[str, str]:
assert (
- len(teacher.input_tensors()) == 1
+ len(model.input_tensors()) == 1
), f"Can only train replacements with a single input tensor right now, \
- found {teacher.input_tensors()}"
+ found {model.input_tensors()}"
assert (
- len(teacher.output_tensors()) == 1
+ len(model.output_tensors()) == 1
), f"Can only train replacements with a single output tensor right now, \
- found {teacher.output_tensors()}"
+ found {model.output_tensors()}"
+
+ input_name = model.input_tensors()[0]
+ output_name = model.output_tensors()[0]
+ return (input_name, output_name)
- input_name = teacher.input_tensors()[0]
- output_name = teacher.output_tensors()[0]
+def _check_model_compatibility(teacher: TFLiteModel, replace: TFLiteModel) -> None:
+ """Assert that teacher and replaced sub-graph are compatible."""
assert len(teacher.shape_from_name) == len(
replace.shape_from_name
), f"Baseline and train models must have the same number of inputs and outputs. \
@@ -230,10 +215,37 @@ def train_in_dir(
subgraph being replaced. Teacher: {teacher.shape_from_name}\n \
Train dir: {replace.shape_from_name}"
+
+def set_up_data_pipeline(
+ teacher: TFLiteModel,
+ replace: TFLiteModel,
+ train_dir: str,
+ augmentations: tuple[float | None, float | None],
+ steps: int,
+ batch_size: int = 32,
+) -> tf.data.Dataset:
+ """Create a data pipeline for training of the replacement model."""
+ _check_model_compatibility(teacher, replace)
+
+ input_name, output_name = _get_io_tensors(teacher)
+
input_filename = Path(train_dir, "input.tfrec")
total = numpytf_count(str(input_filename))
dict_inputs = numpytf_read(str(input_filename))
+
inputs = dict_inputs.map(lambda d: tf.squeeze(d[input_name], axis=0))
+
+ steps_per_epoch = math.ceil(total / batch_size)
+ epochs = int(math.ceil(steps / steps_per_epoch))
+ logger.info(
+ "Training on %d items for %d steps (%d epochs with batch size %d)",
+ total,
+ epochs * steps_per_epoch,
+ epochs,
+ batch_size,
+ )
+
+ teacher_dir = Path(teacher.model_path).parent
if any(augmentations):
# Map the teacher inputs here because the augmentation stage passes these
# through a TFLite model to get the outputs
@@ -245,17 +257,6 @@ def train_in_dir(
lambda d: tf.squeeze(d[output_name], axis=0)
)
- steps_per_epoch = math.ceil(total / batch_size)
- epochs = int(math.ceil(steps / steps_per_epoch))
- if verbose:
- logger.info(
- "Training on %d items for %d steps (%d epochs with batch size %d)",
- total,
- epochs * steps_per_epoch,
- epochs,
- batch_size,
- )
-
dataset = tf.data.Dataset.zip((inputs, teacher_outputs))
if epochs > 1:
dataset = dataset.cache()
@@ -268,10 +269,9 @@ def train_in_dir(
train: Any, teach: Any # pylint: disable=redefined-outer-name
) -> tuple:
"""Return results of train and teach based on augmentations."""
- return (
- augment_train({input_name: train})[input_name],
- teacher(augment_teacher({input_name: teach}))[output_name],
- )
+ augmented_train = augment_train({input_name: train})[input_name]
+ augmented_teach = teacher(augment_teacher({input_name: teach}))[output_name]
+ return (augmented_train, augmented_teach)
dataset = dataset.map(
lambda augment_train, augment_teach: tf.py_function(
@@ -281,18 +281,67 @@ def train_in_dir(
)
)
+ # Restore data shapes of the dataset as they are set to unknown per default
+ # and get lost during augmentation with tf.py_function.
+ shape_in, shape_out = (
+ teacher.shape_from_name[name].tolist() for name in (input_name, output_name)
+ )
+ for shape in (shape_in, shape_out):
+ shape[0] = None # set dynamic batch size
+
+ def restore_shapes(input_: Any, output: Any) -> tuple[Any, Any]:
+ input_.set_shape(shape_in)
+ output.set_shape(shape_out)
+ return input_, output
+
+ dataset = dataset.map(restore_shapes)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
+ return dataset
+
+
+def train_in_dir(
+ train_dir: str,
+ baseline_dir: Any,
+ output_filename: Path,
+ replace_fn: Callable,
+ train_params: TrainingParameters = TrainingParameters(),
+) -> list[str]:
+ """Train a replacement for replace.tflite using the input.tfrec \
+ and output.tfrec in train_dir.
+
+ If baseline_dir is provided, train the replacement to match baseline
+ outputs for train_dir inputs. Result saved as new.tflite in train_dir.
+ """
+ teacher_dir = baseline_dir if baseline_dir else train_dir
+ teacher = ParallelTFLiteModel(
+ f"{teacher_dir}/replace.tflite",
+ train_params.num_procs,
+ train_params.num_threads,
+ batch_size=train_params.batch_size,
+ )
+ replace = TFLiteModel(f"{train_dir}/replace.tflite")
+
+ input_name, output_name = _get_io_tensors(teacher)
+
+ dataset = set_up_data_pipeline(
+ teacher,
+ replace,
+ train_dir,
+ augmentations=train_params.augmentations,
+ steps=train_params.steps,
+ batch_size=train_params.batch_size,
+ )
input_shape = teacher.shape_from_name[input_name][1:]
output_shape = teacher.shape_from_name[output_name][1:]
+
model = replace_fn(input_shape, output_shape)
- optimizer = tf.keras.optimizers.Nadam(learning_rate=learning_rate)
+ optimizer = tf.keras.optimizers.Nadam(learning_rate=train_params.learning_rate)
loss_fn = tf.keras.losses.MeanSquaredError()
- model.compile(optimizer=optimizer, loss=loss_fn)
+ model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"])
- if verbose:
- model.summary()
+ logger.info(model.summary())
steps_so_far = 0
@@ -302,7 +351,9 @@ def train_in_dir(
"""Cosine decay from learning rate at start of the run to zero at the end."""
current_step = epoch_step + steps_so_far
cd_learning_rate = (
- learning_rate * (math.cos(math.pi * current_step / steps) + 1) / 2.0
+ train_params.learning_rate
+ * (math.cos(math.pi * current_step / train_params.steps) + 1)
+ / 2.0
)
tf.keras.backend.set_value(optimizer.learning_rate, cd_learning_rate)
@@ -311,28 +362,29 @@ def train_in_dir(
) -> None:
"""Constant until the last 20% of the run, then linear decay to zero."""
current_step = epoch_step + steps_so_far
- steps_remaining = steps - current_step
- decay_length = steps // 5
+ steps_remaining = train_params.steps - current_step
+ decay_length = train_params.steps // 5
decay_fraction = min(steps_remaining, decay_length) / decay_length
- ld_learning_rate = learning_rate * decay_fraction
+ ld_learning_rate = train_params.learning_rate * decay_fraction
tf.keras.backend.set_value(optimizer.learning_rate, ld_learning_rate)
- if schedule == "cosine":
+ assert train_params.learning_rate_schedule in LEARNING_RATE_SCHEDULES, (
+ f'Learning rate schedule "{train_params.learning_rate_schedule}" '
+ f"not implemented - expected one of {LEARNING_RATE_SCHEDULES}."
+ )
+ if train_params.learning_rate_schedule == "cosine":
callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=cosine_decay)]
- elif schedule == "late":
+ elif train_params.learning_rate_schedule == "late":
callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=late_decay)]
- elif schedule == "constant":
+ elif train_params.learning_rate_schedule == "constant":
callbacks = []
- else:
- assert schedule not in LEARNING_RATE_SCHEDULES
- raise ValueError(
- f'Learning rate schedule "{schedule}" not implemented - '
- f"expected one of {LEARNING_RATE_SCHEDULES}."
- )
output_filenames = []
- checkpoints = (checkpoint_at if checkpoint_at else []) + [steps]
- while steps_so_far < steps:
+ checkpoints = (train_params.checkpoint_at if train_params.checkpoint_at else []) + [
+ train_params.steps
+ ]
+
+ while steps_so_far < train_params.steps:
steps_to_train = checkpoints.pop(0) - steps_so_far
lr_start = optimizer.learning_rate.numpy()
model.fit(
@@ -340,7 +392,7 @@ def train_in_dir(
epochs=1,
steps_per_epoch=steps_to_train,
callbacks=callbacks,
- verbose=show_progress,
+ verbose=train_params.show_progress,
)
steps_so_far += steps_to_train
logger.info(
@@ -350,12 +402,14 @@ def train_in_dir(
steps_to_train,
)
- if steps_so_far < steps:
+ if steps_so_far < train_params.steps:
filename, ext = Path(output_filename).parts[1:]
checkpoint_filename = filename + (f"_@{steps_so_far}") + ext
else:
checkpoint_filename = str(output_filename)
- with log_action(f"{steps_so_far}/{steps}: Saved as {checkpoint_filename}"):
+ with log_action(
+ f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}"
+ ):
save_as_tflite(
model,
checkpoint_filename,
@@ -379,14 +433,30 @@ def save_as_tflite(
output_shape: list,
) -> None:
"""Save Keras model as TFLite file."""
- converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
+
+ @contextmanager
+ def fixed_input(keras_model: tf.keras.Model, tmp_shape: list) -> GeneratorType:
+ """Fix the input shape of the Keras model temporarily.
+
+ This avoids artifacts during conversion to TensorFlow Lite.
+ """
+ orig_shape = keras_model.input.shape
+ keras_model.input.set_shape(tf.TensorShape(tmp_shape))
+ try:
+ yield keras_model
+ finally:
+ # Restore original shape to not interfere with further training
+ keras_model.input.set_shape(orig_shape)
+
+ with fixed_input(keras_model, input_shape) as fixed_model:
+ converter = tf.lite.TFLiteConverter.from_keras_model(fixed_model)
tflite_model = converter.convert()
with open(filename, "wb") as file:
file.write(tflite_model)
# Now fix the shapes and names to match those we expect
- flatbuffer = load(filename)
+ flatbuffer = load_fb(filename)
i = flatbuffer.subgraphs[0].inputs[0]
flatbuffer.subgraphs[0].tensors[i].shape = np.array(input_shape, dtype=np.int32)
flatbuffer.subgraphs[0].tensors[i].name = input_name.encode("utf-8")
@@ -395,11 +465,11 @@ def save_as_tflite(
output_shape, dtype=np.int32
)
flatbuffer.subgraphs[0].tensors[output].name = output_name.encode("utf-8")
- save(flatbuffer, filename)
+ save_fb(flatbuffer, filename)
def augment_fn_twins(
- inputs: dict, augmentations: tuple[float | None, float | None]
+ inputs: tf.data.Dataset, augmentations: tuple[float | None, float | None]
) -> Any:
"""Return a pair of twinned augmentation functions with the same sequence \
of random numbers."""
@@ -415,6 +485,11 @@ def augment_fn(
inputs: Any, augmentations: tuple[float | None, float | None], rng: Generator
) -> Any:
"""Augmentation module."""
+ assert len(augmentations) == 2, (
+ f"Unexpected number of augmentation parameters: {len(augmentations)} "
+ "(must be 2)"
+ )
+
mixup_strength, gaussian_strength = augmentations
augments = []
@@ -449,17 +524,16 @@ def augment_fn(
augments.append(gaussian_strength_augment)
- if len(augments) == 0: # pylint: disable=no-else-return
+ if len(augments) == 0:
return lambda x: x
- elif len(augments) == 1:
+ if len(augments) == 1:
return augments[0]
- elif len(augments) == 2:
+ if len(augments) == 2:
return lambda x: augments[1](augments[0](x))
- else:
- assert (
- False
- ), f"Unexpected number of augmentation \
- functions ({len(augments)})"
+
+ raise RuntimeError(
+ "Unexpected number of augmentation functions (must be <=2): " f"{len(augments)}"
+ )
def mixup(rng: Generator, batch: Any, beta_range: tuple = (0.0, 1.0)) -> Any:
diff --git a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
index 9229810..38ac1ed 100644
--- a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
+++ b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
@@ -6,55 +6,56 @@ from __future__ import annotations
import json
import os
import random
-import tempfile
-from collections import defaultdict
+from functools import lru_cache
from pathlib import Path
from typing import Any
from typing import Callable
-import numpy as np
import tensorflow as tf
-from tensorflow.lite.python import interpreter as interpreter_wrapper
-from mlia.nn.rewrite.core.utils.utils import load
-from mlia.nn.rewrite.core.utils.utils import save
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
-def make_decode_fn(filename: str) -> Callable:
- """Make decode filename."""
+def decode_fn(record_bytes: Any, type_map: dict) -> dict:
+ """Decode the given bytes into a name-tensor dict assuming the given type."""
+ parse_dict = {
+ name: tf.io.FixedLenFeature([], tf.string) for name in type_map.keys()
+ }
+ example = tf.io.parse_single_example(record_bytes, parse_dict)
+ features = {
+ n: tf.io.parse_tensor(example[n], tf.as_dtype(t)) for n, t in type_map.items()
+ }
+ return features
- def decode_fn(record_bytes: Any, type_map: dict) -> dict:
- parse_dict = {
- name: tf.io.FixedLenFeature([], tf.string) for name in type_map.keys()
- }
- example = tf.io.parse_single_example(record_bytes, parse_dict)
- features = {
- n: tf.io.parse_tensor(example[n], tf.as_dtype(t))
- for n, t in type_map.items()
- }
- return features
+def make_decode_fn(filename: str, model_filename: str | Path | None = None) -> Callable:
+ """Make decode filename."""
meta_filename = filename + ".meta"
- with open(meta_filename, encoding="utf-8") as file:
- type_map = json.load(file)["type_map"]
+ try:
+ with open(meta_filename, encoding="utf-8") as file:
+ type_map = json.load(file)["type_map"]
return lambda record_bytes: decode_fn(record_bytes, type_map)
def numpytf_read(filename: str | Path) -> Any:
"""Read TFRecord dataset."""
- decode_fn = make_decode_fn(str(filename))
+ decode = make_decode_fn(str(filename))
dataset = tf.data.TFRecordDataset(str(filename))
- return dataset.map(decode_fn)
+ return dataset.map(decode)
-def numpytf_count(filename: str | Path) -> Any:
+@lru_cache
+def numpytf_count(filename: str | Path) -> int:
"""Return count from TFRecord file."""
meta_filename = f"{filename}.meta"
- with open(meta_filename, encoding="utf-8") as file:
- return json.load(file)["count"]
+ try:
+ with open(meta_filename, encoding="utf-8") as file:
+ return int(json.load(file)["count"])
+ except FileNotFoundError:
+ raw_dataset = tf.data.TFRecordDataset(filename)
+ return sum(1 for _ in raw_dataset)
class NumpyTFWriter:
@@ -101,93 +102,6 @@ class NumpyTFWriter:
self.writer.close()
-class TFLiteModel:
- """A representation of a TFLite Model."""
-
- def __init__(
- self,
- filename: str,
- batch_size: int | None = None,
- num_threads: int | None = None,
- ) -> None:
- """Initiate a TFLite Model."""
- if not num_threads:
- num_threads = None
- if not batch_size:
- self.interpreter = interpreter_wrapper.Interpreter(
- model_path=filename, num_threads=num_threads
- )
- else: # if a batch size is specified, modify the TFLite model to use this size
- with tempfile.TemporaryDirectory() as tmp:
- flatbuffer = load(filename)
- for subgraph in flatbuffer.subgraphs:
- for tensor in list(subgraph.inputs) + list(subgraph.outputs):
- subgraph.tensors[tensor].shape = np.array(
- [batch_size] + list(subgraph.tensors[tensor].shape[1:]),
- dtype=np.int32,
- )
- tempname = os.path.join(tmp, "rewrite_tmp.tflite")
- save(flatbuffer, tempname)
- self.interpreter = interpreter_wrapper.Interpreter(
- model_path=tempname, num_threads=num_threads
- )
-
- try:
- self.interpreter.allocate_tensors()
- except RuntimeError:
- self.interpreter = interpreter_wrapper.Interpreter(
- model_path=filename, num_threads=num_threads
- )
- self.interpreter.allocate_tensors()
-
- # Get input and output tensors.
- self.input_details = self.interpreter.get_input_details()
- self.output_details = self.interpreter.get_output_details()
- details = list(self.input_details) + list(self.output_details)
- self.handle_from_name = {d["name"]: d["index"] for d in details}
- self.shape_from_name = {d["name"]: d["shape"] for d in details}
- self.batch_size = next(iter(self.shape_from_name.values()))[0]
-
- def __call__(self, named_input: dict) -> dict:
- """Execute the model on one or a batch of named inputs \
- (a dict of name: numpy array)."""
- input_len = next(iter(named_input.values())).shape[0]
- full_steps = input_len // self.batch_size
- remainder = input_len % self.batch_size
-
- named_ys = defaultdict(list)
- for i in range(full_steps):
- for name, x_batch in named_input.items():
- x_tensor = x_batch[i : i + self.batch_size] # noqa: E203
- self.interpreter.set_tensor(self.handle_from_name[name], x_tensor)
- self.interpreter.invoke()
- for output_detail in self.output_details:
- named_ys[output_detail["name"]].append(
- self.interpreter.get_tensor(output_detail["index"])
- )
- if remainder:
- for name, x_batch in named_input.items():
- x_tensor = np.zeros( # pylint: disable=invalid-name
- self.shape_from_name[name]
- ).astype(x_batch.dtype)
- x_tensor[:remainder] = x_batch[-remainder:]
- self.interpreter.set_tensor(self.handle_from_name[name], x_tensor)
- self.interpreter.invoke()
- for output_detail in self.output_details:
- named_ys[output_detail["name"]].append(
- self.interpreter.get_tensor(output_detail["index"])[:remainder]
- )
- return {k: np.concatenate(v) for k, v in named_ys.items()}
-
- def input_tensors(self) -> list:
- """Return name from input details."""
- return [d["name"] for d in self.input_details]
-
- def output_tensors(self) -> list:
- """Return name from output details."""
- return [d["name"] for d in self.output_details]
-
-
def sample_tfrec(input_file: str, k: int, output_file: str) -> None:
"""Count, read and write TFRecord input and output data."""
total = numpytf_count(input_file)
diff --git a/src/mlia/nn/rewrite/core/utils/parallel.py b/src/mlia/nn/rewrite/core/utils/parallel.py
index d930a1e..b7b390d 100644
--- a/src/mlia/nn/rewrite/core/utils/parallel.py
+++ b/src/mlia/nn/rewrite/core/utils/parallel.py
@@ -15,14 +15,14 @@ from typing import Any
import numpy as np
import tensorflow as tf
-from mlia.nn.rewrite.core.utils.numpy_tfrecord import TFLiteModel
+from mlia.nn.tensorflow.config import TFLiteModel
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
logger = logging.getLogger(__name__)
-class ParallelTFLiteModel(TFLiteModel):
+class ParallelTFLiteModel(TFLiteModel): # pylint: disable=abstract-method
"""A parallel version of a TFLiteModel.
num_procs: 0 => detect real cores on system
diff --git a/src/mlia/nn/rewrite/core/utils/utils.py b/src/mlia/nn/rewrite/core/utils/utils.py
deleted file mode 100644
index ddf0cc2..0000000
--- a/src/mlia/nn/rewrite/core/utils/utils.py
+++ /dev/null
@@ -1,32 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Model and file system utilites."""
-from __future__ import annotations
-
-from pathlib import Path
-
-import flatbuffers
-from tensorflow.lite.python.schema_py_generated import Model
-from tensorflow.lite.python.schema_py_generated import ModelT
-
-
-def load(input_tflite_file: str | Path) -> ModelT:
- """Load a flatbuffer model from file."""
- if not Path(input_tflite_file).exists():
- raise FileNotFoundError(f"TFLite file not found at {input_tflite_file}\n")
- with open(input_tflite_file, "rb") as file_handle:
- file_data = bytearray(file_handle.read())
- model_obj = Model.GetRootAsModel(file_data, 0)
- model = ModelT.InitFromObj(model_obj)
- return model
-
-
-def save(model: ModelT, output_tflite_file: str | Path) -> None:
- """Save a flatbuffer model to a given file."""
- builder = flatbuffers.Builder(1024) # Initial size of the buffer, which
- # will grow automatically if needed
- model_offset = model.Pack(builder)
- builder.Finish(model_offset, file_identifier=b"TFL3")
- model_data = builder.Output()
- with open(output_tflite_file, "wb") as out_file:
- out_file.write(model_data)
diff --git a/src/mlia/nn/rewrite/library/fc_layer.py b/src/mlia/nn/rewrite/library/fc_layer.py
index 8704154..2480500 100644
--- a/src/mlia/nn/rewrite/library/fc_layer.py
+++ b/src/mlia/nn/rewrite/library/fc_layer.py
@@ -7,12 +7,12 @@ import tensorflow as tf
def get_keras_model(input_shape: Any, output_shape: Any) -> tf.keras.Model:
- """Generate tflite model for rewrite."""
- input_tensor = tf.keras.layers.Input(
- shape=input_shape, name="MbileNet/avg_pool/AvgPool"
+ """Generate TensorFlow Lite model for rewrite."""
+ model = tf.keras.Sequential(
+ (
+ tf.keras.layers.InputLayer(input_shape=input_shape),
+ tf.keras.layers.Reshape([-1]),
+ tf.keras.layers.Dense(output_shape),
+ )
)
- output_tensor = tf.keras.layers.Dense(output_shape, name="MobileNet/fc1/BiasAdd")(
- input_tensor
- )
- model = tf.keras.Model(input_tensor, output_tensor)
return model
diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py
index 5a7f289..983426b 100644
--- a/src/mlia/nn/select.py
+++ b/src/mlia/nn/select.py
@@ -17,6 +17,7 @@ from mlia.nn.common import Optimizer
from mlia.nn.common import OptimizerConfiguration
from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
from mlia.nn.rewrite.core.rewrite import Rewriter
+from mlia.nn.rewrite.core.rewrite import TrainingParameters
from mlia.nn.tensorflow.config import KerasModel
from mlia.nn.tensorflow.config import TFLiteModel
from mlia.nn.tensorflow.optimizations.clustering import Clusterer
@@ -164,6 +165,15 @@ def _get_optimizer(
return MultiStageOptimizer(model, optimizer_configs)
+def _get_rewrite_train_params() -> TrainingParameters:
+ """Get the rewrite TrainingParameters.
+
+ Return the default constructed TrainingParameters() per default, but can be
+ overwritten in the unit tests.
+ """
+ return TrainingParameters()
+
+
def _get_optimizer_configuration(
optimization_type: str,
optimization_target: int | float | str,
@@ -190,7 +200,10 @@ def _get_optimizer_configuration(
if opt_type == "rewrite":
if isinstance(optimization_target, str):
return RewriteConfiguration(
- str(optimization_target), layers_to_optimize, dataset
+ optimization_target=str(optimization_target),
+ layers_to_optimize=layers_to_optimize,
+ dataset=dataset,
+ train_params=_get_rewrite_train_params(),
)
raise ConfigurationError(
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py
index d7d430f..c6a7c88 100644
--- a/src/mlia/nn/tensorflow/config.py
+++ b/src/mlia/nn/tensorflow/config.py
@@ -4,13 +4,16 @@
from __future__ import annotations
import logging
+import tempfile
+from collections import defaultdict
from pathlib import Path
-from typing import cast
-from typing import List
+import numpy as np
import tensorflow as tf
from mlia.core.context import Context
+from mlia.nn.tensorflow.tflite_graph import load_fb
+from mlia.nn.tensorflow.tflite_graph import save_fb
from mlia.nn.tensorflow.utils import convert_to_tflite
from mlia.nn.tensorflow.utils import is_keras_model
from mlia.nn.tensorflow.utils import is_saved_model
@@ -71,10 +74,89 @@ class KerasModel(ModelConfiguration):
class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
"""TensorFlow Lite model configuration."""
- def input_details(self) -> list[dict]:
- """Get model's input details."""
- interpreter = tf.lite.Interpreter(model_path=self.model_path)
- return cast(List[dict], interpreter.get_input_details())
+ def __init__(
+ self,
+ model_path: str | Path,
+ batch_size: int | None = None,
+ num_threads: int | None = None,
+ ) -> None:
+ """Initiate a TFLite Model."""
+ super().__init__(model_path)
+ if not num_threads:
+ num_threads = None
+ if not batch_size:
+ self.interpreter = tf.lite.Interpreter(
+ model_path=self.model_path, num_threads=num_threads
+ )
+ else: # if a batch size is specified, modify the TFLite model to use this size
+ with tempfile.TemporaryDirectory() as tmp:
+ flatbuffer = load_fb(self.model_path)
+ for subgraph in flatbuffer.subgraphs:
+ for tensor in list(subgraph.inputs) + list(subgraph.outputs):
+ subgraph.tensors[tensor].shape = np.array(
+ [batch_size] + list(subgraph.tensors[tensor].shape[1:]),
+ dtype=np.int32,
+ )
+ tempname = Path(tmp, "rewrite_tmp.tflite")
+ save_fb(flatbuffer, tempname)
+ self.interpreter = tf.lite.Interpreter(
+ model_path=str(tempname), num_threads=num_threads
+ )
+
+ try:
+ self.interpreter.allocate_tensors()
+ except RuntimeError:
+ self.interpreter = tf.lite.Interpreter(
+ model_path=self.model_path, num_threads=num_threads
+ )
+ self.interpreter.allocate_tensors()
+
+ # Get input and output tensors.
+ self.input_details = self.interpreter.get_input_details()
+ self.output_details = self.interpreter.get_output_details()
+ details = list(self.input_details) + list(self.output_details)
+ self.handle_from_name = {d["name"]: d["index"] for d in details}
+ self.shape_from_name = {d["name"]: d["shape"] for d in details}
+ self.batch_size = next(iter(self.shape_from_name.values()))[0]
+
+ def __call__(self, named_input: dict) -> dict:
+ """Execute the model on one or a batch of named inputs \
+ (a dict of name: numpy array)."""
+ input_len = next(iter(named_input.values())).shape[0]
+ full_steps = input_len // self.batch_size
+ remainder = input_len % self.batch_size
+
+ named_ys = defaultdict(list)
+ for i in range(full_steps):
+ for name, x_batch in named_input.items():
+ x_tensor = x_batch[i : i + self.batch_size] # noqa: E203
+ self.interpreter.set_tensor(self.handle_from_name[name], x_tensor)
+ self.interpreter.invoke()
+ for output_detail in self.output_details:
+ named_ys[output_detail["name"]].append(
+ self.interpreter.get_tensor(output_detail["index"])
+ )
+ if remainder:
+ for name, x_batch in named_input.items():
+ x_tensor = np.zeros( # pylint: disable=invalid-name
+ self.shape_from_name[name]
+ ).astype(x_batch.dtype)
+ x_tensor[:remainder] = x_batch[-remainder:]
+ self.interpreter.set_tensor(self.handle_from_name[name], x_tensor)
+ self.interpreter.invoke()
+ for output_detail in self.output_details:
+ named_ys[output_detail["name"]].append(
+ self.interpreter.get_tensor(output_detail["index"])[:remainder]
+ )
+ return {k: np.concatenate(v) for k, v in named_ys.items()}
+
+ def input_tensors(self) -> list:
+ """Return name from input details."""
+ return [d["name"] for d in self.input_details]
+
+ def output_tensors(self) -> list:
+ """Return name from output details."""
+ return [d["name"] for d in self.output_details]
def convert_to_tflite(
self, tflite_model_path: str | Path, quantized: bool = False
@@ -118,10 +200,10 @@ def get_model(model: str | Path) -> ModelConfiguration:
def get_tflite_model(model: str | Path, ctx: Context) -> TFLiteModel:
"""Convert input model to TensorFlow Lite and returns TFLiteModel object."""
- tflite_model_path = ctx.get_model_path("converted_model.tflite")
- converted_model = get_model(model)
+ dst_model_path = ctx.get_model_path("converted_model.tflite")
+ src_model = get_model(model)
- return converted_model.convert_to_tflite(tflite_model_path, True)
+ return src_model.convert_to_tflite(dst_model_path, quantized=True)
def get_keras_model(model: str | Path, ctx: Context) -> KerasModel:
diff --git a/src/mlia/nn/tensorflow/tflite_graph.py b/src/mlia/nn/tensorflow/tflite_graph.py
index 4f5e85f..7ca9337 100644
--- a/src/mlia/nn/tensorflow/tflite_graph.py
+++ b/src/mlia/nn/tensorflow/tflite_graph.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Utilities for TensorFlow Lite graphs."""
from __future__ import annotations
@@ -10,7 +10,10 @@ from pathlib import Path
from typing import Any
from typing import cast
+import flatbuffers
from tensorflow.lite.python import schema_py_generated as schema_fb
+from tensorflow.lite.python.schema_py_generated import Model
+from tensorflow.lite.python.schema_py_generated import ModelT
from tensorflow.lite.tools import visualize
@@ -137,3 +140,25 @@ def parse_subgraphs(tflite_file: Path) -> list[list[Op]]:
]
return graphs
+
+
+def load_fb(input_tflite_file: str | Path) -> ModelT:
+ """Load a flatbuffer model from file."""
+ if not Path(input_tflite_file).exists():
+ raise FileNotFoundError(f"TFLite file not found at {input_tflite_file}\n")
+ with open(input_tflite_file, "rb") as file_handle:
+ file_data = bytearray(file_handle.read())
+ model_obj = Model.GetRootAsModel(file_data, 0)
+ model = ModelT.InitFromObj(model_obj)
+ return model
+
+
+def save_fb(model: ModelT, output_tflite_file: str | Path) -> None:
+ """Save a flatbuffer model to a given file."""
+ builder = flatbuffers.Builder(1024) # Initial size of the buffer, which
+ # will grow automatically if needed
+ model_offset = model.Pack(builder)
+ builder.Finish(model_offset, file_identifier=b"TFL3")
+ model_data = builder.Output()
+ with open(output_tflite_file, "wb") as out_file:
+ out_file.write(model_data)
diff --git a/src/mlia/target/ethos_u/data_collection.py b/src/mlia/target/ethos_u/data_collection.py
index ba8b0fe..4ea6120 100644
--- a/src/mlia/target/ethos_u/data_collection.py
+++ b/src/mlia/target/ethos_u/data_collection.py
@@ -106,15 +106,14 @@ class OptimizeModel:
self.context = context
self.opt_settings = opt_settings
- def __call__(self, keras_model: KerasModel) -> Any:
+ def __call__(self, model: KerasModel | TFLiteModel) -> Any:
"""Run optimization."""
- optimizer = get_optimizer(keras_model, self.opt_settings)
+ optimizer = get_optimizer(model, self.opt_settings)
opts_as_str = ", ".join(str(opt) for opt in self.opt_settings)
logger.info("Applying model optimizations - [%s]", opts_as_str)
optimizer.apply_optimization()
-
- model = optimizer.get_model()
+ model = optimizer.get_model() # type: ignore
if isinstance(model, Path):
return model
@@ -178,6 +177,7 @@ class EthosUOptimizationPerformance(ContextAwareDataCollector):
self.target,
self.backends,
)
+
original_metrics, *optimized_metrics = estimate_performance(
model, estimator, optimizers # type: ignore
)
diff --git a/tests/conftest.py b/tests/conftest.py
index c42b8cb..bb2423f 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -5,6 +5,7 @@ import shutil
from pathlib import Path
from typing import Callable
from typing import Generator
+from unittest.mock import MagicMock
import numpy as np
import pytest
@@ -17,6 +18,7 @@ from mlia.nn.tensorflow.utils import convert_to_tflite
from mlia.nn.tensorflow.utils import save_keras_model
from mlia.nn.tensorflow.utils import save_tflite_model
from mlia.target.ethos_u.config import EthosUConfiguration
+from tests.utils.rewrite import TestTrainingParameters
@pytest.fixture(scope="session", name="test_resources_path")
@@ -168,16 +170,12 @@ def _write_tfrecord(
writer.write({input_name: data_generator()})
-@pytest.fixture(scope="session", name="test_tfrecord")
-def fixture_test_tfrecord(
- tmp_path_factory: pytest.TempPathFactory,
+def create_tfrecord(
+ tmp_path_factory: pytest.TempPathFactory, random_data: Callable
) -> Generator[Path, None, None]:
"""Create a tfrecord with random data matching fixture 'test_tflite_model'."""
tmp_path = tmp_path_factory.mktemp("tfrecords")
- tfrecord_file = tmp_path / "test_int8.tfrecord"
-
- def random_data() -> np.ndarray:
- return np.random.randint(low=-127, high=128, size=(1, 28, 28, 1), dtype=np.int8)
+ tfrecord_file = tmp_path / "test.tfrecord"
_write_tfrecord(tfrecord_file, random_data)
@@ -186,19 +184,36 @@ def fixture_test_tfrecord(
shutil.rmtree(tmp_path)
+@pytest.fixture(scope="session", name="test_tfrecord")
+def fixture_test_tfrecord(
+ tmp_path_factory: pytest.TempPathFactory,
+) -> Generator[Path, None, None]:
+ """Create a tfrecord with random data matching fixture 'test_tflite_model'."""
+
+ def random_data() -> np.ndarray:
+ return np.random.randint(low=-127, high=128, size=(1, 28, 28, 1), dtype=np.int8)
+
+ yield from create_tfrecord(tmp_path_factory, random_data)
+
+
@pytest.fixture(scope="session", name="test_tfrecord_fp32")
def fixture_test_tfrecord_fp32(
tmp_path_factory: pytest.TempPathFactory,
) -> Generator[Path, None, None]:
"""Create tfrecord with random data matching fixture 'test_tflite_model_fp32'."""
- tmp_path = tmp_path_factory.mktemp("tfrecords")
- tfrecord_file = tmp_path / "test_fp32.tfrecord"
def random_data() -> np.ndarray:
return np.random.rand(1, 28, 28, 1).astype(np.float32)
- _write_tfrecord(tfrecord_file, random_data)
+ yield from create_tfrecord(tmp_path_factory, random_data)
- yield tfrecord_file
- shutil.rmtree(tmp_path)
+@pytest.fixture(scope="session", autouse=True)
+def set_training_steps() -> Generator[None, None, None]:
+ """Speed up tests by using TestTrainingParameters."""
+ with pytest.MonkeyPatch.context() as monkeypatch:
+ monkeypatch.setattr(
+ "mlia.nn.select._get_rewrite_train_params",
+ MagicMock(return_value=TestTrainingParameters()),
+ )
+ yield
diff --git a/tests/test_nn_rewrite_core_graph_edit_join.py b/tests/test_nn_rewrite_core_graph_edit_join.py
index cb3e4e2..0cb121e 100644
--- a/tests/test_nn_rewrite_core_graph_edit_join.py
+++ b/tests/test_nn_rewrite_core_graph_edit_join.py
@@ -10,7 +10,7 @@ import pytest
from mlia.nn.rewrite.core.graph_edit.cut import cut_model
from mlia.nn.rewrite.core.graph_edit.join import append_relabel
from mlia.nn.rewrite.core.graph_edit.join import join_models
-from mlia.nn.rewrite.core.utils.utils import load
+from mlia.nn.tensorflow.tflite_graph import load_fb
from tests.utils.rewrite import models_are_equal
@@ -49,8 +49,8 @@ def test_join_model(test_tflite_model: Path, tmp_path: Path) -> None:
)
assert joined_file.is_file()
- orig_model = load(str(test_tflite_model))
- joined_model = load(str(joined_file))
+ orig_model = load_fb(str(test_tflite_model))
+ joined_model = load_fb(str(joined_file))
assert models_are_equal(orig_model, joined_model)
diff --git a/tests/test_nn_rewrite_core_graph_edit_record.py b/tests/test_nn_rewrite_core_graph_edit_record.py
index cd728af..41b9c50 100644
--- a/tests/test_nn_rewrite_core_graph_edit_record.py
+++ b/tests/test_nn_rewrite_core_graph_edit_record.py
@@ -3,15 +3,13 @@
"""Tests for module mlia.nn.rewrite.graph_edit.record."""
from pathlib import Path
-import pytest
import tensorflow as tf
from mlia.nn.rewrite.core.graph_edit.record import record_model
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
-@pytest.mark.parametrize("batch_size", (None, 1, 2))
-def test_record_model(
+def check_record_model(
test_tflite_model: Path,
tmp_path: Path,
test_tfrecord: Path,
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py
index b98971e..2542db2 100644
--- a/tests/test_nn_rewrite_core_rewrite.py
+++ b/tests/test_nn_rewrite_core_rewrite.py
@@ -12,6 +12,7 @@ import pytest
from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
from mlia.nn.rewrite.core.rewrite import Rewriter
from mlia.nn.tensorflow.config import TFLiteModel
+from tests.utils.rewrite import TestTrainingParameters
@pytest.mark.parametrize(
@@ -32,12 +33,14 @@ def test_rewrite_configuration(
None,
)
+ assert config_obj.optimization_target in str(config_obj)
+
rewriter_obj = Rewriter(test_tflite_model_fp32, config_obj)
assert rewriter_obj.optimizer_configuration.optimization_target == rewrite_name
assert isinstance(rewriter_obj, Rewriter)
-def test_rewriter(
+def test_rewriting_optimizer(
test_tflite_model_fp32: Path,
test_tfrecord_fp32: Path,
) -> None:
@@ -46,6 +49,7 @@ def test_rewriter(
"fully_connected",
["sequential/flatten/Reshape", "StatefulPartitionedCall:0"],
test_tfrecord_fp32,
+ train_params=TestTrainingParameters(),
)
test_obj = Rewriter(test_tflite_model_fp32, config_obj)
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
index 3c2ef3e..4493671 100644
--- a/tests/test_nn_rewrite_core_train.py
+++ b/tests/test_nn_rewrite_core_train.py
@@ -4,6 +4,7 @@
# pylint: disable=too-many-arguments
from __future__ import annotations
+from contextlib import ExitStack as does_not_raise
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any
@@ -12,10 +13,13 @@ import numpy as np
import pytest
import tensorflow as tf
-from mlia.nn.rewrite.core.train import augmentation_presets
+from mlia.nn.rewrite.core.train import augment_fn_twins
+from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS
from mlia.nn.rewrite.core.train import LearningRateSchedule
from mlia.nn.rewrite.core.train import mixup
from mlia.nn.rewrite.core.train import train
+from mlia.nn.rewrite.core.train import TrainingParameters
+from tests.utils.rewrite import TestTrainingParameters
def replace_fully_connected_with_conv(
@@ -41,15 +45,8 @@ def replace_fully_connected_with_conv(
def check_train(
tflite_model: Path,
tfrecord: Path,
- batch_size: int = 1,
- verbose: bool = False,
- show_progress: bool = False,
- augmentation_preset: tuple[float | None, float | None] = augmentation_presets[
- "none"
- ],
- lr_schedule: LearningRateSchedule = "cosine",
+ train_params: TrainingParameters = TestTrainingParameters(),
use_unmodified_model: bool = False,
- num_procs: int = 1,
) -> None:
"""Test the train() function."""
with TemporaryDirectory() as tmp_dir:
@@ -62,14 +59,7 @@ def check_train(
replace_fn=replace_fully_connected_with_conv,
input_tensors=["sequential/flatten/Reshape"],
output_tensors=["StatefulPartitionedCall:0"],
- augment=augmentation_preset,
- steps=32,
- learning_rate=1e-3,
- batch_size=batch_size,
- verbose=verbose,
- show_progress=show_progress,
- learning_rate_schedule=lr_schedule,
- num_procs=num_procs,
+ train_params=train_params,
)
assert len(result) == 2
assert all(res >= 0.0 for res in result), f"Results out of bound: {result}"
@@ -79,7 +69,6 @@ def check_train(
@pytest.mark.parametrize(
(
"batch_size",
- "verbose",
"show_progress",
"augmentation_preset",
"lr_schedule",
@@ -87,14 +76,13 @@ def check_train(
"num_procs",
),
(
- (1, False, False, augmentation_presets["none"], "cosine", False, 2),
- (32, True, True, augmentation_presets["gaussian"], "late", True, 1),
- (2, False, False, augmentation_presets["mixup"], "constant", True, 0),
+ (1, False, AUGMENTATION_PRESETS["none"], "cosine", False, 2),
+ (32, True, AUGMENTATION_PRESETS["gaussian"], "late", True, 1),
+ (2, False, AUGMENTATION_PRESETS["mixup"], "constant", True, 0),
(
1,
False,
- False,
- augmentation_presets["mix_gaussian_large"],
+ AUGMENTATION_PRESETS["mix_gaussian_large"],
"cosine",
False,
2,
@@ -105,7 +93,6 @@ def test_train(
test_tflite_model_fp32: Path,
test_tfrecord_fp32: Path,
batch_size: int,
- verbose: bool,
show_progress: bool,
augmentation_preset: tuple[float | None, float | None],
lr_schedule: LearningRateSchedule,
@@ -116,13 +103,14 @@ def test_train(
check_train(
tflite_model=test_tflite_model_fp32,
tfrecord=test_tfrecord_fp32,
- batch_size=batch_size,
- verbose=verbose,
- show_progress=show_progress,
- augmentation_preset=augmentation_preset,
- lr_schedule=lr_schedule,
+ train_params=TestTrainingParameters(
+ batch_size=batch_size,
+ show_progress=show_progress,
+ augmentations=augmentation_preset,
+ learning_rate_schedule=lr_schedule,
+ num_procs=num_procs,
+ ),
use_unmodified_model=use_unmodified_model,
- num_procs=num_procs,
)
@@ -131,11 +119,13 @@ def test_train_invalid_schedule(
test_tfrecord_fp32: Path,
) -> None:
"""Test the train() function with an invalid schedule."""
- with pytest.raises(ValueError):
+ with pytest.raises(AssertionError):
check_train(
tflite_model=test_tflite_model_fp32,
tfrecord=test_tfrecord_fp32,
- lr_schedule="unknown_schedule", # type: ignore
+ train_params=TestTrainingParameters(
+ learning_rate_schedule="unknown_schedule",
+ ),
)
@@ -144,11 +134,13 @@ def test_train_invalid_augmentation(
test_tfrecord_fp32: Path,
) -> None:
"""Test the train() function with an invalid augmentation."""
- with pytest.raises(ValueError):
+ with pytest.raises(AssertionError):
check_train(
tflite_model=test_tflite_model_fp32,
tfrecord=test_tfrecord_fp32,
- augmentation_preset=(1.0, 2.0, 3.0), # type: ignore
+ train_params=TestTrainingParameters(
+ augmentations=(1.0, 2.0, 3.0),
+ ),
)
@@ -159,3 +151,19 @@ def test_mixup() -> None:
assert src.shape == dst.shape
assert np.all(dst >= 0.0)
assert np.all(dst <= 3.0)
+
+
+@pytest.mark.parametrize(
+ "augmentations, expected_error",
+ [
+ (AUGMENTATION_PRESETS["none"], does_not_raise()),
+ (AUGMENTATION_PRESETS["mix_gaussian_large"], does_not_raise()),
+ ((None,) * 3, pytest.raises(AssertionError)),
+ ],
+)
+def test_augment_fn_twins(augmentations: tuple, expected_error: Any) -> None:
+ """Test function augment_fn()."""
+ dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2, 3], "b": [4, 5, 6]})
+ with expected_error:
+ fn_twins = augment_fn_twins(dataset, augmentations) # type: ignore
+ assert len(fn_twins) == 2
diff --git a/tests/test_nn_rewrite_core_utils.py b/tests/test_nn_rewrite_core_utils.py
deleted file mode 100644
index d806a7b..0000000
--- a/tests/test_nn_rewrite_core_utils.py
+++ /dev/null
@@ -1,33 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Tests for module mlia.nn.rewrite.utils."""
-from pathlib import Path
-
-import pytest
-import tensorflow as tf
-from tensorflow.lite.python.schema_py_generated import ModelT
-
-from mlia.nn.rewrite.core.utils.utils import load
-from mlia.nn.rewrite.core.utils.utils import save
-from tests.utils.rewrite import models_are_equal
-
-
-def test_load_save(test_tflite_model: Path, tmp_path: Path) -> None:
- """Test the load/save functions for TensorFlow Lite models."""
- with pytest.raises(FileNotFoundError):
- load("THIS_IS_NOT_A_REAL_FILE")
-
- model = load(test_tflite_model)
- assert isinstance(model, ModelT)
- assert model.subgraphs
-
- output_file = tmp_path / "test.tflite"
- assert not output_file.is_file()
- save(model, output_file)
- assert output_file.is_file()
-
- model_copy = load(str(output_file))
- assert models_are_equal(model, model_copy)
-
- # Double check that the TensorFlow Lite Interpreter can still load the file.
- tf.lite.Interpreter(model_path=str(output_file))
diff --git a/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py b/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py
index 7fc8048..d030350 100644
--- a/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py
+++ b/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py
@@ -5,6 +5,10 @@ from __future__ import annotations
from pathlib import Path
+import pytest
+import tensorflow as tf
+
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import make_decode_fn
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count
from mlia.nn.rewrite.core.utils.numpy_tfrecord import sample_tfrec
@@ -16,3 +20,24 @@ def test_sample_tfrec(test_tfrecord: Path, tmp_path: Path) -> None:
sample_tfrec(input_file=str(test_tfrecord), k=1, output_file=str(output_file))
assert output_file.is_file()
assert numpytf_count(str(output_file)) == 1
+
+
+def test_make_decode_fn(test_tfrecord: Path) -> None:
+ """Test function make_decode_fn()."""
+ decode = make_decode_fn(str(test_tfrecord))
+ dataset = tf.data.TFRecordDataset(str(test_tfrecord))
+ features = decode(next(iter(dataset)))
+ assert isinstance(features, dict)
+ assert len(features) == 1
+ key, val = next(iter(features.items()))
+ assert isinstance(key, str)
+ assert isinstance(val, tf.Tensor)
+ assert val.dtype == tf.int8
+
+ with pytest.raises(FileNotFoundError):
+ make_decode_fn(str(test_tfrecord) + "_")
+
+
+def test_numpytf_count(test_tfrecord: Path) -> None:
+ """Test function numpytf_count()."""
+ assert numpytf_count(test_tfrecord) == 3
diff --git a/tests/test_nn_tensorflow_config.py b/tests/test_nn_tensorflow_config.py
index 656619d..48aec0a 100644
--- a/tests/test_nn_tensorflow_config.py
+++ b/tests/test_nn_tensorflow_config.py
@@ -4,13 +4,28 @@
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
+from typing import Generator
+import numpy as np
import pytest
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
from mlia.nn.tensorflow.config import get_model
from mlia.nn.tensorflow.config import KerasModel
+from mlia.nn.tensorflow.config import ModelConfiguration
from mlia.nn.tensorflow.config import TFLiteModel
from mlia.nn.tensorflow.config import TfModel
+from tests.conftest import create_tfrecord
+
+
+def test_model_configuration(test_keras_model: Path) -> None:
+ """Test ModelConfiguration class."""
+ model = ModelConfiguration(model_path=test_keras_model)
+ assert test_keras_model.match(model.model_path)
+ with pytest.raises(NotImplementedError):
+ model.convert_to_keras("keras_model.h5")
+ with pytest.raises(NotImplementedError):
+ model.convert_to_tflite("model.tflite")
def test_convert_keras_to_tflite(tmp_path: Path, test_keras_model: Path) -> None:
@@ -38,7 +53,7 @@ def test_convert_tf_to_tflite(tmp_path: Path, test_tf_model: Path) -> None:
@pytest.mark.parametrize(
"model_path, expected_type, expected_error",
[
- ("test.tflite", TFLiteModel, does_not_raise()),
+ ("test.tflite", TFLiteModel, pytest.raises(ValueError)),
("test.h5", KerasModel, does_not_raise()),
("test.hdf5", KerasModel, does_not_raise()),
(
@@ -73,3 +88,26 @@ def test_get_model_dir(
"""Test TensorFlow Lite model type."""
model = get_model(str(test_models_path / model_path))
assert isinstance(model, expected_type)
+
+
+@pytest.fixture(scope="session", name="test_tfrecord_fp32_batch_3")
+def fixture_test_tfrecord_fp32_batch_3(
+ tmp_path_factory: pytest.TempPathFactory,
+) -> Generator[Path, None, None]:
+ """Create tfrecord (same as test_tfrecord_fp32) but with batch size 3."""
+
+ def random_data() -> np.ndarray:
+ return np.random.rand(3, 28, 28, 1).astype(np.float32)
+
+ yield from create_tfrecord(tmp_path_factory, random_data)
+
+
+def test_tflite_model_call(
+ test_tflite_model_fp32: Path, test_tfrecord_fp32_batch_3: Path
+) -> None:
+ """Test inference function of class TFLiteModel."""
+ model = TFLiteModel(test_tflite_model_fp32, batch_size=2)
+ data = numpytf_read(test_tfrecord_fp32_batch_3)
+ for named_input in data.as_numpy_iterator():
+ res = model(named_input)
+ assert res
diff --git a/tests/test_nn_tensorflow_tflite_graph.py b/tests/test_nn_tensorflow_tflite_graph.py
index cd1fad6..3512cdd 100644
--- a/tests/test_nn_tensorflow_tflite_graph.py
+++ b/tests/test_nn_tensorflow_tflite_graph.py
@@ -1,15 +1,22 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the tflite_graph module."""
import json
from pathlib import Path
+import pytest
+import tensorflow as tf
+from tensorflow.lite.python.schema_py_generated import ModelT
+
+from mlia.nn.tensorflow.tflite_graph import load_fb
from mlia.nn.tensorflow.tflite_graph import Op
from mlia.nn.tensorflow.tflite_graph import parse_subgraphs
+from mlia.nn.tensorflow.tflite_graph import save_fb
from mlia.nn.tensorflow.tflite_graph import TensorInfo
from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION
from mlia.nn.tensorflow.tflite_graph import TFL_OP
from mlia.nn.tensorflow.tflite_graph import TFL_TYPE
+from tests.utils.rewrite import models_are_equal
def test_tensor_info() -> None:
@@ -79,3 +86,24 @@ def test_parse_subgraphs(test_tflite_model: Path) -> None:
assert TFL_OP[oper.type] in TFL_OP
assert len(oper.inputs) > 0
assert len(oper.outputs) > 0
+
+
+def test_load_save(test_tflite_model: Path, tmp_path: Path) -> None:
+ """Test the load/save functions for TensorFlow Lite models."""
+ with pytest.raises(FileNotFoundError):
+ load_fb("THIS_IS_NOT_A_REAL_FILE")
+
+ model = load_fb(test_tflite_model)
+ assert isinstance(model, ModelT)
+ assert model.subgraphs
+
+ output_file = tmp_path / "test.tflite"
+ assert not output_file.is_file()
+ save_fb(model, output_file)
+ assert output_file.is_file()
+
+ model_copy = load_fb(str(output_file))
+ assert models_are_equal(model, model_copy)
+
+ # Double check that the TensorFlow Lite Interpreter can still load the file.
+ tf.lite.Interpreter(model_path=str(output_file))
diff --git a/tests/utils/rewrite.py b/tests/utils/rewrite.py
index 4264b4b..739bb11 100644
--- a/tests/utils/rewrite.py
+++ b/tests/utils/rewrite.py
@@ -3,8 +3,12 @@
"""Common test utils for the rewrite tests."""
from __future__ import annotations
+from typing import Any
+
from tensorflow.lite.python.schema_py_generated import ModelT
+from mlia.nn.rewrite.core.train import TrainingParameters
+
def models_are_equal(model1: ModelT, model2: ModelT) -> bool:
"""Check that the two models are equal."""
@@ -25,3 +29,17 @@ def models_are_equal(model1: ModelT, model2: ModelT) -> bool:
return False # Tensor from graph1 not found in other graph.")
return True
+
+
+class TestTrainingParameters(
+ TrainingParameters
+): # pylint: disable=too-few-public-methods
+ """
+ TrainingParameter class for rewrites with different default values.
+
+ To speed things up for the unit tests.
+ """
+
+ def __init__(self, *args: Any, steps: int = 32, **kwargs: Any) -> None:
+ """Initialize TrainingParameters with different defaults."""
+ super().__init__(*args, steps=steps, **kwargs) # type: ignore