aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2023-07-19 16:35:57 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-10-11 16:06:17 +0100
commit3cd84481fa25e64c29e57396d4bf32d7a3ca490a (patch)
treead81fb520a965bd3a3c7c983833b7cd48f9b8dea /src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
parentf3e6597dd50ec70f043d692b773f2d9fd31519ae (diff)
downloadmlia-3cd84481fa25e64c29e57396d4bf32d7a3ca490a.tar.gz
Bug-fixes and re-factoring for the rewrite module
- Fix input shape of rewrite replacement: During and after training of the replacement model for a rewrite the Keras model is converted and saved in TensorFlow Lite format. If the input shape does not match the teacher model exactly, e.g. if the batch size is undefined, the TFLiteConverter adds extra operators during conversion. - Fix rewritten model output - Save the model output with the rewritten operator in the output dir - Log MAE and NRMSE of the rewrite - Remove 'verbose' flag from rewrite module and rely on the logging mechanism to control verbose output. - Re-factor utility classes for rewrites - Merge the two TFLiteModel classes - Move functionality to load/save TensorFlow Lite flatbuffers to nn/tensorflow/tflite_graph - Fix issue with unknown shape in datasets After upgrading to TensorFlow 2.12 the unknown shape of the TFRecordDataset is causing problems when training the replacement models for rewrites. By explicitly setting the right shape of the tensors we can work around the issue. - Adapt default parameters for rewrites. The training steps especially had to be increased significantly to be effective. Resolves: MLIA-895, MLIA-907, MLIA-946, MLIA-979 Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com> Change-Id: I887ad165aed0f2c6e5a0041f64cec5e6c5ab5c5c
Diffstat (limited to 'src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py')
-rw-r--r--src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py138
1 files changed, 26 insertions, 112 deletions
diff --git a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
index 9229810..38ac1ed 100644
--- a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
+++ b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
@@ -6,55 +6,56 @@ from __future__ import annotations
import json
import os
import random
-import tempfile
-from collections import defaultdict
+from functools import lru_cache
from pathlib import Path
from typing import Any
from typing import Callable
-import numpy as np
import tensorflow as tf
-from tensorflow.lite.python import interpreter as interpreter_wrapper
-from mlia.nn.rewrite.core.utils.utils import load
-from mlia.nn.rewrite.core.utils.utils import save
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
-def make_decode_fn(filename: str) -> Callable:
- """Make decode filename."""
+def decode_fn(record_bytes: Any, type_map: dict) -> dict:
+ """Decode the given bytes into a name-tensor dict assuming the given type."""
+ parse_dict = {
+ name: tf.io.FixedLenFeature([], tf.string) for name in type_map.keys()
+ }
+ example = tf.io.parse_single_example(record_bytes, parse_dict)
+ features = {
+ n: tf.io.parse_tensor(example[n], tf.as_dtype(t)) for n, t in type_map.items()
+ }
+ return features
- def decode_fn(record_bytes: Any, type_map: dict) -> dict:
- parse_dict = {
- name: tf.io.FixedLenFeature([], tf.string) for name in type_map.keys()
- }
- example = tf.io.parse_single_example(record_bytes, parse_dict)
- features = {
- n: tf.io.parse_tensor(example[n], tf.as_dtype(t))
- for n, t in type_map.items()
- }
- return features
+def make_decode_fn(filename: str, model_filename: str | Path | None = None) -> Callable:
+ """Make decode filename."""
meta_filename = filename + ".meta"
- with open(meta_filename, encoding="utf-8") as file:
- type_map = json.load(file)["type_map"]
+ try:
+ with open(meta_filename, encoding="utf-8") as file:
+ type_map = json.load(file)["type_map"]
return lambda record_bytes: decode_fn(record_bytes, type_map)
def numpytf_read(filename: str | Path) -> Any:
"""Read TFRecord dataset."""
- decode_fn = make_decode_fn(str(filename))
+ decode = make_decode_fn(str(filename))
dataset = tf.data.TFRecordDataset(str(filename))
- return dataset.map(decode_fn)
+ return dataset.map(decode)
-def numpytf_count(filename: str | Path) -> Any:
+@lru_cache
+def numpytf_count(filename: str | Path) -> int:
"""Return count from TFRecord file."""
meta_filename = f"{filename}.meta"
- with open(meta_filename, encoding="utf-8") as file:
- return json.load(file)["count"]
+ try:
+ with open(meta_filename, encoding="utf-8") as file:
+ return int(json.load(file)["count"])
+ except FileNotFoundError:
+ raw_dataset = tf.data.TFRecordDataset(filename)
+ return sum(1 for _ in raw_dataset)
class NumpyTFWriter:
@@ -101,93 +102,6 @@ class NumpyTFWriter:
self.writer.close()
-class TFLiteModel:
- """A representation of a TFLite Model."""
-
- def __init__(
- self,
- filename: str,
- batch_size: int | None = None,
- num_threads: int | None = None,
- ) -> None:
- """Initiate a TFLite Model."""
- if not num_threads:
- num_threads = None
- if not batch_size:
- self.interpreter = interpreter_wrapper.Interpreter(
- model_path=filename, num_threads=num_threads
- )
- else: # if a batch size is specified, modify the TFLite model to use this size
- with tempfile.TemporaryDirectory() as tmp:
- flatbuffer = load(filename)
- for subgraph in flatbuffer.subgraphs:
- for tensor in list(subgraph.inputs) + list(subgraph.outputs):
- subgraph.tensors[tensor].shape = np.array(
- [batch_size] + list(subgraph.tensors[tensor].shape[1:]),
- dtype=np.int32,
- )
- tempname = os.path.join(tmp, "rewrite_tmp.tflite")
- save(flatbuffer, tempname)
- self.interpreter = interpreter_wrapper.Interpreter(
- model_path=tempname, num_threads=num_threads
- )
-
- try:
- self.interpreter.allocate_tensors()
- except RuntimeError:
- self.interpreter = interpreter_wrapper.Interpreter(
- model_path=filename, num_threads=num_threads
- )
- self.interpreter.allocate_tensors()
-
- # Get input and output tensors.
- self.input_details = self.interpreter.get_input_details()
- self.output_details = self.interpreter.get_output_details()
- details = list(self.input_details) + list(self.output_details)
- self.handle_from_name = {d["name"]: d["index"] for d in details}
- self.shape_from_name = {d["name"]: d["shape"] for d in details}
- self.batch_size = next(iter(self.shape_from_name.values()))[0]
-
- def __call__(self, named_input: dict) -> dict:
- """Execute the model on one or a batch of named inputs \
- (a dict of name: numpy array)."""
- input_len = next(iter(named_input.values())).shape[0]
- full_steps = input_len // self.batch_size
- remainder = input_len % self.batch_size
-
- named_ys = defaultdict(list)
- for i in range(full_steps):
- for name, x_batch in named_input.items():
- x_tensor = x_batch[i : i + self.batch_size] # noqa: E203
- self.interpreter.set_tensor(self.handle_from_name[name], x_tensor)
- self.interpreter.invoke()
- for output_detail in self.output_details:
- named_ys[output_detail["name"]].append(
- self.interpreter.get_tensor(output_detail["index"])
- )
- if remainder:
- for name, x_batch in named_input.items():
- x_tensor = np.zeros( # pylint: disable=invalid-name
- self.shape_from_name[name]
- ).astype(x_batch.dtype)
- x_tensor[:remainder] = x_batch[-remainder:]
- self.interpreter.set_tensor(self.handle_from_name[name], x_tensor)
- self.interpreter.invoke()
- for output_detail in self.output_details:
- named_ys[output_detail["name"]].append(
- self.interpreter.get_tensor(output_detail["index"])[:remainder]
- )
- return {k: np.concatenate(v) for k, v in named_ys.items()}
-
- def input_tensors(self) -> list:
- """Return name from input details."""
- return [d["name"] for d in self.input_details]
-
- def output_tensors(self) -> list:
- """Return name from output details."""
- return [d["name"] for d in self.output_details]
-
-
def sample_tfrec(input_file: str, k: int, output_file: str) -> None:
"""Count, read and write TFRecord input and output data."""
total = numpytf_count(input_file)