aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite')
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/cut.py8
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/join.py10
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py77
-rw-r--r--src/mlia/nn/rewrite/core/train.py306
-rw-r--r--src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py138
-rw-r--r--src/mlia/nn/rewrite/core/utils/parallel.py4
-rw-r--r--src/mlia/nn/rewrite/core/utils/utils.py32
-rw-r--r--src/mlia/nn/rewrite/library/fc_layer.py14
8 files changed, 267 insertions, 322 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/cut.py b/src/mlia/nn/rewrite/core/graph_edit/cut.py
index 2707eb1..13a5268 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/cut.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/cut.py
@@ -9,8 +9,8 @@ import tensorflow as tf
from tensorflow.lite.python.schema_py_generated import ModelT
from tensorflow.lite.python.schema_py_generated import SubGraphT
-from mlia.nn.rewrite.core.utils.utils import load
-from mlia.nn.rewrite.core.utils.utils import save
+from mlia.nn.tensorflow.tflite_graph import load_fb
+from mlia.nn.tensorflow.tflite_graph import save_fb
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
@@ -138,8 +138,8 @@ def cut_model(
output_file: str,
) -> None:
"""Cut subgraphs and simplify a given model."""
- model = load(model_file)
+ model = load_fb(model_file)
subgraph = model.subgraphs[subgraph_index]
cut_subgraph(subgraph, input_names, output_names)
simplify(model)
- save(model, output_file)
+ save_fb(model, output_file)
diff --git a/src/mlia/nn/rewrite/core/graph_edit/join.py b/src/mlia/nn/rewrite/core/graph_edit/join.py
index 2530ec8..70109d8 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/join.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/join.py
@@ -11,8 +11,8 @@ from tensorflow.lite.python.schema_py_generated import ModelT
from tensorflow.lite.python.schema_py_generated import OperatorCodeT
from tensorflow.lite.python.schema_py_generated import SubGraphT
-from mlia.nn.rewrite.core.utils.utils import load
-from mlia.nn.rewrite.core.utils.utils import save
+from mlia.nn.tensorflow.tflite_graph import load_fb
+from mlia.nn.tensorflow.tflite_graph import save_fb
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
@@ -26,12 +26,12 @@ def join_models(
subgraph_dst: int = 0,
) -> None:
"""Join two models and save the result into a given model file path."""
- src_model = load(input_src)
- dst_model = load(input_dst)
+ src_model = load_fb(input_src)
+ dst_model = load_fb(input_dst)
src_subgraph = src_model.subgraphs[subgraph_src]
dst_subgraph = dst_model.subgraphs[subgraph_dst]
join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph)
- save(dst_model, output_file)
+ save_fb(dst_model, output_file)
def join_subgraphs(
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
index 0d182df..6b27984 100644
--- a/src/mlia/nn/rewrite/core/rewrite.py
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -4,6 +4,7 @@
from __future__ import annotations
import importlib
+import logging
import tempfile
from dataclasses import dataclass
from pathlib import Path
@@ -12,13 +13,14 @@ from typing import Any
from mlia.core.errors import ConfigurationError
from mlia.nn.common import Optimizer
from mlia.nn.common import OptimizerConfiguration
-from mlia.nn.rewrite.core.train import eval_in_dir
-from mlia.nn.rewrite.core.train import join_in_dir
from mlia.nn.rewrite.core.train import train
-from mlia.nn.rewrite.core.train import train_in_dir
+from mlia.nn.rewrite.core.train import TrainingParameters
from mlia.nn.tensorflow.config import TFLiteModel
+logger = logging.getLogger(__name__)
+
+
@dataclass
class RewriteConfiguration(OptimizerConfiguration):
"""Rewrite configuration."""
@@ -26,6 +28,7 @@ class RewriteConfiguration(OptimizerConfiguration):
optimization_target: str
layers_to_optimize: list[str] | None = None
dataset: Path | None = None
+ train_params: TrainingParameters = TrainingParameters()
def __str__(self) -> str:
"""Return string representation of the configuration."""
@@ -40,8 +43,8 @@ class Rewriter(Optimizer):
):
"""Init Rewriter instance."""
self.model = TFLiteModel(tflite_model_path)
+ self.model_path = tflite_model_path
self.optimizer_configuration = optimizer_configuration
- self.train_dir = ""
def apply_optimization(self) -> None:
"""Apply the rewrite flow."""
@@ -61,50 +64,36 @@ class Rewriter(Optimizer):
replace_fn = get_function(replace_function)
- augmentation_preset = (None, None)
use_unmodified_model = True
tflite_model = self.model.model_path
tfrecord = str(self.optimizer_configuration.dataset)
- with tempfile.TemporaryDirectory() as tmp_dir:
- tmp_output = Path(tmp_dir, "output.tflite")
-
- if self.train_dir:
- tmp_new = Path(tmp_dir, "new.tflite")
- new_part = train_in_dir(
- train_dir=self.train_dir,
- baseline_dir=None,
- output_filename=tmp_new,
- replace_fn=replace_fn,
- augmentations=augmentation_preset,
- steps=32,
- learning_rate=1e-3,
- batch_size=1,
- verbose=True,
- show_progress=True,
- )
- eval_in_dir(self.train_dir, new_part[0])
- join_in_dir(self.train_dir, new_part[0], str(tmp_output))
- else:
- if not self.optimizer_configuration.layers_to_optimize:
- raise ConfigurationError(
- "Input and output tensor names need to be set for rewrite."
- )
- train(
- source_model=tflite_model,
- unmodified_model=tflite_model if use_unmodified_model else None,
- output_model=str(tmp_output),
- input_tfrec=str(tfrecord),
- replace_fn=replace_fn,
- input_tensors=[self.optimizer_configuration.layers_to_optimize[0]],
- output_tensors=[self.optimizer_configuration.layers_to_optimize[1]],
- augment=augmentation_preset,
- steps=32,
- learning_rate=1e-3,
- batch_size=1,
- verbose=True,
- show_progress=True,
- )
+ tmp_dir = tempfile.mkdtemp()
+ tmp_output = Path(tmp_dir, "output.tflite")
+
+ if not self.optimizer_configuration.layers_to_optimize:
+ raise ConfigurationError(
+ "Input and output tensor names need to be set for rewrite."
+ )
+ result = train(
+ source_model=tflite_model,
+ unmodified_model=tflite_model if use_unmodified_model else None,
+ output_model=str(tmp_output),
+ input_tfrec=str(tfrecord),
+ replace_fn=replace_fn,
+ input_tensors=[self.optimizer_configuration.layers_to_optimize[0]],
+ output_tensors=[self.optimizer_configuration.layers_to_optimize[1]],
+ train_params=self.optimizer_configuration.train_params,
+ )
+
+ self.model = TFLiteModel(tmp_output)
+
+ if result:
+ stats_as_str = ", ".join(str(stats) for stats in result)
+ logger.info(
+ "The MAE and NRMSE between original and replacement [%s]",
+ stats_as_str,
+ )
def get_model(self) -> TFLiteModel:
"""Return optimized model."""
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index c8497a4..42bf653 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -1,8 +1,8 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Sequential trainer."""
-# pylint: disable=too-many-arguments, too-many-instance-attributes,
-# pylint: disable=too-many-locals, too-many-branches, too-many-statements
+# pylint: disable=too-many-locals
+# pylint: disable=too-many-statements
from __future__ import annotations
import logging
@@ -10,10 +10,13 @@ import math
import os
import tempfile
from collections import defaultdict
+from contextlib import contextmanager
+from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import Callable
from typing import cast
+from typing import Generator as GeneratorType
from typing import get_args
from typing import Literal
@@ -27,10 +30,10 @@ from mlia.nn.rewrite.core.graph_edit.join import join_models
from mlia.nn.rewrite.core.graph_edit.record import record_model
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
-from mlia.nn.rewrite.core.utils.numpy_tfrecord import TFLiteModel
from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel
-from mlia.nn.rewrite.core.utils.utils import load
-from mlia.nn.rewrite.core.utils.utils import save
+from mlia.nn.tensorflow.config import TFLiteModel
+from mlia.nn.tensorflow.tflite_graph import load_fb
+from mlia.nn.tensorflow.tflite_graph import save_fb
from mlia.utils.logging import log_action
@@ -38,7 +41,7 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
logger = logging.getLogger(__name__)
-augmentation_presets = {
+AUGMENTATION_PRESETS = {
"none": (None, None),
"gaussian": (None, 1.0),
"mixup": (1.0, None),
@@ -51,6 +54,21 @@ LearningRateSchedule = Literal["cosine", "late", "constant"]
LEARNING_RATE_SCHEDULES = get_args(LearningRateSchedule)
+@dataclass
+class TrainingParameters:
+ """Define default parameters for the training."""
+
+ augmentations: tuple[float | None, float | None] = AUGMENTATION_PRESETS["gaussian"]
+ batch_size: int = 32
+ steps: int = 48000
+ learning_rate: float = 1e-3
+ learning_rate_schedule: LearningRateSchedule = "cosine"
+ num_procs: int = 1
+ num_threads: int = 0
+ show_progress: bool = True
+ checkpoint_at: list | None = None
+
+
def train(
source_model: str,
unmodified_model: Any,
@@ -59,16 +77,7 @@ def train(
replace_fn: Callable,
input_tensors: list,
output_tensors: list,
- augment: tuple[float | None, float | None],
- steps: int,
- learning_rate: float,
- batch_size: int,
- verbose: bool,
- show_progress: bool,
- learning_rate_schedule: LearningRateSchedule = "cosine",
- checkpoint_at: list | None = None,
- num_procs: int = 1,
- num_threads: int = 0,
+ train_params: TrainingParameters = TrainingParameters(),
) -> Any:
"""Extract and train a model, and return the results."""
if unmodified_model:
@@ -95,29 +104,27 @@ def train(
input_tfrec,
input_tensors,
output_tensors,
- num_procs=num_procs,
- num_threads=num_threads,
+ num_procs=train_params.num_procs,
+ num_threads=train_params.num_threads,
)
tflite_filenames = train_in_dir(
- train_dir,
- unmodified_model_dir_path,
- Path(train_dir, "new.tflite"),
- replace_fn,
- augment,
- steps,
- learning_rate,
- batch_size,
- checkpoint_at=checkpoint_at,
- verbose=verbose,
- show_progress=show_progress,
- num_procs=num_procs,
- num_threads=num_threads,
- schedule=learning_rate_schedule,
+ train_dir=train_dir,
+ baseline_dir=unmodified_model_dir_path,
+ output_filename=Path(train_dir, "new.tflite"),
+ replace_fn=replace_fn,
+ train_params=train_params,
)
for i, filename in enumerate(tflite_filenames):
- results.append(eval_in_dir(train_dir, filename, num_procs, num_threads))
+ results.append(
+ eval_in_dir(
+ train_dir,
+ filename,
+ train_params.num_procs,
+ train_params.num_threads,
+ )
+ )
if output_model:
if i + 1 < len(tflite_filenames):
@@ -133,7 +140,7 @@ def train(
cast(tempfile.TemporaryDirectory, unmodified_model_dir).cleanup()
return (
- results if checkpoint_at else results[0]
+ results if train_params.checkpoint_at else results[0]
) # only return a list if multiple checkpoints are asked for
@@ -176,46 +183,24 @@ def join_in_dir(model_dir: str, new_part: str, output_model: str) -> None:
join_models(Path(model_dir, "start.tflite"), new_end, output_model)
-def train_in_dir(
- train_dir: str,
- baseline_dir: Any,
- output_filename: Path,
- replace_fn: Callable,
- augmentations: tuple[float | None, float | None],
- steps: int,
- learning_rate: float = 1e-3,
- batch_size: int = 32,
- checkpoint_at: list | None = None,
- schedule: str = "cosine",
- verbose: bool = False,
- show_progress: bool = False,
- num_procs: int = 0,
- num_threads: int = 1,
-) -> list:
- """Train a replacement for replace.tflite using the input.tfrec \
- and output.tfrec in train_dir.
-
- If baseline_dir is provided, train the replacement to match baseline
- outputs for train_dir inputs. Result saved as new.tflite in train_dir.
- """
- teacher_dir = baseline_dir if baseline_dir else train_dir
- teacher = ParallelTFLiteModel(
- f"{teacher_dir}/replace.tflite", num_procs, num_threads, batch_size=batch_size
- )
- replace = TFLiteModel(f"{train_dir}/replace.tflite")
+def _get_io_tensors(model: TFLiteModel) -> tuple[str, str]:
assert (
- len(teacher.input_tensors()) == 1
+ len(model.input_tensors()) == 1
), f"Can only train replacements with a single input tensor right now, \
- found {teacher.input_tensors()}"
+ found {model.input_tensors()}"
assert (
- len(teacher.output_tensors()) == 1
+ len(model.output_tensors()) == 1
), f"Can only train replacements with a single output tensor right now, \
- found {teacher.output_tensors()}"
+ found {model.output_tensors()}"
+
+ input_name = model.input_tensors()[0]
+ output_name = model.output_tensors()[0]
+ return (input_name, output_name)
- input_name = teacher.input_tensors()[0]
- output_name = teacher.output_tensors()[0]
+def _check_model_compatibility(teacher: TFLiteModel, replace: TFLiteModel) -> None:
+ """Assert that teacher and replaced sub-graph are compatible."""
assert len(teacher.shape_from_name) == len(
replace.shape_from_name
), f"Baseline and train models must have the same number of inputs and outputs. \
@@ -230,10 +215,37 @@ def train_in_dir(
subgraph being replaced. Teacher: {teacher.shape_from_name}\n \
Train dir: {replace.shape_from_name}"
+
+def set_up_data_pipeline(
+ teacher: TFLiteModel,
+ replace: TFLiteModel,
+ train_dir: str,
+ augmentations: tuple[float | None, float | None],
+ steps: int,
+ batch_size: int = 32,
+) -> tf.data.Dataset:
+ """Create a data pipeline for training of the replacement model."""
+ _check_model_compatibility(teacher, replace)
+
+ input_name, output_name = _get_io_tensors(teacher)
+
input_filename = Path(train_dir, "input.tfrec")
total = numpytf_count(str(input_filename))
dict_inputs = numpytf_read(str(input_filename))
+
inputs = dict_inputs.map(lambda d: tf.squeeze(d[input_name], axis=0))
+
+ steps_per_epoch = math.ceil(total / batch_size)
+ epochs = int(math.ceil(steps / steps_per_epoch))
+ logger.info(
+ "Training on %d items for %d steps (%d epochs with batch size %d)",
+ total,
+ epochs * steps_per_epoch,
+ epochs,
+ batch_size,
+ )
+
+ teacher_dir = Path(teacher.model_path).parent
if any(augmentations):
# Map the teacher inputs here because the augmentation stage passes these
# through a TFLite model to get the outputs
@@ -245,17 +257,6 @@ def train_in_dir(
lambda d: tf.squeeze(d[output_name], axis=0)
)
- steps_per_epoch = math.ceil(total / batch_size)
- epochs = int(math.ceil(steps / steps_per_epoch))
- if verbose:
- logger.info(
- "Training on %d items for %d steps (%d epochs with batch size %d)",
- total,
- epochs * steps_per_epoch,
- epochs,
- batch_size,
- )
-
dataset = tf.data.Dataset.zip((inputs, teacher_outputs))
if epochs > 1:
dataset = dataset.cache()
@@ -268,10 +269,9 @@ def train_in_dir(
train: Any, teach: Any # pylint: disable=redefined-outer-name
) -> tuple:
"""Return results of train and teach based on augmentations."""
- return (
- augment_train({input_name: train})[input_name],
- teacher(augment_teacher({input_name: teach}))[output_name],
- )
+ augmented_train = augment_train({input_name: train})[input_name]
+ augmented_teach = teacher(augment_teacher({input_name: teach}))[output_name]
+ return (augmented_train, augmented_teach)
dataset = dataset.map(
lambda augment_train, augment_teach: tf.py_function(
@@ -281,18 +281,67 @@ def train_in_dir(
)
)
+ # Restore data shapes of the dataset as they are set to unknown per default
+ # and get lost during augmentation with tf.py_function.
+ shape_in, shape_out = (
+ teacher.shape_from_name[name].tolist() for name in (input_name, output_name)
+ )
+ for shape in (shape_in, shape_out):
+ shape[0] = None # set dynamic batch size
+
+ def restore_shapes(input_: Any, output: Any) -> tuple[Any, Any]:
+ input_.set_shape(shape_in)
+ output.set_shape(shape_out)
+ return input_, output
+
+ dataset = dataset.map(restore_shapes)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
+ return dataset
+
+
+def train_in_dir(
+ train_dir: str,
+ baseline_dir: Any,
+ output_filename: Path,
+ replace_fn: Callable,
+ train_params: TrainingParameters = TrainingParameters(),
+) -> list[str]:
+ """Train a replacement for replace.tflite using the input.tfrec \
+ and output.tfrec in train_dir.
+
+ If baseline_dir is provided, train the replacement to match baseline
+ outputs for train_dir inputs. Result saved as new.tflite in train_dir.
+ """
+ teacher_dir = baseline_dir if baseline_dir else train_dir
+ teacher = ParallelTFLiteModel(
+ f"{teacher_dir}/replace.tflite",
+ train_params.num_procs,
+ train_params.num_threads,
+ batch_size=train_params.batch_size,
+ )
+ replace = TFLiteModel(f"{train_dir}/replace.tflite")
+
+ input_name, output_name = _get_io_tensors(teacher)
+
+ dataset = set_up_data_pipeline(
+ teacher,
+ replace,
+ train_dir,
+ augmentations=train_params.augmentations,
+ steps=train_params.steps,
+ batch_size=train_params.batch_size,
+ )
input_shape = teacher.shape_from_name[input_name][1:]
output_shape = teacher.shape_from_name[output_name][1:]
+
model = replace_fn(input_shape, output_shape)
- optimizer = tf.keras.optimizers.Nadam(learning_rate=learning_rate)
+ optimizer = tf.keras.optimizers.Nadam(learning_rate=train_params.learning_rate)
loss_fn = tf.keras.losses.MeanSquaredError()
- model.compile(optimizer=optimizer, loss=loss_fn)
+ model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"])
- if verbose:
- model.summary()
+ logger.info(model.summary())
steps_so_far = 0
@@ -302,7 +351,9 @@ def train_in_dir(
"""Cosine decay from learning rate at start of the run to zero at the end."""
current_step = epoch_step + steps_so_far
cd_learning_rate = (
- learning_rate * (math.cos(math.pi * current_step / steps) + 1) / 2.0
+ train_params.learning_rate
+ * (math.cos(math.pi * current_step / train_params.steps) + 1)
+ / 2.0
)
tf.keras.backend.set_value(optimizer.learning_rate, cd_learning_rate)
@@ -311,28 +362,29 @@ def train_in_dir(
) -> None:
"""Constant until the last 20% of the run, then linear decay to zero."""
current_step = epoch_step + steps_so_far
- steps_remaining = steps - current_step
- decay_length = steps // 5
+ steps_remaining = train_params.steps - current_step
+ decay_length = train_params.steps // 5
decay_fraction = min(steps_remaining, decay_length) / decay_length
- ld_learning_rate = learning_rate * decay_fraction
+ ld_learning_rate = train_params.learning_rate * decay_fraction
tf.keras.backend.set_value(optimizer.learning_rate, ld_learning_rate)
- if schedule == "cosine":
+ assert train_params.learning_rate_schedule in LEARNING_RATE_SCHEDULES, (
+ f'Learning rate schedule "{train_params.learning_rate_schedule}" '
+ f"not implemented - expected one of {LEARNING_RATE_SCHEDULES}."
+ )
+ if train_params.learning_rate_schedule == "cosine":
callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=cosine_decay)]
- elif schedule == "late":
+ elif train_params.learning_rate_schedule == "late":
callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=late_decay)]
- elif schedule == "constant":
+ elif train_params.learning_rate_schedule == "constant":
callbacks = []
- else:
- assert schedule not in LEARNING_RATE_SCHEDULES
- raise ValueError(
- f'Learning rate schedule "{schedule}" not implemented - '
- f"expected one of {LEARNING_RATE_SCHEDULES}."
- )
output_filenames = []
- checkpoints = (checkpoint_at if checkpoint_at else []) + [steps]
- while steps_so_far < steps:
+ checkpoints = (train_params.checkpoint_at if train_params.checkpoint_at else []) + [
+ train_params.steps
+ ]
+
+ while steps_so_far < train_params.steps:
steps_to_train = checkpoints.pop(0) - steps_so_far
lr_start = optimizer.learning_rate.numpy()
model.fit(
@@ -340,7 +392,7 @@ def train_in_dir(
epochs=1,
steps_per_epoch=steps_to_train,
callbacks=callbacks,
- verbose=show_progress,
+ verbose=train_params.show_progress,
)
steps_so_far += steps_to_train
logger.info(
@@ -350,12 +402,14 @@ def train_in_dir(
steps_to_train,
)
- if steps_so_far < steps:
+ if steps_so_far < train_params.steps:
filename, ext = Path(output_filename).parts[1:]
checkpoint_filename = filename + (f"_@{steps_so_far}") + ext
else:
checkpoint_filename = str(output_filename)
- with log_action(f"{steps_so_far}/{steps}: Saved as {checkpoint_filename}"):
+ with log_action(
+ f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}"
+ ):
save_as_tflite(
model,
checkpoint_filename,
@@ -379,14 +433,30 @@ def save_as_tflite(
output_shape: list,
) -> None:
"""Save Keras model as TFLite file."""
- converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
+
+ @contextmanager
+ def fixed_input(keras_model: tf.keras.Model, tmp_shape: list) -> GeneratorType:
+ """Fix the input shape of the Keras model temporarily.
+
+ This avoids artifacts during conversion to TensorFlow Lite.
+ """
+ orig_shape = keras_model.input.shape
+ keras_model.input.set_shape(tf.TensorShape(tmp_shape))
+ try:
+ yield keras_model
+ finally:
+ # Restore original shape to not interfere with further training
+ keras_model.input.set_shape(orig_shape)
+
+ with fixed_input(keras_model, input_shape) as fixed_model:
+ converter = tf.lite.TFLiteConverter.from_keras_model(fixed_model)
tflite_model = converter.convert()
with open(filename, "wb") as file:
file.write(tflite_model)
# Now fix the shapes and names to match those we expect
- flatbuffer = load(filename)
+ flatbuffer = load_fb(filename)
i = flatbuffer.subgraphs[0].inputs[0]
flatbuffer.subgraphs[0].tensors[i].shape = np.array(input_shape, dtype=np.int32)
flatbuffer.subgraphs[0].tensors[i].name = input_name.encode("utf-8")
@@ -395,11 +465,11 @@ def save_as_tflite(
output_shape, dtype=np.int32
)
flatbuffer.subgraphs[0].tensors[output].name = output_name.encode("utf-8")
- save(flatbuffer, filename)
+ save_fb(flatbuffer, filename)
def augment_fn_twins(
- inputs: dict, augmentations: tuple[float | None, float | None]
+ inputs: tf.data.Dataset, augmentations: tuple[float | None, float | None]
) -> Any:
"""Return a pair of twinned augmentation functions with the same sequence \
of random numbers."""
@@ -415,6 +485,11 @@ def augment_fn(
inputs: Any, augmentations: tuple[float | None, float | None], rng: Generator
) -> Any:
"""Augmentation module."""
+ assert len(augmentations) == 2, (
+ f"Unexpected number of augmentation parameters: {len(augmentations)} "
+ "(must be 2)"
+ )
+
mixup_strength, gaussian_strength = augmentations
augments = []
@@ -449,17 +524,16 @@ def augment_fn(
augments.append(gaussian_strength_augment)
- if len(augments) == 0: # pylint: disable=no-else-return
+ if len(augments) == 0:
return lambda x: x
- elif len(augments) == 1:
+ if len(augments) == 1:
return augments[0]
- elif len(augments) == 2:
+ if len(augments) == 2:
return lambda x: augments[1](augments[0](x))
- else:
- assert (
- False
- ), f"Unexpected number of augmentation \
- functions ({len(augments)})"
+
+ raise RuntimeError(
+ "Unexpected number of augmentation functions (must be <=2): " f"{len(augments)}"
+ )
def mixup(rng: Generator, batch: Any, beta_range: tuple = (0.0, 1.0)) -> Any:
diff --git a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
index 9229810..38ac1ed 100644
--- a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
+++ b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
@@ -6,55 +6,56 @@ from __future__ import annotations
import json
import os
import random
-import tempfile
-from collections import defaultdict
+from functools import lru_cache
from pathlib import Path
from typing import Any
from typing import Callable
-import numpy as np
import tensorflow as tf
-from tensorflow.lite.python import interpreter as interpreter_wrapper
-from mlia.nn.rewrite.core.utils.utils import load
-from mlia.nn.rewrite.core.utils.utils import save
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
-def make_decode_fn(filename: str) -> Callable:
- """Make decode filename."""
+def decode_fn(record_bytes: Any, type_map: dict) -> dict:
+ """Decode the given bytes into a name-tensor dict assuming the given type."""
+ parse_dict = {
+ name: tf.io.FixedLenFeature([], tf.string) for name in type_map.keys()
+ }
+ example = tf.io.parse_single_example(record_bytes, parse_dict)
+ features = {
+ n: tf.io.parse_tensor(example[n], tf.as_dtype(t)) for n, t in type_map.items()
+ }
+ return features
- def decode_fn(record_bytes: Any, type_map: dict) -> dict:
- parse_dict = {
- name: tf.io.FixedLenFeature([], tf.string) for name in type_map.keys()
- }
- example = tf.io.parse_single_example(record_bytes, parse_dict)
- features = {
- n: tf.io.parse_tensor(example[n], tf.as_dtype(t))
- for n, t in type_map.items()
- }
- return features
+def make_decode_fn(filename: str, model_filename: str | Path | None = None) -> Callable:
+ """Make decode filename."""
meta_filename = filename + ".meta"
- with open(meta_filename, encoding="utf-8") as file:
- type_map = json.load(file)["type_map"]
+ try:
+ with open(meta_filename, encoding="utf-8") as file:
+ type_map = json.load(file)["type_map"]
return lambda record_bytes: decode_fn(record_bytes, type_map)
def numpytf_read(filename: str | Path) -> Any:
"""Read TFRecord dataset."""
- decode_fn = make_decode_fn(str(filename))
+ decode = make_decode_fn(str(filename))
dataset = tf.data.TFRecordDataset(str(filename))
- return dataset.map(decode_fn)
+ return dataset.map(decode)
-def numpytf_count(filename: str | Path) -> Any:
+@lru_cache
+def numpytf_count(filename: str | Path) -> int:
"""Return count from TFRecord file."""
meta_filename = f"{filename}.meta"
- with open(meta_filename, encoding="utf-8") as file:
- return json.load(file)["count"]
+ try:
+ with open(meta_filename, encoding="utf-8") as file:
+ return int(json.load(file)["count"])
+ except FileNotFoundError:
+ raw_dataset = tf.data.TFRecordDataset(filename)
+ return sum(1 for _ in raw_dataset)
class NumpyTFWriter:
@@ -101,93 +102,6 @@ class NumpyTFWriter:
self.writer.close()
-class TFLiteModel:
- """A representation of a TFLite Model."""
-
- def __init__(
- self,
- filename: str,
- batch_size: int | None = None,
- num_threads: int | None = None,
- ) -> None:
- """Initiate a TFLite Model."""
- if not num_threads:
- num_threads = None
- if not batch_size:
- self.interpreter = interpreter_wrapper.Interpreter(
- model_path=filename, num_threads=num_threads
- )
- else: # if a batch size is specified, modify the TFLite model to use this size
- with tempfile.TemporaryDirectory() as tmp:
- flatbuffer = load(filename)
- for subgraph in flatbuffer.subgraphs:
- for tensor in list(subgraph.inputs) + list(subgraph.outputs):
- subgraph.tensors[tensor].shape = np.array(
- [batch_size] + list(subgraph.tensors[tensor].shape[1:]),
- dtype=np.int32,
- )
- tempname = os.path.join(tmp, "rewrite_tmp.tflite")
- save(flatbuffer, tempname)
- self.interpreter = interpreter_wrapper.Interpreter(
- model_path=tempname, num_threads=num_threads
- )
-
- try:
- self.interpreter.allocate_tensors()
- except RuntimeError:
- self.interpreter = interpreter_wrapper.Interpreter(
- model_path=filename, num_threads=num_threads
- )
- self.interpreter.allocate_tensors()
-
- # Get input and output tensors.
- self.input_details = self.interpreter.get_input_details()
- self.output_details = self.interpreter.get_output_details()
- details = list(self.input_details) + list(self.output_details)
- self.handle_from_name = {d["name"]: d["index"] for d in details}
- self.shape_from_name = {d["name"]: d["shape"] for d in details}
- self.batch_size = next(iter(self.shape_from_name.values()))[0]
-
- def __call__(self, named_input: dict) -> dict:
- """Execute the model on one or a batch of named inputs \
- (a dict of name: numpy array)."""
- input_len = next(iter(named_input.values())).shape[0]
- full_steps = input_len // self.batch_size
- remainder = input_len % self.batch_size
-
- named_ys = defaultdict(list)
- for i in range(full_steps):
- for name, x_batch in named_input.items():
- x_tensor = x_batch[i : i + self.batch_size] # noqa: E203
- self.interpreter.set_tensor(self.handle_from_name[name], x_tensor)
- self.interpreter.invoke()
- for output_detail in self.output_details:
- named_ys[output_detail["name"]].append(
- self.interpreter.get_tensor(output_detail["index"])
- )
- if remainder:
- for name, x_batch in named_input.items():
- x_tensor = np.zeros( # pylint: disable=invalid-name
- self.shape_from_name[name]
- ).astype(x_batch.dtype)
- x_tensor[:remainder] = x_batch[-remainder:]
- self.interpreter.set_tensor(self.handle_from_name[name], x_tensor)
- self.interpreter.invoke()
- for output_detail in self.output_details:
- named_ys[output_detail["name"]].append(
- self.interpreter.get_tensor(output_detail["index"])[:remainder]
- )
- return {k: np.concatenate(v) for k, v in named_ys.items()}
-
- def input_tensors(self) -> list:
- """Return name from input details."""
- return [d["name"] for d in self.input_details]
-
- def output_tensors(self) -> list:
- """Return name from output details."""
- return [d["name"] for d in self.output_details]
-
-
def sample_tfrec(input_file: str, k: int, output_file: str) -> None:
"""Count, read and write TFRecord input and output data."""
total = numpytf_count(input_file)
diff --git a/src/mlia/nn/rewrite/core/utils/parallel.py b/src/mlia/nn/rewrite/core/utils/parallel.py
index d930a1e..b7b390d 100644
--- a/src/mlia/nn/rewrite/core/utils/parallel.py
+++ b/src/mlia/nn/rewrite/core/utils/parallel.py
@@ -15,14 +15,14 @@ from typing import Any
import numpy as np
import tensorflow as tf
-from mlia.nn.rewrite.core.utils.numpy_tfrecord import TFLiteModel
+from mlia.nn.tensorflow.config import TFLiteModel
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
logger = logging.getLogger(__name__)
-class ParallelTFLiteModel(TFLiteModel):
+class ParallelTFLiteModel(TFLiteModel): # pylint: disable=abstract-method
"""A parallel version of a TFLiteModel.
num_procs: 0 => detect real cores on system
diff --git a/src/mlia/nn/rewrite/core/utils/utils.py b/src/mlia/nn/rewrite/core/utils/utils.py
deleted file mode 100644
index ddf0cc2..0000000
--- a/src/mlia/nn/rewrite/core/utils/utils.py
+++ /dev/null
@@ -1,32 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Model and file system utilites."""
-from __future__ import annotations
-
-from pathlib import Path
-
-import flatbuffers
-from tensorflow.lite.python.schema_py_generated import Model
-from tensorflow.lite.python.schema_py_generated import ModelT
-
-
-def load(input_tflite_file: str | Path) -> ModelT:
- """Load a flatbuffer model from file."""
- if not Path(input_tflite_file).exists():
- raise FileNotFoundError(f"TFLite file not found at {input_tflite_file}\n")
- with open(input_tflite_file, "rb") as file_handle:
- file_data = bytearray(file_handle.read())
- model_obj = Model.GetRootAsModel(file_data, 0)
- model = ModelT.InitFromObj(model_obj)
- return model
-
-
-def save(model: ModelT, output_tflite_file: str | Path) -> None:
- """Save a flatbuffer model to a given file."""
- builder = flatbuffers.Builder(1024) # Initial size of the buffer, which
- # will grow automatically if needed
- model_offset = model.Pack(builder)
- builder.Finish(model_offset, file_identifier=b"TFL3")
- model_data = builder.Output()
- with open(output_tflite_file, "wb") as out_file:
- out_file.write(model_data)
diff --git a/src/mlia/nn/rewrite/library/fc_layer.py b/src/mlia/nn/rewrite/library/fc_layer.py
index 8704154..2480500 100644
--- a/src/mlia/nn/rewrite/library/fc_layer.py
+++ b/src/mlia/nn/rewrite/library/fc_layer.py
@@ -7,12 +7,12 @@ import tensorflow as tf
def get_keras_model(input_shape: Any, output_shape: Any) -> tf.keras.Model:
- """Generate tflite model for rewrite."""
- input_tensor = tf.keras.layers.Input(
- shape=input_shape, name="MbileNet/avg_pool/AvgPool"
+ """Generate TensorFlow Lite model for rewrite."""
+ model = tf.keras.Sequential(
+ (
+ tf.keras.layers.InputLayer(input_shape=input_shape),
+ tf.keras.layers.Reshape([-1]),
+ tf.keras.layers.Dense(output_shape),
+ )
)
- output_tensor = tf.keras.layers.Dense(output_shape, name="MobileNet/fc1/BiasAdd")(
- input_tensor
- )
- model = tf.keras.Model(input_tensor, output_tensor)
return model