aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/utils
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/utils')
-rw-r--r--src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py138
-rw-r--r--src/mlia/nn/rewrite/core/utils/parallel.py4
-rw-r--r--src/mlia/nn/rewrite/core/utils/utils.py32
3 files changed, 28 insertions, 146 deletions
diff --git a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
index 9229810..38ac1ed 100644
--- a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
+++ b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
@@ -6,55 +6,56 @@ from __future__ import annotations
import json
import os
import random
-import tempfile
-from collections import defaultdict
+from functools import lru_cache
from pathlib import Path
from typing import Any
from typing import Callable
-import numpy as np
import tensorflow as tf
-from tensorflow.lite.python import interpreter as interpreter_wrapper
-from mlia.nn.rewrite.core.utils.utils import load
-from mlia.nn.rewrite.core.utils.utils import save
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
-def make_decode_fn(filename: str) -> Callable:
- """Make decode filename."""
+def decode_fn(record_bytes: Any, type_map: dict) -> dict:
+ """Decode the given bytes into a name-tensor dict assuming the given type."""
+ parse_dict = {
+ name: tf.io.FixedLenFeature([], tf.string) for name in type_map.keys()
+ }
+ example = tf.io.parse_single_example(record_bytes, parse_dict)
+ features = {
+ n: tf.io.parse_tensor(example[n], tf.as_dtype(t)) for n, t in type_map.items()
+ }
+ return features
- def decode_fn(record_bytes: Any, type_map: dict) -> dict:
- parse_dict = {
- name: tf.io.FixedLenFeature([], tf.string) for name in type_map.keys()
- }
- example = tf.io.parse_single_example(record_bytes, parse_dict)
- features = {
- n: tf.io.parse_tensor(example[n], tf.as_dtype(t))
- for n, t in type_map.items()
- }
- return features
+def make_decode_fn(filename: str, model_filename: str | Path | None = None) -> Callable:
+ """Make decode filename."""
meta_filename = filename + ".meta"
- with open(meta_filename, encoding="utf-8") as file:
- type_map = json.load(file)["type_map"]
+ try:
+ with open(meta_filename, encoding="utf-8") as file:
+ type_map = json.load(file)["type_map"]
return lambda record_bytes: decode_fn(record_bytes, type_map)
def numpytf_read(filename: str | Path) -> Any:
"""Read TFRecord dataset."""
- decode_fn = make_decode_fn(str(filename))
+ decode = make_decode_fn(str(filename))
dataset = tf.data.TFRecordDataset(str(filename))
- return dataset.map(decode_fn)
+ return dataset.map(decode)
-def numpytf_count(filename: str | Path) -> Any:
+@lru_cache
+def numpytf_count(filename: str | Path) -> int:
"""Return count from TFRecord file."""
meta_filename = f"{filename}.meta"
- with open(meta_filename, encoding="utf-8") as file:
- return json.load(file)["count"]
+ try:
+ with open(meta_filename, encoding="utf-8") as file:
+ return int(json.load(file)["count"])
+ except FileNotFoundError:
+ raw_dataset = tf.data.TFRecordDataset(filename)
+ return sum(1 for _ in raw_dataset)
class NumpyTFWriter:
@@ -101,93 +102,6 @@ class NumpyTFWriter:
self.writer.close()
-class TFLiteModel:
- """A representation of a TFLite Model."""
-
- def __init__(
- self,
- filename: str,
- batch_size: int | None = None,
- num_threads: int | None = None,
- ) -> None:
- """Initiate a TFLite Model."""
- if not num_threads:
- num_threads = None
- if not batch_size:
- self.interpreter = interpreter_wrapper.Interpreter(
- model_path=filename, num_threads=num_threads
- )
- else: # if a batch size is specified, modify the TFLite model to use this size
- with tempfile.TemporaryDirectory() as tmp:
- flatbuffer = load(filename)
- for subgraph in flatbuffer.subgraphs:
- for tensor in list(subgraph.inputs) + list(subgraph.outputs):
- subgraph.tensors[tensor].shape = np.array(
- [batch_size] + list(subgraph.tensors[tensor].shape[1:]),
- dtype=np.int32,
- )
- tempname = os.path.join(tmp, "rewrite_tmp.tflite")
- save(flatbuffer, tempname)
- self.interpreter = interpreter_wrapper.Interpreter(
- model_path=tempname, num_threads=num_threads
- )
-
- try:
- self.interpreter.allocate_tensors()
- except RuntimeError:
- self.interpreter = interpreter_wrapper.Interpreter(
- model_path=filename, num_threads=num_threads
- )
- self.interpreter.allocate_tensors()
-
- # Get input and output tensors.
- self.input_details = self.interpreter.get_input_details()
- self.output_details = self.interpreter.get_output_details()
- details = list(self.input_details) + list(self.output_details)
- self.handle_from_name = {d["name"]: d["index"] for d in details}
- self.shape_from_name = {d["name"]: d["shape"] for d in details}
- self.batch_size = next(iter(self.shape_from_name.values()))[0]
-
- def __call__(self, named_input: dict) -> dict:
- """Execute the model on one or a batch of named inputs \
- (a dict of name: numpy array)."""
- input_len = next(iter(named_input.values())).shape[0]
- full_steps = input_len // self.batch_size
- remainder = input_len % self.batch_size
-
- named_ys = defaultdict(list)
- for i in range(full_steps):
- for name, x_batch in named_input.items():
- x_tensor = x_batch[i : i + self.batch_size] # noqa: E203
- self.interpreter.set_tensor(self.handle_from_name[name], x_tensor)
- self.interpreter.invoke()
- for output_detail in self.output_details:
- named_ys[output_detail["name"]].append(
- self.interpreter.get_tensor(output_detail["index"])
- )
- if remainder:
- for name, x_batch in named_input.items():
- x_tensor = np.zeros( # pylint: disable=invalid-name
- self.shape_from_name[name]
- ).astype(x_batch.dtype)
- x_tensor[:remainder] = x_batch[-remainder:]
- self.interpreter.set_tensor(self.handle_from_name[name], x_tensor)
- self.interpreter.invoke()
- for output_detail in self.output_details:
- named_ys[output_detail["name"]].append(
- self.interpreter.get_tensor(output_detail["index"])[:remainder]
- )
- return {k: np.concatenate(v) for k, v in named_ys.items()}
-
- def input_tensors(self) -> list:
- """Return name from input details."""
- return [d["name"] for d in self.input_details]
-
- def output_tensors(self) -> list:
- """Return name from output details."""
- return [d["name"] for d in self.output_details]
-
-
def sample_tfrec(input_file: str, k: int, output_file: str) -> None:
"""Count, read and write TFRecord input and output data."""
total = numpytf_count(input_file)
diff --git a/src/mlia/nn/rewrite/core/utils/parallel.py b/src/mlia/nn/rewrite/core/utils/parallel.py
index d930a1e..b7b390d 100644
--- a/src/mlia/nn/rewrite/core/utils/parallel.py
+++ b/src/mlia/nn/rewrite/core/utils/parallel.py
@@ -15,14 +15,14 @@ from typing import Any
import numpy as np
import tensorflow as tf
-from mlia.nn.rewrite.core.utils.numpy_tfrecord import TFLiteModel
+from mlia.nn.tensorflow.config import TFLiteModel
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
logger = logging.getLogger(__name__)
-class ParallelTFLiteModel(TFLiteModel):
+class ParallelTFLiteModel(TFLiteModel): # pylint: disable=abstract-method
"""A parallel version of a TFLiteModel.
num_procs: 0 => detect real cores on system
diff --git a/src/mlia/nn/rewrite/core/utils/utils.py b/src/mlia/nn/rewrite/core/utils/utils.py
deleted file mode 100644
index ddf0cc2..0000000
--- a/src/mlia/nn/rewrite/core/utils/utils.py
+++ /dev/null
@@ -1,32 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Model and file system utilites."""
-from __future__ import annotations
-
-from pathlib import Path
-
-import flatbuffers
-from tensorflow.lite.python.schema_py_generated import Model
-from tensorflow.lite.python.schema_py_generated import ModelT
-
-
-def load(input_tflite_file: str | Path) -> ModelT:
- """Load a flatbuffer model from file."""
- if not Path(input_tflite_file).exists():
- raise FileNotFoundError(f"TFLite file not found at {input_tflite_file}\n")
- with open(input_tflite_file, "rb") as file_handle:
- file_data = bytearray(file_handle.read())
- model_obj = Model.GetRootAsModel(file_data, 0)
- model = ModelT.InitFromObj(model_obj)
- return model
-
-
-def save(model: ModelT, output_tflite_file: str | Path) -> None:
- """Save a flatbuffer model to a given file."""
- builder = flatbuffers.Builder(1024) # Initial size of the buffer, which
- # will grow automatically if needed
- model_offset = model.Pack(builder)
- builder.Finish(model_offset, file_identifier=b"TFL3")
- model_data = builder.Output()
- with open(output_tflite_file, "wb") as out_file:
- out_file.write(model_data)