diff options
author | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2023-07-19 16:35:57 +0100 |
---|---|---|
committer | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2023-10-11 16:06:17 +0100 |
commit | 3cd84481fa25e64c29e57396d4bf32d7a3ca490a (patch) | |
tree | ad81fb520a965bd3a3c7c983833b7cd48f9b8dea /src/mlia/nn/tensorflow/config.py | |
parent | f3e6597dd50ec70f043d692b773f2d9fd31519ae (diff) | |
download | mlia-3cd84481fa25e64c29e57396d4bf32d7a3ca490a.tar.gz |
Bug-fixes and re-factoring for the rewrite module
- Fix input shape of rewrite replacement:
During and after training of the replacement model for a rewrite the
Keras model is converted and saved in TensorFlow Lite format. If the
input shape does not match the teacher model exactly, e.g. if the
batch size is undefined, the TFLiteConverter adds extra operators
during conversion.
- Fix rewritten model output
- Save the model output with the rewritten operator in the output dir
- Log MAE and NRMSE of the rewrite
- Remove 'verbose' flag from rewrite module and rely on the logging
mechanism to control verbose output.
- Re-factor utility classes for rewrites
- Merge the two TFLiteModel classes
- Move functionality to load/save TensorFlow Lite flatbuffers to
nn/tensorflow/tflite_graph
- Fix issue with unknown shape in datasets
After upgrading to TensorFlow 2.12 the unknown shape of the
TFRecordDataset is causing problems when training the replacement models
for rewrites. By explicitly setting the right shape of the tensors we
can work around the issue.
- Adapt default parameters for rewrites. The training steps especially
had to be increased significantly to be effective.
Resolves: MLIA-895, MLIA-907, MLIA-946, MLIA-979
Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
Change-Id: I887ad165aed0f2c6e5a0041f64cec5e6c5ab5c5c
Diffstat (limited to 'src/mlia/nn/tensorflow/config.py')
-rw-r--r-- | src/mlia/nn/tensorflow/config.py | 100 |
1 files changed, 91 insertions, 9 deletions
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py index d7d430f..c6a7c88 100644 --- a/src/mlia/nn/tensorflow/config.py +++ b/src/mlia/nn/tensorflow/config.py @@ -4,13 +4,16 @@ from __future__ import annotations import logging +import tempfile +from collections import defaultdict from pathlib import Path -from typing import cast -from typing import List +import numpy as np import tensorflow as tf from mlia.core.context import Context +from mlia.nn.tensorflow.tflite_graph import load_fb +from mlia.nn.tensorflow.tflite_graph import save_fb from mlia.nn.tensorflow.utils import convert_to_tflite from mlia.nn.tensorflow.utils import is_keras_model from mlia.nn.tensorflow.utils import is_saved_model @@ -71,10 +74,89 @@ class KerasModel(ModelConfiguration): class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method """TensorFlow Lite model configuration.""" - def input_details(self) -> list[dict]: - """Get model's input details.""" - interpreter = tf.lite.Interpreter(model_path=self.model_path) - return cast(List[dict], interpreter.get_input_details()) + def __init__( + self, + model_path: str | Path, + batch_size: int | None = None, + num_threads: int | None = None, + ) -> None: + """Initiate a TFLite Model.""" + super().__init__(model_path) + if not num_threads: + num_threads = None + if not batch_size: + self.interpreter = tf.lite.Interpreter( + model_path=self.model_path, num_threads=num_threads + ) + else: # if a batch size is specified, modify the TFLite model to use this size + with tempfile.TemporaryDirectory() as tmp: + flatbuffer = load_fb(self.model_path) + for subgraph in flatbuffer.subgraphs: + for tensor in list(subgraph.inputs) + list(subgraph.outputs): + subgraph.tensors[tensor].shape = np.array( + [batch_size] + list(subgraph.tensors[tensor].shape[1:]), + dtype=np.int32, + ) + tempname = Path(tmp, "rewrite_tmp.tflite") + save_fb(flatbuffer, tempname) + self.interpreter = tf.lite.Interpreter( + model_path=str(tempname), num_threads=num_threads + ) + + try: + self.interpreter.allocate_tensors() + except RuntimeError: + self.interpreter = tf.lite.Interpreter( + model_path=self.model_path, num_threads=num_threads + ) + self.interpreter.allocate_tensors() + + # Get input and output tensors. + self.input_details = self.interpreter.get_input_details() + self.output_details = self.interpreter.get_output_details() + details = list(self.input_details) + list(self.output_details) + self.handle_from_name = {d["name"]: d["index"] for d in details} + self.shape_from_name = {d["name"]: d["shape"] for d in details} + self.batch_size = next(iter(self.shape_from_name.values()))[0] + + def __call__(self, named_input: dict) -> dict: + """Execute the model on one or a batch of named inputs \ + (a dict of name: numpy array).""" + input_len = next(iter(named_input.values())).shape[0] + full_steps = input_len // self.batch_size + remainder = input_len % self.batch_size + + named_ys = defaultdict(list) + for i in range(full_steps): + for name, x_batch in named_input.items(): + x_tensor = x_batch[i : i + self.batch_size] # noqa: E203 + self.interpreter.set_tensor(self.handle_from_name[name], x_tensor) + self.interpreter.invoke() + for output_detail in self.output_details: + named_ys[output_detail["name"]].append( + self.interpreter.get_tensor(output_detail["index"]) + ) + if remainder: + for name, x_batch in named_input.items(): + x_tensor = np.zeros( # pylint: disable=invalid-name + self.shape_from_name[name] + ).astype(x_batch.dtype) + x_tensor[:remainder] = x_batch[-remainder:] + self.interpreter.set_tensor(self.handle_from_name[name], x_tensor) + self.interpreter.invoke() + for output_detail in self.output_details: + named_ys[output_detail["name"]].append( + self.interpreter.get_tensor(output_detail["index"])[:remainder] + ) + return {k: np.concatenate(v) for k, v in named_ys.items()} + + def input_tensors(self) -> list: + """Return name from input details.""" + return [d["name"] for d in self.input_details] + + def output_tensors(self) -> list: + """Return name from output details.""" + return [d["name"] for d in self.output_details] def convert_to_tflite( self, tflite_model_path: str | Path, quantized: bool = False @@ -118,10 +200,10 @@ def get_model(model: str | Path) -> ModelConfiguration: def get_tflite_model(model: str | Path, ctx: Context) -> TFLiteModel: """Convert input model to TensorFlow Lite and returns TFLiteModel object.""" - tflite_model_path = ctx.get_model_path("converted_model.tflite") - converted_model = get_model(model) + dst_model_path = ctx.get_model_path("converted_model.tflite") + src_model = get_model(model) - return converted_model.convert_to_tflite(tflite_model_path, True) + return src_model.convert_to_tflite(dst_model_path, quantized=True) def get_keras_model(model: str | Path, ctx: Context) -> KerasModel: |