aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/config.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/tensorflow/config.py')
-rw-r--r--src/mlia/nn/tensorflow/config.py100
1 files changed, 91 insertions, 9 deletions
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py
index d7d430f..c6a7c88 100644
--- a/src/mlia/nn/tensorflow/config.py
+++ b/src/mlia/nn/tensorflow/config.py
@@ -4,13 +4,16 @@
from __future__ import annotations
import logging
+import tempfile
+from collections import defaultdict
from pathlib import Path
-from typing import cast
-from typing import List
+import numpy as np
import tensorflow as tf
from mlia.core.context import Context
+from mlia.nn.tensorflow.tflite_graph import load_fb
+from mlia.nn.tensorflow.tflite_graph import save_fb
from mlia.nn.tensorflow.utils import convert_to_tflite
from mlia.nn.tensorflow.utils import is_keras_model
from mlia.nn.tensorflow.utils import is_saved_model
@@ -71,10 +74,89 @@ class KerasModel(ModelConfiguration):
class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
"""TensorFlow Lite model configuration."""
- def input_details(self) -> list[dict]:
- """Get model's input details."""
- interpreter = tf.lite.Interpreter(model_path=self.model_path)
- return cast(List[dict], interpreter.get_input_details())
+ def __init__(
+ self,
+ model_path: str | Path,
+ batch_size: int | None = None,
+ num_threads: int | None = None,
+ ) -> None:
+ """Initiate a TFLite Model."""
+ super().__init__(model_path)
+ if not num_threads:
+ num_threads = None
+ if not batch_size:
+ self.interpreter = tf.lite.Interpreter(
+ model_path=self.model_path, num_threads=num_threads
+ )
+ else: # if a batch size is specified, modify the TFLite model to use this size
+ with tempfile.TemporaryDirectory() as tmp:
+ flatbuffer = load_fb(self.model_path)
+ for subgraph in flatbuffer.subgraphs:
+ for tensor in list(subgraph.inputs) + list(subgraph.outputs):
+ subgraph.tensors[tensor].shape = np.array(
+ [batch_size] + list(subgraph.tensors[tensor].shape[1:]),
+ dtype=np.int32,
+ )
+ tempname = Path(tmp, "rewrite_tmp.tflite")
+ save_fb(flatbuffer, tempname)
+ self.interpreter = tf.lite.Interpreter(
+ model_path=str(tempname), num_threads=num_threads
+ )
+
+ try:
+ self.interpreter.allocate_tensors()
+ except RuntimeError:
+ self.interpreter = tf.lite.Interpreter(
+ model_path=self.model_path, num_threads=num_threads
+ )
+ self.interpreter.allocate_tensors()
+
+ # Get input and output tensors.
+ self.input_details = self.interpreter.get_input_details()
+ self.output_details = self.interpreter.get_output_details()
+ details = list(self.input_details) + list(self.output_details)
+ self.handle_from_name = {d["name"]: d["index"] for d in details}
+ self.shape_from_name = {d["name"]: d["shape"] for d in details}
+ self.batch_size = next(iter(self.shape_from_name.values()))[0]
+
+ def __call__(self, named_input: dict) -> dict:
+ """Execute the model on one or a batch of named inputs \
+ (a dict of name: numpy array)."""
+ input_len = next(iter(named_input.values())).shape[0]
+ full_steps = input_len // self.batch_size
+ remainder = input_len % self.batch_size
+
+ named_ys = defaultdict(list)
+ for i in range(full_steps):
+ for name, x_batch in named_input.items():
+ x_tensor = x_batch[i : i + self.batch_size] # noqa: E203
+ self.interpreter.set_tensor(self.handle_from_name[name], x_tensor)
+ self.interpreter.invoke()
+ for output_detail in self.output_details:
+ named_ys[output_detail["name"]].append(
+ self.interpreter.get_tensor(output_detail["index"])
+ )
+ if remainder:
+ for name, x_batch in named_input.items():
+ x_tensor = np.zeros( # pylint: disable=invalid-name
+ self.shape_from_name[name]
+ ).astype(x_batch.dtype)
+ x_tensor[:remainder] = x_batch[-remainder:]
+ self.interpreter.set_tensor(self.handle_from_name[name], x_tensor)
+ self.interpreter.invoke()
+ for output_detail in self.output_details:
+ named_ys[output_detail["name"]].append(
+ self.interpreter.get_tensor(output_detail["index"])[:remainder]
+ )
+ return {k: np.concatenate(v) for k, v in named_ys.items()}
+
+ def input_tensors(self) -> list:
+ """Return name from input details."""
+ return [d["name"] for d in self.input_details]
+
+ def output_tensors(self) -> list:
+ """Return name from output details."""
+ return [d["name"] for d in self.output_details]
def convert_to_tflite(
self, tflite_model_path: str | Path, quantized: bool = False
@@ -118,10 +200,10 @@ def get_model(model: str | Path) -> ModelConfiguration:
def get_tflite_model(model: str | Path, ctx: Context) -> TFLiteModel:
"""Convert input model to TensorFlow Lite and returns TFLiteModel object."""
- tflite_model_path = ctx.get_model_path("converted_model.tflite")
- converted_model = get_model(model)
+ dst_model_path = ctx.get_model_path("converted_model.tflite")
+ src_model = get_model(model)
- return converted_model.convert_to_tflite(tflite_model_path, True)
+ return src_model.convert_to_tflite(dst_model_path, quantized=True)
def get_keras_model(model: str | Path, ctx: Context) -> KerasModel: