aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2023-07-19 16:35:57 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-10-11 16:06:17 +0100
commit3cd84481fa25e64c29e57396d4bf32d7a3ca490a (patch)
treead81fb520a965bd3a3c7c983833b7cd48f9b8dea
parentf3e6597dd50ec70f043d692b773f2d9fd31519ae (diff)
downloadmlia-3cd84481fa25e64c29e57396d4bf32d7a3ca490a.tar.gz
Bug-fixes and re-factoring for the rewrite module
- Fix input shape of rewrite replacement: During and after training of the replacement model for a rewrite the Keras model is converted and saved in TensorFlow Lite format. If the input shape does not match the teacher model exactly, e.g. if the batch size is undefined, the TFLiteConverter adds extra operators during conversion. - Fix rewritten model output - Save the model output with the rewritten operator in the output dir - Log MAE and NRMSE of the rewrite - Remove 'verbose' flag from rewrite module and rely on the logging mechanism to control verbose output. - Re-factor utility classes for rewrites - Merge the two TFLiteModel classes - Move functionality to load/save TensorFlow Lite flatbuffers to nn/tensorflow/tflite_graph - Fix issue with unknown shape in datasets After upgrading to TensorFlow 2.12 the unknown shape of the TFRecordDataset is causing problems when training the replacement models for rewrites. By explicitly setting the right shape of the tensors we can work around the issue. - Adapt default parameters for rewrites. The training steps especially had to be increased significantly to be effective. Resolves: MLIA-895, MLIA-907, MLIA-946, MLIA-979 Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com> Change-Id: I887ad165aed0f2c6e5a0041f64cec5e6c5ab5c5c
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/cut.py8
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/join.py10
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py77
-rw-r--r--src/mlia/nn/rewrite/core/train.py306
-rw-r--r--src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py138
-rw-r--r--src/mlia/nn/rewrite/core/utils/parallel.py4
-rw-r--r--src/mlia/nn/rewrite/core/utils/utils.py32
-rw-r--r--src/mlia/nn/rewrite/library/fc_layer.py14
-rw-r--r--src/mlia/nn/select.py15
-rw-r--r--src/mlia/nn/tensorflow/config.py100
-rw-r--r--src/mlia/nn/tensorflow/tflite_graph.py27
-rw-r--r--src/mlia/target/ethos_u/data_collection.py8
-rw-r--r--tests/conftest.py39
-rw-r--r--tests/test_nn_rewrite_core_graph_edit_join.py6
-rw-r--r--tests/test_nn_rewrite_core_graph_edit_record.py4
-rw-r--r--tests/test_nn_rewrite_core_rewrite.py6
-rw-r--r--tests/test_nn_rewrite_core_train.py76
-rw-r--r--tests/test_nn_rewrite_core_utils.py33
-rw-r--r--tests/test_nn_rewrite_core_utils_numpy_tfrecord.py25
-rw-r--r--tests/test_nn_tensorflow_config.py40
-rw-r--r--tests/test_nn_tensorflow_tflite_graph.py30
-rw-r--r--tests/utils/rewrite.py18
22 files changed, 591 insertions, 425 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/cut.py b/src/mlia/nn/rewrite/core/graph_edit/cut.py
index 2707eb1..13a5268 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/cut.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/cut.py
@@ -9,8 +9,8 @@ import tensorflow as tf
from tensorflow.lite.python.schema_py_generated import ModelT
from tensorflow.lite.python.schema_py_generated import SubGraphT
-from mlia.nn.rewrite.core.utils.utils import load
-from mlia.nn.rewrite.core.utils.utils import save
+from mlia.nn.tensorflow.tflite_graph import load_fb
+from mlia.nn.tensorflow.tflite_graph import save_fb
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
@@ -138,8 +138,8 @@ def cut_model(
output_file: str,
) -> None:
"""Cut subgraphs and simplify a given model."""
- model = load(model_file)
+ model = load_fb(model_file)
subgraph = model.subgraphs[subgraph_index]
cut_subgraph(subgraph, input_names, output_names)
simplify(model)
- save(model, output_file)
+ save_fb(model, output_file)
diff --git a/src/mlia/nn/rewrite/core/graph_edit/join.py b/src/mlia/nn/rewrite/core/graph_edit/join.py
index 2530ec8..70109d8 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/join.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/join.py
@@ -11,8 +11,8 @@ from tensorflow.lite.python.schema_py_generated import ModelT
from tensorflow.lite.python.schema_py_generated import OperatorCodeT
from tensorflow.lite.python.schema_py_generated import SubGraphT
-from mlia.nn.rewrite.core.utils.utils import load
-from mlia.nn.rewrite.core.utils.utils import save
+from mlia.nn.tensorflow.tflite_graph import load_fb
+from mlia.nn.tensorflow.tflite_graph import save_fb
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
@@ -26,12 +26,12 @@ def join_models(
subgraph_dst: int = 0,
) -> None:
"""Join two models and save the result into a given model file path."""
- src_model = load(input_src)
- dst_model = load(input_dst)
+ src_model = load_fb(input_src)
+ dst_model = load_fb(input_dst)
src_subgraph = src_model.subgraphs[subgraph_src]
dst_subgraph = dst_model.subgraphs[subgraph_dst]
join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph)
- save(dst_model, output_file)
+ save_fb(dst_model, output_file)
def join_subgraphs(
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
index 0d182df..6b27984 100644
--- a/src/mlia/nn/rewrite/core/rewrite.py
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -4,6 +4,7 @@
from __future__ import annotations
import importlib
+import logging
import tempfile
from dataclasses import dataclass
from pathlib import Path
@@ -12,13 +13,14 @@ from typing import Any
from mlia.core.errors import ConfigurationError
from mlia.nn.common import Optimizer
from mlia.nn.common import OptimizerConfiguration
-from mlia.nn.rewrite.core.train import eval_in_dir
-from mlia.nn.rewrite.core.train import join_in_dir
from mlia.nn.rewrite.core.train import train
-from mlia.nn.rewrite.core.train import train_in_dir
+from mlia.nn.rewrite.core.train import TrainingParameters
from mlia.nn.tensorflow.config import TFLiteModel
+logger = logging.getLogger(__name__)
+
+
@dataclass
class RewriteConfiguration(OptimizerConfiguration):
"""Rewrite configuration."""
@@ -26,6 +28,7 @@ class RewriteConfiguration(OptimizerConfiguration):
optimization_target: str
layers_to_optimize: list[str] | None = None
dataset: Path | None = None
+ train_params: TrainingParameters = TrainingParameters()
def __str__(self) -> str:
"""Return string representation of the configuration."""
@@ -40,8 +43,8 @@ class Rewriter(Optimizer):
):
"""Init Rewriter instance."""
self.model = TFLiteModel(tflite_model_path)
+ self.model_path = tflite_model_path
self.optimizer_configuration = optimizer_configuration
- self.train_dir = ""
def apply_optimization(self) -> None:
"""Apply the rewrite flow."""
@@ -61,50 +64,36 @@ class Rewriter(Optimizer):
replace_fn = get_function(replace_function)
- augmentation_preset = (None, None)
use_unmodified_model = True
tflite_model = self.model.model_path
tfrecord = str(self.optimizer_configuration.dataset)
- with tempfile.TemporaryDirectory() as tmp_dir:
- tmp_output = Path(tmp_dir, "output.tflite")
-
- if self.train_dir:
- tmp_new = Path(tmp_dir, "new.tflite")
- new_part = train_in_dir(
- train_dir=self.train_dir,
- baseline_dir=None,
- output_filename=tmp_new,
- replace_fn=replace_fn,
- augmentations=augmentation_preset,
- steps=32,
- learning_rate=1e-3,
- batch_size=1,
- verbose=True,
- show_progress=True,
- )
- eval_in_dir(self.train_dir, new_part[0])
- join_in_dir(self.train_dir, new_part[0], str(tmp_output))
- else:
- if not self.optimizer_configuration.layers_to_optimize:
- raise ConfigurationError(
- "Input and output tensor names need to be set for rewrite."
- )
- train(
- source_model=tflite_model,
- unmodified_model=tflite_model if use_unmodified_model else None,
- output_model=str(tmp_output),
- input_tfrec=str(tfrecord),
- replace_fn=replace_fn,
- input_tensors=[self.optimizer_configuration.layers_to_optimize[0]],
- output_tensors=[self.optimizer_configuration.layers_to_optimize[1]],
- augment=augmentation_preset,
- steps=32,
- learning_rate=1e-3,
- batch_size=1,
- verbose=True,
- show_progress=True,
- )
+ tmp_dir = tempfile.mkdtemp()
+ tmp_output = Path(tmp_dir, "output.tflite")
+
+ if not self.optimizer_configuration.layers_to_optimize:
+ raise ConfigurationError(
+ "Input and output tensor names need to be set for rewrite."
+ )
+ result = train(
+ source_model=tflite_model,
+ unmodified_model=tflite_model if use_unmodified_model else None,
+ output_model=str(tmp_output),
+ input_tfrec=str(tfrecord),
+ replace_fn=replace_fn,
+ input_tensors=[self.optimizer_configuration.layers_to_optimize[0]],
+ output_tensors=[self.optimizer_configuration.layers_to_optimize[1]],
+ train_params=self.optimizer_configuration.train_params,
+ )
+
+ self.model = TFLiteModel(tmp_output)
+
+ if result:
+ stats_as_str = ", ".join(str(stats) for stats in result)
+ logger.info(
+ "The MAE and NRMSE between original and replacement [%s]",
+ stats_as_str,
+ )
def get_model(self) -> TFLiteModel:
"""Return optimized model."""
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index c8497a4..42bf653 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -1,8 +1,8 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Sequential trainer."""
-# pylint: disable=too-many-arguments, too-many-instance-attributes,
-# pylint: disable=too-many-locals, too-many-branches, too-many-statements
+# pylint: disable=too-many-locals
+# pylint: disable=too-many-statements
from __future__ import annotations
import logging
@@ -10,10 +10,13 @@ import math
import os
import tempfile
from collections import defaultdict
+from contextlib import contextmanager
+from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import Callable
from typing import cast
+from typing import Generator as GeneratorType
from typing import get_args
from typing import Literal
@@ -27,10 +30,10 @@ from mlia.nn.rewrite.core.graph_edit.join import join_models
from mlia.nn.rewrite.core.graph_edit.record import record_model
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
-from mlia.nn.rewrite.core.utils.numpy_tfrecord import TFLiteModel
from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel
-from mlia.nn.rewrite.core.utils.utils import load
-from mlia.nn.rewrite.core.utils.utils import save
+from mlia.nn.tensorflow.config import TFLiteModel
+from mlia.nn.tensorflow.tflite_graph import load_fb
+from mlia.nn.tensorflow.tflite_graph import save_fb
from mlia.utils.logging import log_action
@@ -38,7 +41,7 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
logger = logging.getLogger(__name__)
-augmentation_presets = {
+AUGMENTATION_PRESETS = {
"none": (None, None),
"gaussian": (None, 1.0),
"mixup": (1.0, None),
@@ -51,6 +54,21 @@ LearningRateSchedule = Literal["cosine", "late", "constant"]
LEARNING_RATE_SCHEDULES = get_args(LearningRateSchedule)
+@dataclass
+class TrainingParameters:
+ """Define default parameters for the training."""
+
+ augmentations: tuple[float | None, float | None] = AUGMENTATION_PRESETS["gaussian"]
+ batch_size: int = 32
+ steps: int = 48000
+ learning_rate: float = 1e-3
+ learning_rate_schedule: LearningRateSchedule = "cosine"
+ num_procs: int = 1
+ num_threads: int = 0
+ show_progress: bool = True
+ checkpoint_at: list | None = None
+
+
def train(
source_model: str,
unmodified_model: Any,
@@ -59,16 +77,7 @@ def train(
replace_fn: Callable,
input_tensors: list,
output_tensors: list,
- augment: tuple[float | None, float | None],
- steps: int,
- learning_rate: float,
- batch_size: int,
- verbose: bool,
- show_progress: bool,
- learning_rate_schedule: LearningRateSchedule = "cosine",
- checkpoint_at: list | None = None,
- num_procs: int = 1,
- num_threads: int = 0,
+ train_params: TrainingParameters = TrainingParameters(),
) -> Any:
"""Extract and train a model, and return the results."""
if unmodified_model:
@@ -95,29 +104,27 @@ def train(
input_tfrec,
input_tensors,
output_tensors,
- num_procs=num_procs,
- num_threads=num_threads,
+ num_procs=train_params.num_procs,
+ num_threads=train_params.num_threads,
)
tflite_filenames = train_in_dir(
- train_dir,
- unmodified_model_dir_path,
- Path(train_dir, "new.tflite"),
- replace_fn,
- augment,
- steps,
- learning_rate,
- batch_size,
- checkpoint_at=checkpoint_at,
- verbose=verbose,
- show_progress=show_progress,
- num_procs=num_procs,
- num_threads=num_threads,
- schedule=learning_rate_schedule,
+ train_dir=train_dir,
+ baseline_dir=unmodified_model_dir_path,
+ output_filename=Path(train_dir, "new.tflite"),
+ replace_fn=replace_fn,
+ train_params=train_params,
)
for i, filename in enumerate(tflite_filenames):
- results.append(eval_in_dir(train_dir, filename, num_procs, num_threads))
+ results.append(
+ eval_in_dir(
+ train_dir,
+ filename,
+ train_params.num_procs,
+ train_params.num_threads,
+ )
+ )
if output_model:
if i + 1 < len(tflite_filenames):
@@ -133,7 +140,7 @@ def train(
cast(tempfile.TemporaryDirectory, unmodified_model_dir).cleanup()
return (
- results if checkpoint_at else results[0]
+ results if train_params.checkpoint_at else results[0]
) # only return a list if multiple checkpoints are asked for
@@ -176,46 +183,24 @@ def join_in_dir(model_dir: str, new_part: str, output_model: str) -> None:
join_models(Path(model_dir, "start.tflite"), new_end, output_model)
-def train_in_dir(
- train_dir: str,
- baseline_dir: Any,
- output_filename: Path,
- replace_fn: Callable,
- augmentations: tuple[float | None, float | None],
- steps: int,
- learning_rate: float = 1e-3,
- batch_size: int = 32,
- checkpoint_at: list | None = None,
- schedule: str = "cosine",
- verbose: bool = False,
- show_progress: bool = False,
- num_procs: int = 0,
- num_threads: int = 1,
-) -> list:
- """Train a replacement for replace.tflite using the input.tfrec \
- and output.tfrec in train_dir.
-
- If baseline_dir is provided, train the replacement to match baseline
- outputs for train_dir inputs. Result saved as new.tflite in train_dir.
- """
- teacher_dir = baseline_dir if baseline_dir else train_dir
- teacher = ParallelTFLiteModel(
- f"{teacher_dir}/replace.tflite", num_procs, num_threads, batch_size=batch_size
- )
- replace = TFLiteModel(f"{train_dir}/replace.tflite")
+def _get_io_tensors(model: TFLiteModel) -> tuple[str, str]:
assert (
- len(teacher.input_tensors()) == 1
+ len(model.input_tensors()) == 1
), f"Can only train replacements with a single input tensor right now, \
- found {teacher.input_tensors()}"
+ found {model.input_tensors()}"
assert (
- len(teacher.output_tensors()) == 1
+ len(model.output_tensors()) == 1
), f"Can only train replacements with a single output tensor right now, \
- found {teacher.output_tensors()}"
+ found {model.output_tensors()}"
+
+ input_name = model.input_tensors()[0]
+ output_name = model.output_tensors()[0]
+ return (input_name, output_name)
- input_name = teacher.input_tensors()[0]
- output_name = teacher.output_tensors()[0]
+def _check_model_compatibility(teacher: TFLiteModel, replace: TFLiteModel) -> None:
+ """Assert that teacher and replaced sub-graph are compatible."""
assert len(teacher.shape_from_name) == len(
replace.shape_from_name
), f"Baseline and train models must have the same number of inputs and outputs. \
@@ -230,10 +215,37 @@ def train_in_dir(
subgraph being replaced. Teacher: {teacher.shape_from_name}\n \
Train dir: {replace.shape_from_name}"
+
+def set_up_data_pipeline(
+ teacher: TFLiteModel,
+ replace: TFLiteModel,
+ train_dir: str,
+ augmentations: tuple[float | None, float | None],
+ steps: int,
+ batch_size: int = 32,
+) -> tf.data.Dataset:
+ """Create a data pipeline for training of the replacement model."""
+ _check_model_compatibility(teacher, replace)
+
+ input_name, output_name = _get_io_tensors(teacher)
+
input_filename = Path(train_dir, "input.tfrec")
total = numpytf_count(str(input_filename))
dict_inputs = numpytf_read(str(input_filename))
+
inputs = dict_inputs.map(lambda d: tf.squeeze(d[input_name], axis=0))
+
+ steps_per_epoch = math.ceil(total / batch_size)
+ epochs = int(math.ceil(steps / steps_per_epoch))
+ logger.info(
+ "Training on %d items for %d steps (%d epochs with batch size %d)",
+ total,
+ epochs * steps_per_epoch,
+ epochs,
+ batch_size,
+ )
+
+ teacher_dir = Path(teacher.model_path).parent
if any(augmentations):
# Map the teacher inputs here because the augmentation stage passes these
# through a TFLite model to get the outputs
@@ -245,17 +257,6 @@ def train_in_dir(
lambda d: tf.squeeze(d[output_name], axis=0)
)
- steps_per_epoch = math.ceil(total / batch_size)
- epochs = int(math.ceil(steps / steps_per_epoch))
- if verbose:
- logger.info(
- "Training on %d items for %d steps (%d epochs with batch size %d)",
- total,
- epochs * steps_per_epoch,
- epochs,
- batch_size,
- )
-
dataset = tf.data.Dataset.zip((inputs, teacher_outputs))
if epochs > 1:
dataset = dataset.cache()
@@ -268,10 +269,9 @@ def train_in_dir(
train: Any, teach: Any # pylint: disable=redefined-outer-name
) -> tuple:
"""Return results of train and teach based on augmentations."""
- return (
- augment_train({input_name: train})[input_name],
- teacher(augment_teacher({input_name: teach}))[output_name],
- )
+ augmented_train = augment_train({input_name: train})[input_name]
+ augmented_teach = teacher(augment_teacher({input_name: teach}))[output_name]
+ return (augmented_train, augmented_teach)
dataset = dataset.map(
lambda augment_train, augment_teach: tf.py_function(
@@ -281,18 +281,67 @@ def train_in_dir(
)
)
+ # Restore data shapes of the dataset as they are set to unknown per default
+ # and get lost during augmentation with tf.py_function.
+ shape_in, shape_out = (
+ teacher.shape_from_name[name].tolist() for name in (input_name, output_name)
+ )
+ for shape in (shape_in, shape_out):
+ shape[0] = None # set dynamic batch size
+
+ def restore_shapes(input_: Any, output: Any) -> tuple[Any, Any]:
+ input_.set_shape(shape_in)
+ output.set_shape(shape_out)
+ return input_, output
+
+ dataset = dataset.map(restore_shapes)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
+ return dataset
+
+
+def train_in_dir(
+ train_dir: str,
+ baseline_dir: Any,
+ output_filename: Path,
+ replace_fn: Callable,
+ train_params: TrainingParameters = TrainingParameters(),
+) -> list[str]:
+ """Train a replacement for replace.tflite using the input.tfrec \
+ and output.tfrec in train_dir.
+
+ If baseline_dir is provided, train the replacement to match baseline
+ outputs for train_dir inputs. Result saved as new.tflite in train_dir.
+ """
+ teacher_dir = baseline_dir if baseline_dir else train_dir
+ teacher = ParallelTFLiteModel(
+ f"{teacher_dir}/replace.tflite",
+ train_params.num_procs,
+ train_params.num_threads,
+ batch_size=train_params.batch_size,
+ )
+ replace = TFLiteModel(f"{train_dir}/replace.tflite")
+
+ input_name, output_name = _get_io_tensors(teacher)
+
+ dataset = set_up_data_pipeline(
+ teacher,
+ replace,
+ train_dir,
+ augmentations=train_params.augmentations,
+ steps=train_params.steps,
+ batch_size=train_params.batch_size,
+ )
input_shape = teacher.shape_from_name[input_name][1:]
output_shape = teacher.shape_from_name[output_name][1:]
+
model = replace_fn(input_shape, output_shape)
- optimizer = tf.keras.optimizers.Nadam(learning_rate=learning_rate)
+ optimizer = tf.keras.optimizers.Nadam(learning_rate=train_params.learning_rate)
loss_fn = tf.keras.losses.MeanSquaredError()
- model.compile(optimizer=optimizer, loss=loss_fn)
+ model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"])
- if verbose:
- model.summary()
+ logger.info(model.summary())
steps_so_far = 0
@@ -302,7 +351,9 @@ def train_in_dir(
"""Cosine decay from learning rate at start of the run to zero at the end."""
current_step = epoch_step + steps_so_far
cd_learning_rate = (
- learning_rate * (math.cos(math.pi * current_step / steps) + 1) / 2.0
+ train_params.learning_rate
+ * (math.cos(math.pi * current_step / train_params.steps) + 1)
+ / 2.0
)
tf.keras.backend.set_value(optimizer.learning_rate, cd_learning_rate)
@@ -311,28 +362,29 @@ def train_in_dir(
) -> None:
"""Constant until the last 20% of the run, then linear decay to zero."""
current_step = epoch_step + steps_so_far
- steps_remaining = steps - current_step
- decay_length = steps // 5
+ steps_remaining = train_params.steps - current_step
+ decay_length = train_params.steps // 5
decay_fraction = min(steps_remaining, decay_length) / decay_length
- ld_learning_rate = learning_rate * decay_fraction
+ ld_learning_rate = train_params.learning_rate * decay_fraction
tf.keras.backend.set_value(optimizer.learning_rate, ld_learning_rate)
- if schedule == "cosine":
+ assert train_params.learning_rate_schedule in LEARNING_RATE_SCHEDULES, (
+ f'Learning rate schedule "{train_params.learning_rate_schedule}" '
+ f"not implemented - expected one of {LEARNING_RATE_SCHEDULES}."
+ )
+ if train_params.learning_rate_schedule == "cosine":
callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=cosine_decay)]
- elif schedule == "late":
+ elif train_params.learning_rate_schedule == "late":
callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=late_decay)]
- elif schedule == "constant":
+ elif train_params.learning_rate_schedule == "constant":
callbacks = []
- else:
- assert schedule not in LEARNING_RATE_SCHEDULES
- raise ValueError(
- f'Learning rate schedule "{schedule}" not implemented - '
- f"expected one of {LEARNING_RATE_SCHEDULES}."
- )
output_filenames = []
- checkpoints = (checkpoint_at if checkpoint_at else []) + [steps]
- while steps_so_far < steps:
+ checkpoints = (train_params.checkpoint_at if train_params.checkpoint_at else []) + [
+ train_params.steps
+ ]
+
+ while steps_so_far < train_params.steps:
steps_to_train = checkpoints.pop(0) - steps_so_far
lr_start = optimizer.learning_rate.numpy()
model.fit(
@@ -340,7 +392,7 @@ def train_in_dir(
epochs=1,
steps_per_epoch=steps_to_train,
callbacks=callbacks,
- verbose=show_progress,
+ verbose=train_params.show_progress,
)
steps_so_far += steps_to_train
logger.info(
@@ -350,12 +402,14 @@ def train_in_dir(
steps_to_train,
)
- if steps_so_far < steps:
+ if steps_so_far < train_params.steps:
filename, ext = Path(output_filename).parts[1:]
checkpoint_filename = filename + (f"_@{steps_so_far}") + ext
else:
checkpoint_filename = str(output_filename)
- with log_action(f"{steps_so_far}/{steps}: Saved as {checkpoint_filename}"):
+ with log_action(
+ f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}"
+ ):
save_as_tflite(
model,
checkpoint_filename,
@@ -379,14 +433,30 @@ def save_as_tflite(
output_shape: list,
) -> None:
"""Save Keras model as TFLite file."""
- converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
+
+ @contextmanager
+ def fixed_input(keras_model: tf.keras.Model, tmp_shape: list) -> GeneratorType:
+ """Fix the input shape of the Keras model temporarily.
+
+ This avoids artifacts during conversion to TensorFlow Lite.
+ """
+ orig_shape = keras_model.input.shape
+ keras_model.input.set_shape(tf.TensorShape(tmp_shape))
+ try:
+ yield keras_model
+ finally:
+ # Restore original shape to not interfere with further training
+ keras_model.input.set_shape(orig_shape)
+
+ with fixed_input(keras_model, input_shape) as fixed_model:
+ converter = tf.lite.TFLiteConverter.from_keras_model(fixed_model)
tflite_model = converter.convert()
with open(filename, "wb") as file:
file.write(tflite_model)
# Now fix the shapes and names to match those we expect
- flatbuffer = load(filename)
+ flatbuffer = load_fb(filename)
i = flatbuffer.subgraphs[0].inputs[0]
flatbuffer.subgraphs[0].tensors[i].shape = np.array(input_shape, dtype=np.int32)
flatbuffer.subgraphs[0].tensors[i].name = input_name.encode("utf-8")
@@ -395,11 +465,11 @@ def save_as_tflite(
output_shape, dtype=np.int32
)
flatbuffer.subgraphs[0].tensors[output].name = output_name.encode("utf-8")
- save(flatbuffer, filename)
+ save_fb(flatbuffer, filename)
def augment_fn_twins(
- inputs: dict, augmentations: tuple[float | None, float | None]
+ inputs: tf.data.Dataset, augmentations: tuple[float | None, float | None]
) -> Any:
"""Return a pair of twinned augmentation functions with the same sequence \
of random numbers."""
@@ -415,6 +485,11 @@ def augment_fn(
inputs: Any, augmentations: tuple[float | None, float | None], rng: Generator
) -> Any:
"""Augmentation module."""
+ assert len(augmentations) == 2, (
+ f"Unexpected number of augmentation parameters: {len(augmentations)} "
+ "(must be 2)"
+ )
+
mixup_strength, gaussian_strength = augmentations
augments = []
@@ -449,17 +524,16 @@ def augment_fn(
augments.append(gaussian_strength_augment)
- if len(augments) == 0: # pylint: disable=no-else-return
+ if len(augments) == 0:
return lambda x: x
- elif len(augments) == 1:
+ if len(augments) == 1:
return augments[0]
- elif len(augments) == 2:
+ if len(augments) == 2:
return lambda x: augments[1](augments[0](x))
- else:
- assert (
- False
- ), f"Unexpected number of augmentation \
- functions ({len(augments)})"
+
+ raise RuntimeError(
+ "Unexpected number of augmentation functions (must be <=2): " f"{len(augments)}"
+ )
def mixup(rng: Generator, batch: Any, beta_range: tuple = (0.0, 1.0)) -> Any:
diff --git a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
index 9229810..38ac1ed 100644
--- a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
+++ b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
@@ -6,55 +6,56 @@ from __future__ import annotations
import json
import os
import random
-import tempfile
-from collections import defaultdict
+from functools import lru_cache
from pathlib import Path
from typing import Any
from typing import Callable
-import numpy as np
import tensorflow as tf
-from tensorflow.lite.python import interpreter as interpreter_wrapper
-from mlia.nn.rewrite.core.utils.utils import load
-from mlia.nn.rewrite.core.utils.utils import save
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
-def make_decode_fn(filename: str) -> Callable:
- """Make decode filename."""
+def decode_fn(record_bytes: Any, type_map: dict) -> dict:
+ """Decode the given bytes into a name-tensor dict assuming the given type."""
+ parse_dict = {
+ name: tf.io.FixedLenFeature([], tf.string) for name in type_map.keys()
+ }
+ example = tf.io.parse_single_example(record_bytes, parse_dict)
+ features = {
+ n: tf.io.parse_tensor(example[n], tf.as_dtype(t)) for n, t in type_map.items()
+ }
+ return features
- def decode_fn(record_bytes: Any, type_map: dict) -> dict:
- parse_dict = {
- name: tf.io.FixedLenFeature([], tf.string) for name in type_map.keys()
- }
- example = tf.io.parse_single_example(record_bytes, parse_dict)
- features = {
- n: tf.io.parse_tensor(example[n], tf.as_dtype(t))
- for n, t in type_map.items()
- }
- return features
+def make_decode_fn(filename: str, model_filename: str | Path | None = None) -> Callable:
+ """Make decode filename."""
meta_filename = filename + ".meta"
- with open(meta_filename, encoding="utf-8") as file:
- type_map = json.load(file)["type_map"]
+ try:
+ with open(meta_filename, encoding="utf-8") as file:
+ type_map = json.load(file)["type_map"]
return lambda record_bytes: decode_fn(record_bytes, type_map)
def numpytf_read(filename: str | Path) -> Any:
"""Read TFRecord dataset."""
- decode_fn = make_decode_fn(str(filename))
+ decode = make_decode_fn(str(filename))
dataset = tf.data.TFRecordDataset(str(filename))
- return dataset.map(decode_fn)
+ return dataset.map(decode)
-def numpytf_count(filename: str | Path) -> Any:
+@lru_cache
+def numpytf_count(filename: str | Path) -> int:
"""Return count from TFRecord file."""
meta_filename = f"{filename}.meta"
- with open(meta_filename, encoding="utf-8") as file:
- return json.load(file)["count"]
+ try:
+ with open(meta_filename, encoding="utf-8") as file:
+ return int(json.load(file)["count"])
+ except FileNotFoundError:
+ raw_dataset = tf.data.TFRecordDataset(filename)
+ return sum(1 for _ in raw_dataset)
class NumpyTFWriter:
@@ -101,93 +102,6 @@ class NumpyTFWriter:
self.writer.close()
-class TFLiteModel:
- """A representation of a TFLite Model."""
-
- def __init__(
- self,
- filename: str,
- batch_size: int | None = None,
- num_threads: int | None = None,
- ) -> None:
- """Initiate a TFLite Model."""
- if not num_threads:
- num_threads = None
- if not batch_size:
- self.interpreter = interpreter_wrapper.Interpreter(
- model_path=filename, num_threads=num_threads
- )
- else: # if a batch size is specified, modify the TFLite model to use this size
- with tempfile.TemporaryDirectory() as tmp:
- flatbuffer = load(filename)
- for subgraph in flatbuffer.subgraphs:
- for tensor in list(subgraph.inputs) + list(subgraph.outputs):
- subgraph.tensors[tensor].shape = np.array(
- [batch_size] + list(subgraph.tensors[tensor].shape[1:]),
- dtype=np.int32,
- )
- tempname = os.path.join(tmp, "rewrite_tmp.tflite")
- save(flatbuffer, tempname)
- self.interpreter = interpreter_wrapper.Interpreter(
- model_path=tempname, num_threads=num_threads
- )
-
- try:
- self.interpreter.allocate_tensors()
- except RuntimeError:
- self.interpreter = interpreter_wrapper.Interpreter(
- model_path=filename, num_threads=num_threads
- )
- self.interpreter.allocate_tensors()
-
- # Get input and output tensors.
- self.input_details = self.interpreter.get_input_details()
- self.output_details = self.interpreter.get_output_details()
- details = list(self.input_details) + list(self.output_details)
- self.handle_from_name = {d["name"]: d["index"] for d in details}
- self.shape_from_name = {d["name"]: d["shape"] for d in details}
- self.batch_size = next(iter(self.shape_from_name.values()))[0]
-
- def __call__(self, named_input: dict) -> dict:
- """Execute the model on one or a batch of named inputs \
- (a dict of name: numpy array)."""
- input_len = next(iter(named_input.values())).shape[0]
- full_steps = input_len // self.batch_size
- remainder = input_len % self.batch_size
-
- named_ys = defaultdict(list)
- for i in range(full_steps):
- for name, x_batch in named_input.items():
- x_tensor = x_batch[i : i + self.batch_size] # noqa: E203
- self.interpreter.set_tensor(self.handle_from_name[name], x_tensor)
- self.interpreter.invoke()
- for output_detail in self.output_details:
- named_ys[output_detail["name"]].append(
- self.interpreter.get_tensor(output_detail["index"])
- )
- if remainder:
- for name, x_batch in named_input.items():
- x_tensor = np.zeros( # pylint: disable=invalid-name
- self.shape_from_name[name]
- ).astype(x_batch.dtype)
- x_tensor[:remainder] = x_batch[-remainder:]
- self.interpreter.set_tensor(self.handle_from_name[name], x_tensor)
- self.interpreter.invoke()
- for output_detail in self.output_details:
- named_ys[output_detail["name"]].append(
- self.interpreter.get_tensor(output_detail["index"])[:remainder]
- )
- return {k: np.concatenate(v) for k, v in named_ys.items()}
-
- def input_tensors(self) -> list:
- """Return name from input details."""
- return [d["name"] for d in self.input_details]
-
- def output_tensors(self) -> list:
- """Return name from output details."""
- return [d["name"] for d in self.output_details]
-
-
def sample_tfrec(input_file: str, k: int, output_file: str) -> None:
"""Count, read and write TFRecord input and output data."""
total = numpytf_count(input_file)
diff --git a/src/mlia/nn/rewrite/core/utils/parallel.py b/src/mlia/nn/rewrite/core/utils/parallel.py
index d930a1e..b7b390d 100644
--- a/src/mlia/nn/rewrite/core/utils/parallel.py
+++ b/src/mlia/nn/rewrite/core/utils/parallel.py
@@ -15,14 +15,14 @@ from typing import Any
import numpy as np
import tensorflow as tf
-from mlia.nn.rewrite.core.utils.numpy_tfrecord import TFLiteModel
+from mlia.nn.tensorflow.config import TFLiteModel
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
logger = logging.getLogger(__name__)
-class ParallelTFLiteModel(TFLiteModel):
+class ParallelTFLiteModel(TFLiteModel): # pylint: disable=abstract-method
"""A parallel version of a TFLiteModel.
num_procs: 0 => detect real cores on system
diff --git a/src/mlia/nn/rewrite/core/utils/utils.py b/src/mlia/nn/rewrite/core/utils/utils.py
deleted file mode 100644
index ddf0cc2..0000000
--- a/src/mlia/nn/rewrite/core/utils/utils.py
+++ /dev/null
@@ -1,32 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Model and file system utilites."""
-from __future__ import annotations
-
-from pathlib import Path
-
-import flatbuffers
-from tensorflow.lite.python.schema_py_generated import Model
-from tensorflow.lite.python.schema_py_generated import ModelT
-
-
-def load(input_tflite_file: str | Path) -> ModelT:
- """Load a flatbuffer model from file."""
- if not Path(input_tflite_file).exists():
- raise FileNotFoundError(f"TFLite file not found at {input_tflite_file}\n")
- with open(input_tflite_file, "rb") as file_handle:
- file_data = bytearray(file_handle.read())
- model_obj = Model.GetRootAsModel(file_data, 0)
- model = ModelT.InitFromObj(model_obj)
- return model
-
-
-def save(model: ModelT, output_tflite_file: str | Path) -> None:
- """Save a flatbuffer model to a given file."""
- builder = flatbuffers.Builder(1024) # Initial size of the buffer, which
- # will grow automatically if needed
- model_offset = model.Pack(builder)
- builder.Finish(model_offset, file_identifier=b"TFL3")
- model_data = builder.Output()
- with open(output_tflite_file, "wb") as out_file:
- out_file.write(model_data)
diff --git a/src/mlia/nn/rewrite/library/fc_layer.py b/src/mlia/nn/rewrite/library/fc_layer.py
index 8704154..2480500 100644
--- a/src/mlia/nn/rewrite/library/fc_layer.py
+++ b/src/mlia/nn/rewrite/library/fc_layer.py
@@ -7,12 +7,12 @@ import tensorflow as tf
def get_keras_model(input_shape: Any, output_shape: Any) -> tf.keras.Model:
- """Generate tflite model for rewrite."""
- input_tensor = tf.keras.layers.Input(
- shape=input_shape, name="MbileNet/avg_pool/AvgPool"
+ """Generate TensorFlow Lite model for rewrite."""
+ model = tf.keras.Sequential(
+ (
+ tf.keras.layers.InputLayer(input_shape=input_shape),
+ tf.keras.layers.Reshape([-1]),
+ tf.keras.layers.Dense(output_shape),
+ )
)
- output_tensor = tf.keras.layers.Dense(output_shape, name="MobileNet/fc1/BiasAdd")(
- input_tensor
- )
- model = tf.keras.Model(input_tensor, output_tensor)
return model
diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py
index 5a7f289..983426b 100644
--- a/src/mlia/nn/select.py
+++ b/src/mlia/nn/select.py
@@ -17,6 +17,7 @@ from mlia.nn.common import Optimizer
from mlia.nn.common import OptimizerConfiguration
from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
from mlia.nn.rewrite.core.rewrite import Rewriter
+from mlia.nn.rewrite.core.rewrite import TrainingParameters
from mlia.nn.tensorflow.config import KerasModel
from mlia.nn.tensorflow.config import TFLiteModel
from mlia.nn.tensorflow.optimizations.clustering import Clusterer
@@ -164,6 +165,15 @@ def _get_optimizer(
return MultiStageOptimizer(model, optimizer_configs)
+def _get_rewrite_train_params() -> TrainingParameters:
+ """Get the rewrite TrainingParameters.
+
+ Return the default constructed TrainingParameters() per default, but can be
+ overwritten in the unit tests.
+ """
+ return TrainingParameters()
+
+
def _get_optimizer_configuration(
optimization_type: str,
optimization_target: int | float | str,
@@ -190,7 +200,10 @@ def _get_optimizer_configuration(
if opt_type == "rewrite":
if isinstance(optimization_target, str):
return RewriteConfiguration(
- str(optimization_target), layers_to_optimize, dataset
+ optimization_target=str(optimization_target),
+ layers_to_optimize=layers_to_optimize,
+ dataset=dataset,
+ train_params=_get_rewrite_train_params(),
)
raise ConfigurationError(
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py
index d7d430f..c6a7c88 100644
--- a/src/mlia/nn/tensorflow/config.py
+++ b/src/mlia/nn/tensorflow/config.py
@@ -4,13 +4,16 @@
from __future__ import annotations
import logging
+import tempfile
+from collections import defaultdict
from pathlib import Path
-from typing import cast
-from typing import List
+import numpy as np
import tensorflow as tf
from mlia.core.context import Context
+from mlia.nn.tensorflow.tflite_graph import load_fb
+from mlia.nn.tensorflow.tflite_graph import save_fb
from mlia.nn.tensorflow.utils import convert_to_tflite
from mlia.nn.tensorflow.utils import is_keras_model
from mlia.nn.tensorflow.utils import is_saved_model
@@ -71,10 +74,89 @@ class KerasModel(ModelConfiguration):
class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
"""TensorFlow Lite model configuration."""
- def input_details(self) -> list[dict]:
- """Get model's input details."""
- interpreter = tf.lite.Interpreter(model_path=self.model_path)
- return cast(List[dict], interpreter.get_input_details())
+ def __init__(
+ self,
+ model_path: str | Path,
+ batch_size: int | None = None,
+ num_threads: int | None = None,
+ ) -> None:
+ """Initiate a TFLite Model."""
+ super().__init__(model_path)
+ if not num_threads:
+ num_threads = None
+ if not batch_size:
+ self.interpreter = tf.lite.Interpreter(
+ model_path=self.model_path, num_threads=num_threads
+ )
+ else: # if a batch size is specified, modify the TFLite model to use this size
+ with tempfile.TemporaryDirectory() as tmp:
+ flatbuffer = load_fb(self.model_path)
+ for subgraph in flatbuffer.subgraphs:
+ for tensor in list(subgraph.inputs) + list(subgraph.outputs):
+ subgraph.tensors[tensor].shape = np.array(
+ [batch_size] + list(subgraph.tensors[tensor].shape[1:]),
+ dtype=np.int32,
+ )
+ tempname = Path(tmp, "rewrite_tmp.tflite")
+ save_fb(flatbuffer, tempname)
+ self.interpreter = tf.lite.Interpreter(
+ model_path=str(tempname), num_threads=num_threads
+ )
+
+ try:
+ self.interpreter.allocate_tensors()
+ except RuntimeError:
+ self.interpreter = tf.lite.Interpreter(
+ model_path=self.model_path, num_threads=num_threads
+ )
+ self.interpreter.allocate_tensors()
+
+ # Get input and output tensors.
+ self.input_details = self.interpreter.get_input_details()
+ self.output_details = self.interpreter.get_output_details()
+ details = list(self.input_details) + list(self.output_details)
+ self.handle_from_name = {d["name"]: d["index"] for d in details}
+ self.shape_from_name = {d["name"]: d["shape"] for d in details}
+ self.batch_size = next(iter(self.shape_from_name.values()))[0]
+
+ def __call__(self, named_input: dict) -> dict:
+ """Execute the model on one or a batch of named inputs \
+ (a dict of name: numpy array)."""
+ input_len = next(iter(named_input.values())).shape[0]
+ full_steps = input_len // self.batch_size
+ remainder = input_len % self.batch_size
+
+ named_ys = defaultdict(list)
+ for i in range(full_steps):
+ for name, x_batch in named_input.items():
+ x_tensor = x_batch[i : i + self.batch_size] # noqa: E203
+ self.interpreter.set_tensor(self.handle_from_name[name], x_tensor)
+ self.interpreter.invoke()
+ for output_detail in self.output_details:
+ named_ys[output_detail["name"]].append(
+ self.interpreter.get_tensor(output_detail["index"])
+ )
+ if remainder:
+ for name, x_batch in named_input.items():
+ x_tensor = np.zeros( # pylint: disable=invalid-name
+ self.shape_from_name[name]
+ ).astype(x_batch.dtype)
+ x_tensor[:remainder] = x_batch[-remainder:]
+ self.interpreter.set_tensor(self.handle_from_name[name], x_tensor)
+ self.interpreter.invoke()
+ for output_detail in self.output_details:
+ named_ys[output_detail["name"]].append(
+ self.interpreter.get_tensor(output_detail["index"])[:remainder]
+ )
+ return {k: np.concatenate(v) for k, v in named_ys.items()}
+
+ def input_tensors(self) -> list:
+ """Return name from input details."""
+ return [d["name"] for d in self.input_details]
+
+ def output_tensors(self) -> list:
+ """Return name from output details."""
+ return [d["name"] for d in self.output_details]
def convert_to_tflite(
self, tflite_model_path: str | Path, quantized: bool = False
@@ -118,10 +200,10 @@ def get_model(model: str | Path) -> ModelConfiguration:
def get_tflite_model(model: str | Path, ctx: Context) -> TFLiteModel:
"""Convert input model to TensorFlow Lite and returns TFLiteModel object."""
- tflite_model_path = ctx.get_model_path("converted_model.tflite")
- converted_model = get_model(model)
+ dst_model_path = ctx.get_model_path("converted_model.tflite")
+ src_model = get_model(model)
- return converted_model.convert_to_tflite(tflite_model_path, True)
+ return src_model.convert_to_tflite(dst_model_path, quantized=True)
def get_keras_model(model: str | Path, ctx: Context) -> KerasModel:
diff --git a/src/mlia/nn/tensorflow/tflite_graph.py b/src/mlia/nn/tensorflow/tflite_graph.py
index 4f5e85f..7ca9337 100644
--- a/src/mlia/nn/tensorflow/tflite_graph.py
+++ b/src/mlia/nn/tensorflow/tflite_graph.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Utilities for TensorFlow Lite graphs."""
from __future__ import annotations
@@ -10,7 +10,10 @@ from pathlib import Path
from typing import Any
from typing import cast
+import flatbuffers
from tensorflow.lite.python import schema_py_generated as schema_fb
+from tensorflow.lite.python.schema_py_generated import Model
+from tensorflow.lite.python.schema_py_generated import ModelT
from tensorflow.lite.tools import visualize
@@ -137,3 +140,25 @@ def parse_subgraphs(tflite_file: Path) -> list[list[Op]]:
]
return graphs
+
+
+def load_fb(input_tflite_file: str | Path) -> ModelT:
+ """Load a flatbuffer model from file."""
+ if not Path(input_tflite_file).exists():
+ raise FileNotFoundError(f"TFLite file not found at {input_tflite_file}\n")
+ with open(input_tflite_file, "rb") as file_handle:
+ file_data = bytearray(file_handle.read())
+ model_obj = Model.GetRootAsModel(file_data, 0)
+ model = ModelT.InitFromObj(model_obj)
+ return model
+
+
+def save_fb(model: ModelT, output_tflite_file: str | Path) -> None:
+ """Save a flatbuffer model to a given file."""
+ builder = flatbuffers.Builder(1024) # Initial size of the buffer, which
+ # will grow automatically if needed
+ model_offset = model.Pack(builder)
+ builder.Finish(model_offset, file_identifier=b"TFL3")
+ model_data = builder.Output()
+ with open(output_tflite_file, "wb") as out_file:
+ out_file.write(model_data)
diff --git a/src/mlia/target/ethos_u/data_collection.py b/src/mlia/target/ethos_u/data_collection.py
index ba8b0fe..4ea6120 100644
--- a/src/mlia/target/ethos_u/data_collection.py
+++ b/src/mlia/target/ethos_u/data_collection.py
@@ -106,15 +106,14 @@ class OptimizeModel:
self.context = context
self.opt_settings = opt_settings
- def __call__(self, keras_model: KerasModel) -> Any:
+ def __call__(self, model: KerasModel | TFLiteModel) -> Any:
"""Run optimization."""
- optimizer = get_optimizer(keras_model, self.opt_settings)
+ optimizer = get_optimizer(model, self.opt_settings)
opts_as_str = ", ".join(str(opt) for opt in self.opt_settings)
logger.info("Applying model optimizations - [%s]", opts_as_str)
optimizer.apply_optimization()
-
- model = optimizer.get_model()
+ model = optimizer.get_model() # type: ignore
if isinstance(model, Path):
return model
@@ -178,6 +177,7 @@ class EthosUOptimizationPerformance(ContextAwareDataCollector):
self.target,
self.backends,
)
+
original_metrics, *optimized_metrics = estimate_performance(
model, estimator, optimizers # type: ignore
)
diff --git a/tests/conftest.py b/tests/conftest.py
index c42b8cb..bb2423f 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -5,6 +5,7 @@ import shutil
from pathlib import Path
from typing import Callable
from typing import Generator
+from unittest.mock import MagicMock
import numpy as np
import pytest
@@ -17,6 +18,7 @@ from mlia.nn.tensorflow.utils import convert_to_tflite
from mlia.nn.tensorflow.utils import save_keras_model
from mlia.nn.tensorflow.utils import save_tflite_model
from mlia.target.ethos_u.config import EthosUConfiguration
+from tests.utils.rewrite import TestTrainingParameters
@pytest.fixture(scope="session", name="test_resources_path")
@@ -168,16 +170,12 @@ def _write_tfrecord(
writer.write({input_name: data_generator()})
-@pytest.fixture(scope="session", name="test_tfrecord")
-def fixture_test_tfrecord(
- tmp_path_factory: pytest.TempPathFactory,
+def create_tfrecord(
+ tmp_path_factory: pytest.TempPathFactory, random_data: Callable
) -> Generator[Path, None, None]:
"""Create a tfrecord with random data matching fixture 'test_tflite_model'."""
tmp_path = tmp_path_factory.mktemp("tfrecords")
- tfrecord_file = tmp_path / "test_int8.tfrecord"
-
- def random_data() -> np.ndarray:
- return np.random.randint(low=-127, high=128, size=(1, 28, 28, 1), dtype=np.int8)
+ tfrecord_file = tmp_path / "test.tfrecord"
_write_tfrecord(tfrecord_file, random_data)
@@ -186,19 +184,36 @@ def fixture_test_tfrecord(
shutil.rmtree(tmp_path)
+@pytest.fixture(scope="session", name="test_tfrecord")
+def fixture_test_tfrecord(
+ tmp_path_factory: pytest.TempPathFactory,
+) -> Generator[Path, None, None]:
+ """Create a tfrecord with random data matching fixture 'test_tflite_model'."""
+
+ def random_data() -> np.ndarray:
+ return np.random.randint(low=-127, high=128, size=(1, 28, 28, 1), dtype=np.int8)
+
+ yield from create_tfrecord(tmp_path_factory, random_data)
+
+
@pytest.fixture(scope="session", name="test_tfrecord_fp32")
def fixture_test_tfrecord_fp32(
tmp_path_factory: pytest.TempPathFactory,
) -> Generator[Path, None, None]:
"""Create tfrecord with random data matching fixture 'test_tflite_model_fp32'."""
- tmp_path = tmp_path_factory.mktemp("tfrecords")
- tfrecord_file = tmp_path / "test_fp32.tfrecord"
def random_data() -> np.ndarray:
return np.random.rand(1, 28, 28, 1).astype(np.float32)
- _write_tfrecord(tfrecord_file, random_data)
+ yield from create_tfrecord(tmp_path_factory, random_data)
- yield tfrecord_file
- shutil.rmtree(tmp_path)
+@pytest.fixture(scope="session", autouse=True)
+def set_training_steps() -> Generator[None, None, None]:
+ """Speed up tests by using TestTrainingParameters."""
+ with pytest.MonkeyPatch.context() as monkeypatch:
+ monkeypatch.setattr(
+ "mlia.nn.select._get_rewrite_train_params",
+ MagicMock(return_value=TestTrainingParameters()),
+ )
+ yield
diff --git a/tests/test_nn_rewrite_core_graph_edit_join.py b/tests/test_nn_rewrite_core_graph_edit_join.py
index cb3e4e2..0cb121e 100644
--- a/tests/test_nn_rewrite_core_graph_edit_join.py
+++ b/tests/test_nn_rewrite_core_graph_edit_join.py
@@ -10,7 +10,7 @@ import pytest
from mlia.nn.rewrite.core.graph_edit.cut import cut_model
from mlia.nn.rewrite.core.graph_edit.join import append_relabel
from mlia.nn.rewrite.core.graph_edit.join import join_models
-from mlia.nn.rewrite.core.utils.utils import load
+from mlia.nn.tensorflow.tflite_graph import load_fb
from tests.utils.rewrite import models_are_equal
@@ -49,8 +49,8 @@ def test_join_model(test_tflite_model: Path, tmp_path: Path) -> None:
)
assert joined_file.is_file()
- orig_model = load(str(test_tflite_model))
- joined_model = load(str(joined_file))
+ orig_model = load_fb(str(test_tflite_model))
+ joined_model = load_fb(str(joined_file))
assert models_are_equal(orig_model, joined_model)
diff --git a/tests/test_nn_rewrite_core_graph_edit_record.py b/tests/test_nn_rewrite_core_graph_edit_record.py
index cd728af..41b9c50 100644
--- a/tests/test_nn_rewrite_core_graph_edit_record.py
+++ b/tests/test_nn_rewrite_core_graph_edit_record.py
@@ -3,15 +3,13 @@
"""Tests for module mlia.nn.rewrite.graph_edit.record."""
from pathlib import Path
-import pytest
import tensorflow as tf
from mlia.nn.rewrite.core.graph_edit.record import record_model
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
-@pytest.mark.parametrize("batch_size", (None, 1, 2))
-def test_record_model(
+def check_record_model(
test_tflite_model: Path,
tmp_path: Path,
test_tfrecord: Path,
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py
index b98971e..2542db2 100644
--- a/tests/test_nn_rewrite_core_rewrite.py
+++ b/tests/test_nn_rewrite_core_rewrite.py
@@ -12,6 +12,7 @@ import pytest
from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
from mlia.nn.rewrite.core.rewrite import Rewriter
from mlia.nn.tensorflow.config import TFLiteModel
+from tests.utils.rewrite import TestTrainingParameters
@pytest.mark.parametrize(
@@ -32,12 +33,14 @@ def test_rewrite_configuration(
None,
)
+ assert config_obj.optimization_target in str(config_obj)
+
rewriter_obj = Rewriter(test_tflite_model_fp32, config_obj)
assert rewriter_obj.optimizer_configuration.optimization_target == rewrite_name
assert isinstance(rewriter_obj, Rewriter)
-def test_rewriter(
+def test_rewriting_optimizer(
test_tflite_model_fp32: Path,
test_tfrecord_fp32: Path,
) -> None:
@@ -46,6 +49,7 @@ def test_rewriter(
"fully_connected",
["sequential/flatten/Reshape", "StatefulPartitionedCall:0"],
test_tfrecord_fp32,
+ train_params=TestTrainingParameters(),
)
test_obj = Rewriter(test_tflite_model_fp32, config_obj)
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
index 3c2ef3e..4493671 100644
--- a/tests/test_nn_rewrite_core_train.py
+++ b/tests/test_nn_rewrite_core_train.py
@@ -4,6 +4,7 @@
# pylint: disable=too-many-arguments
from __future__ import annotations
+from contextlib import ExitStack as does_not_raise
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any
@@ -12,10 +13,13 @@ import numpy as np
import pytest
import tensorflow as tf
-from mlia.nn.rewrite.core.train import augmentation_presets
+from mlia.nn.rewrite.core.train import augment_fn_twins
+from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS
from mlia.nn.rewrite.core.train import LearningRateSchedule
from mlia.nn.rewrite.core.train import mixup
from mlia.nn.rewrite.core.train import train
+from mlia.nn.rewrite.core.train import TrainingParameters
+from tests.utils.rewrite import TestTrainingParameters
def replace_fully_connected_with_conv(
@@ -41,15 +45,8 @@ def replace_fully_connected_with_conv(
def check_train(
tflite_model: Path,
tfrecord: Path,
- batch_size: int = 1,
- verbose: bool = False,
- show_progress: bool = False,
- augmentation_preset: tuple[float | None, float | None] = augmentation_presets[
- "none"
- ],
- lr_schedule: LearningRateSchedule = "cosine",
+ train_params: TrainingParameters = TestTrainingParameters(),
use_unmodified_model: bool = False,
- num_procs: int = 1,
) -> None:
"""Test the train() function."""
with TemporaryDirectory() as tmp_dir:
@@ -62,14 +59,7 @@ def check_train(
replace_fn=replace_fully_connected_with_conv,
input_tensors=["sequential/flatten/Reshape"],
output_tensors=["StatefulPartitionedCall:0"],
- augment=augmentation_preset,
- steps=32,
- learning_rate=1e-3,
- batch_size=batch_size,
- verbose=verbose,
- show_progress=show_progress,
- learning_rate_schedule=lr_schedule,
- num_procs=num_procs,
+ train_params=train_params,
)
assert len(result) == 2
assert all(res >= 0.0 for res in result), f"Results out of bound: {result}"
@@ -79,7 +69,6 @@ def check_train(
@pytest.mark.parametrize(
(
"batch_size",
- "verbose",
"show_progress",
"augmentation_preset",
"lr_schedule",
@@ -87,14 +76,13 @@ def check_train(
"num_procs",
),
(
- (1, False, False, augmentation_presets["none"], "cosine", False, 2),
- (32, True, True, augmentation_presets["gaussian"], "late", True, 1),
- (2, False, False, augmentation_presets["mixup"], "constant", True, 0),
+ (1, False, AUGMENTATION_PRESETS["none"], "cosine", False, 2),
+ (32, True, AUGMENTATION_PRESETS["gaussian"], "late", True, 1),
+ (2, False, AUGMENTATION_PRESETS["mixup"], "constant", True, 0),
(
1,
False,
- False,
- augmentation_presets["mix_gaussian_large"],
+ AUGMENTATION_PRESETS["mix_gaussian_large"],
"cosine",
False,
2,
@@ -105,7 +93,6 @@ def test_train(
test_tflite_model_fp32: Path,
test_tfrecord_fp32: Path,
batch_size: int,
- verbose: bool,
show_progress: bool,
augmentation_preset: tuple[float | None, float | None],
lr_schedule: LearningRateSchedule,
@@ -116,13 +103,14 @@ def test_train(
check_train(
tflite_model=test_tflite_model_fp32,
tfrecord=test_tfrecord_fp32,
- batch_size=batch_size,
- verbose=verbose,
- show_progress=show_progress,
- augmentation_preset=augmentation_preset,
- lr_schedule=lr_schedule,
+ train_params=TestTrainingParameters(
+ batch_size=batch_size,
+ show_progress=show_progress,
+ augmentations=augmentation_preset,
+ learning_rate_schedule=lr_schedule,
+ num_procs=num_procs,
+ ),
use_unmodified_model=use_unmodified_model,
- num_procs=num_procs,
)
@@ -131,11 +119,13 @@ def test_train_invalid_schedule(
test_tfrecord_fp32: Path,
) -> None:
"""Test the train() function with an invalid schedule."""
- with pytest.raises(ValueError):
+ with pytest.raises(AssertionError):
check_train(
tflite_model=test_tflite_model_fp32,
tfrecord=test_tfrecord_fp32,
- lr_schedule="unknown_schedule", # type: ignore
+ train_params=TestTrainingParameters(
+ learning_rate_schedule="unknown_schedule",
+ ),
)
@@ -144,11 +134,13 @@ def test_train_invalid_augmentation(
test_tfrecord_fp32: Path,
) -> None:
"""Test the train() function with an invalid augmentation."""
- with pytest.raises(ValueError):
+ with pytest.raises(AssertionError):
check_train(
tflite_model=test_tflite_model_fp32,
tfrecord=test_tfrecord_fp32,
- augmentation_preset=(1.0, 2.0, 3.0), # type: ignore
+ train_params=TestTrainingParameters(
+ augmentations=(1.0, 2.0, 3.0),
+ ),
)
@@ -159,3 +151,19 @@ def test_mixup() -> None:
assert src.shape == dst.shape
assert np.all(dst >= 0.0)
assert np.all(dst <= 3.0)
+
+
+@pytest.mark.parametrize(
+ "augmentations, expected_error",
+ [
+ (AUGMENTATION_PRESETS["none"], does_not_raise()),
+ (AUGMENTATION_PRESETS["mix_gaussian_large"], does_not_raise()),
+ ((None,) * 3, pytest.raises(AssertionError)),
+ ],
+)
+def test_augment_fn_twins(augmentations: tuple, expected_error: Any) -> None:
+ """Test function augment_fn()."""
+ dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2, 3], "b": [4, 5, 6]})
+ with expected_error:
+ fn_twins = augment_fn_twins(dataset, augmentations) # type: ignore
+ assert len(fn_twins) == 2
diff --git a/tests/test_nn_rewrite_core_utils.py b/tests/test_nn_rewrite_core_utils.py
deleted file mode 100644
index d806a7b..0000000
--- a/tests/test_nn_rewrite_core_utils.py
+++ /dev/null
@@ -1,33 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Tests for module mlia.nn.rewrite.utils."""
-from pathlib import Path
-
-import pytest
-import tensorflow as tf
-from tensorflow.lite.python.schema_py_generated import ModelT
-
-from mlia.nn.rewrite.core.utils.utils import load
-from mlia.nn.rewrite.core.utils.utils import save
-from tests.utils.rewrite import models_are_equal
-
-
-def test_load_save(test_tflite_model: Path, tmp_path: Path) -> None:
- """Test the load/save functions for TensorFlow Lite models."""
- with pytest.raises(FileNotFoundError):
- load("THIS_IS_NOT_A_REAL_FILE")
-
- model = load(test_tflite_model)
- assert isinstance(model, ModelT)
- assert model.subgraphs
-
- output_file = tmp_path / "test.tflite"
- assert not output_file.is_file()
- save(model, output_file)
- assert output_file.is_file()
-
- model_copy = load(str(output_file))
- assert models_are_equal(model, model_copy)
-
- # Double check that the TensorFlow Lite Interpreter can still load the file.
- tf.lite.Interpreter(model_path=str(output_file))
diff --git a/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py b/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py
index 7fc8048..d030350 100644
--- a/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py
+++ b/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py
@@ -5,6 +5,10 @@ from __future__ import annotations
from pathlib import Path
+import pytest
+import tensorflow as tf
+
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import make_decode_fn
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count
from mlia.nn.rewrite.core.utils.numpy_tfrecord import sample_tfrec
@@ -16,3 +20,24 @@ def test_sample_tfrec(test_tfrecord: Path, tmp_path: Path) -> None:
sample_tfrec(input_file=str(test_tfrecord), k=1, output_file=str(output_file))
assert output_file.is_file()
assert numpytf_count(str(output_file)) == 1
+
+
+def test_make_decode_fn(test_tfrecord: Path) -> None:
+ """Test function make_decode_fn()."""
+ decode = make_decode_fn(str(test_tfrecord))
+ dataset = tf.data.TFRecordDataset(str(test_tfrecord))
+ features = decode(next(iter(dataset)))
+ assert isinstance(features, dict)
+ assert len(features) == 1
+ key, val = next(iter(features.items()))
+ assert isinstance(key, str)
+ assert isinstance(val, tf.Tensor)
+ assert val.dtype == tf.int8
+
+ with pytest.raises(FileNotFoundError):
+ make_decode_fn(str(test_tfrecord) + "_")
+
+
+def test_numpytf_count(test_tfrecord: Path) -> None:
+ """Test function numpytf_count()."""
+ assert numpytf_count(test_tfrecord) == 3
diff --git a/tests/test_nn_tensorflow_config.py b/tests/test_nn_tensorflow_config.py
index 656619d..48aec0a 100644
--- a/tests/test_nn_tensorflow_config.py
+++ b/tests/test_nn_tensorflow_config.py
@@ -4,13 +4,28 @@
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
+from typing import Generator
+import numpy as np
import pytest
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
from mlia.nn.tensorflow.config import get_model
from mlia.nn.tensorflow.config import KerasModel
+from mlia.nn.tensorflow.config import ModelConfiguration
from mlia.nn.tensorflow.config import TFLiteModel
from mlia.nn.tensorflow.config import TfModel
+from tests.conftest import create_tfrecord
+
+
+def test_model_configuration(test_keras_model: Path) -> None:
+ """Test ModelConfiguration class."""
+ model = ModelConfiguration(model_path=test_keras_model)
+ assert test_keras_model.match(model.model_path)
+ with pytest.raises(NotImplementedError):
+ model.convert_to_keras("keras_model.h5")
+ with pytest.raises(NotImplementedError):
+ model.convert_to_tflite("model.tflite")
def test_convert_keras_to_tflite(tmp_path: Path, test_keras_model: Path) -> None:
@@ -38,7 +53,7 @@ def test_convert_tf_to_tflite(tmp_path: Path, test_tf_model: Path) -> None:
@pytest.mark.parametrize(
"model_path, expected_type, expected_error",
[
- ("test.tflite", TFLiteModel, does_not_raise()),
+ ("test.tflite", TFLiteModel, pytest.raises(ValueError)),
("test.h5", KerasModel, does_not_raise()),
("test.hdf5", KerasModel, does_not_raise()),
(
@@ -73,3 +88,26 @@ def test_get_model_dir(
"""Test TensorFlow Lite model type."""
model = get_model(str(test_models_path / model_path))
assert isinstance(model, expected_type)
+
+
+@pytest.fixture(scope="session", name="test_tfrecord_fp32_batch_3")
+def fixture_test_tfrecord_fp32_batch_3(
+ tmp_path_factory: pytest.TempPathFactory,
+) -> Generator[Path, None, None]:
+ """Create tfrecord (same as test_tfrecord_fp32) but with batch size 3."""
+
+ def random_data() -> np.ndarray:
+ return np.random.rand(3, 28, 28, 1).astype(np.float32)
+
+ yield from create_tfrecord(tmp_path_factory, random_data)
+
+
+def test_tflite_model_call(
+ test_tflite_model_fp32: Path, test_tfrecord_fp32_batch_3: Path
+) -> None:
+ """Test inference function of class TFLiteModel."""
+ model = TFLiteModel(test_tflite_model_fp32, batch_size=2)
+ data = numpytf_read(test_tfrecord_fp32_batch_3)
+ for named_input in data.as_numpy_iterator():
+ res = model(named_input)
+ assert res
diff --git a/tests/test_nn_tensorflow_tflite_graph.py b/tests/test_nn_tensorflow_tflite_graph.py
index cd1fad6..3512cdd 100644
--- a/tests/test_nn_tensorflow_tflite_graph.py
+++ b/tests/test_nn_tensorflow_tflite_graph.py
@@ -1,15 +1,22 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the tflite_graph module."""
import json
from pathlib import Path
+import pytest
+import tensorflow as tf
+from tensorflow.lite.python.schema_py_generated import ModelT
+
+from mlia.nn.tensorflow.tflite_graph import load_fb
from mlia.nn.tensorflow.tflite_graph import Op
from mlia.nn.tensorflow.tflite_graph import parse_subgraphs
+from mlia.nn.tensorflow.tflite_graph import save_fb
from mlia.nn.tensorflow.tflite_graph import TensorInfo
from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION
from mlia.nn.tensorflow.tflite_graph import TFL_OP
from mlia.nn.tensorflow.tflite_graph import TFL_TYPE
+from tests.utils.rewrite import models_are_equal
def test_tensor_info() -> None:
@@ -79,3 +86,24 @@ def test_parse_subgraphs(test_tflite_model: Path) -> None:
assert TFL_OP[oper.type] in TFL_OP
assert len(oper.inputs) > 0
assert len(oper.outputs) > 0
+
+
+def test_load_save(test_tflite_model: Path, tmp_path: Path) -> None:
+ """Test the load/save functions for TensorFlow Lite models."""
+ with pytest.raises(FileNotFoundError):
+ load_fb("THIS_IS_NOT_A_REAL_FILE")
+
+ model = load_fb(test_tflite_model)
+ assert isinstance(model, ModelT)
+ assert model.subgraphs
+
+ output_file = tmp_path / "test.tflite"
+ assert not output_file.is_file()
+ save_fb(model, output_file)
+ assert output_file.is_file()
+
+ model_copy = load_fb(str(output_file))
+ assert models_are_equal(model, model_copy)
+
+ # Double check that the TensorFlow Lite Interpreter can still load the file.
+ tf.lite.Interpreter(model_path=str(output_file))
diff --git a/tests/utils/rewrite.py b/tests/utils/rewrite.py
index 4264b4b..739bb11 100644
--- a/tests/utils/rewrite.py
+++ b/tests/utils/rewrite.py
@@ -3,8 +3,12 @@
"""Common test utils for the rewrite tests."""
from __future__ import annotations
+from typing import Any
+
from tensorflow.lite.python.schema_py_generated import ModelT
+from mlia.nn.rewrite.core.train import TrainingParameters
+
def models_are_equal(model1: ModelT, model2: ModelT) -> bool:
"""Check that the two models are equal."""
@@ -25,3 +29,17 @@ def models_are_equal(model1: ModelT, model2: ModelT) -> bool:
return False # Tensor from graph1 not found in other graph.")
return True
+
+
+class TestTrainingParameters(
+ TrainingParameters
+): # pylint: disable=too-few-public-methods
+ """
+ TrainingParameter class for rewrites with different default values.
+
+ To speed things up for the unit tests.
+ """
+
+ def __init__(self, *args: Any, steps: int = 32, **kwargs: Any) -> None:
+ """Initialize TrainingParameters with different defaults."""
+ super().__init__(*args, steps=steps, **kwargs) # type: ignore