aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/utils
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/utils')
-rw-r--r--src/mlia/nn/rewrite/core/utils/__init__.py2
-rw-r--r--src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py172
-rw-r--r--src/mlia/nn/rewrite/core/utils/parallel.py113
-rw-r--r--src/mlia/nn/rewrite/core/utils/utils.py26
4 files changed, 313 insertions, 0 deletions
diff --git a/src/mlia/nn/rewrite/core/utils/__init__.py b/src/mlia/nn/rewrite/core/utils/__init__.py
new file mode 100644
index 0000000..48b1622
--- /dev/null
+++ b/src/mlia/nn/rewrite/core/utils/__init__.py
@@ -0,0 +1,2 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file
diff --git a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
new file mode 100644
index 0000000..ac3e875
--- /dev/null
+++ b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
@@ -0,0 +1,172 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+import json
+import os
+import random
+import tempfile
+from collections import defaultdict
+
+import numpy as np
+
+from mlia.nn.rewrite.core.utils.utils import load
+from mlia.nn.rewrite.core.utils.utils import save
+
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+import tensorflow as tf
+
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
+from tensorflow.lite.python import interpreter as interpreter_wrapper
+
+
+def make_decode_fn(filename):
+ def decode_fn(record_bytes, type_map):
+ parse_dict = {
+ name: tf.io.FixedLenFeature([], tf.string) for name in type_map.keys()
+ }
+ example = tf.io.parse_single_example(record_bytes, parse_dict)
+ features = {
+ n: tf.io.parse_tensor(example[n], tf.as_dtype(t))
+ for n, t in type_map.items()
+ }
+ return features
+
+ meta_filename = filename + ".meta"
+ with open(meta_filename) as f:
+ type_map = json.load(f)["type_map"]
+ return lambda record_bytes: decode_fn(record_bytes, type_map)
+
+
+def NumpyTFReader(filename):
+ decode_fn = make_decode_fn(filename)
+ dataset = tf.data.TFRecordDataset(filename)
+ return dataset.map(decode_fn)
+
+
+def numpytf_count(filename):
+ meta_filename = filename + ".meta"
+ with open(meta_filename) as f:
+ return json.load(f)["count"]
+
+
+class NumpyTFWriter:
+ def __init__(self, filename):
+ self.filename = filename
+ self.meta_filename = filename + ".meta"
+ self.writer = tf.io.TFRecordWriter(filename)
+ self.type_map = {}
+ self.count = 0
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type, value, traceback):
+ self.close()
+
+ def __del__(self):
+ self.close()
+
+ def write(self, array_dict):
+ type_map = {n: str(a.dtype.name) for n, a in array_dict.items()}
+ self.type_map.update(type_map)
+ self.count += 1
+
+ feature = {
+ n: tf.train.Feature(
+ bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(a).numpy()])
+ )
+ for n, a in array_dict.items()
+ }
+ example = tf.train.Example(features=tf.train.Features(feature=feature))
+ self.writer.write(example.SerializeToString())
+
+ def close(self):
+ with open(self.meta_filename, "w") as f:
+ meta = {"type_map": self.type_map, "count": self.count}
+ json.dump(meta, f)
+ self.writer.close()
+
+
+class TFLiteModel:
+ def __init__(self, filename, batch_size=None, num_threads=None):
+ if num_threads == 0:
+ num_threads = None
+ if batch_size == None:
+ self.interpreter = interpreter_wrapper.Interpreter(
+ model_path=filename, num_threads=num_threads
+ )
+ else: # if a batch size is specified, modify the TFLite model to use this size
+ with tempfile.TemporaryDirectory() as tmp:
+ fb = load(filename)
+ for sg in fb.subgraphs:
+ for t in list(sg.inputs) + list(sg.outputs):
+ sg.tensors[t].shape = np.array(
+ [batch_size] + list(sg.tensors[t].shape[1:]), dtype=np.int32
+ )
+ tempname = os.path.join(tmp, "rewrite_tmp.tflite")
+ save(fb, tempname)
+ self.interpreter = interpreter_wrapper.Interpreter(
+ model_path=tempname, num_threads=num_threads
+ )
+
+ try:
+ self.interpreter.allocate_tensors()
+ except RuntimeError:
+ self.interpreter = interpreter_wrapper.Interpreter(
+ model_path=filename, num_threads=num_threads
+ )
+ self.interpreter.allocate_tensors()
+
+ # Get input and output tensors.
+ self.input_details = self.interpreter.get_input_details()
+ self.output_details = self.interpreter.get_output_details()
+ details = list(self.input_details) + list(self.output_details)
+ self.handle_from_name = {d["name"]: d["index"] for d in details}
+ self.shape_from_name = {d["name"]: d["shape"] for d in details}
+ self.batch_size = next(iter(self.shape_from_name.values()))[0]
+
+ def __call__(self, named_input):
+ """Execute the model on one or a batch of named inputs (a dict of name: numpy array)"""
+ input_len = next(iter(named_input.values())).shape[0]
+ full_steps = input_len // self.batch_size
+ remainder = input_len % self.batch_size
+
+ named_ys = defaultdict(list)
+ for i in range(full_steps):
+ for name, x_batch in named_input.items():
+ x = x_batch[i : i + self.batch_size]
+ self.interpreter.set_tensor(self.handle_from_name[name], x)
+ self.interpreter.invoke()
+ for d in self.output_details:
+ named_ys[d["name"]].append(self.interpreter.get_tensor(d["index"]))
+ if remainder:
+ for name, x_batch in named_input.items():
+ x = np.zeros(self.shape_from_name[name]).astype(x_batch.dtype)
+ x[:remainder] = x_batch[-remainder:]
+ self.interpreter.set_tensor(self.handle_from_name[name], x)
+ self.interpreter.invoke()
+ for d in self.output_details:
+ named_ys[d["name"]].append(
+ self.interpreter.get_tensor(d["index"])[:remainder]
+ )
+ return {k: np.concatenate(v) for k, v in named_ys.items()}
+
+ def input_tensors(self):
+ return [d["name"] for d in self.input_details]
+
+ def output_tensors(self):
+ return [d["name"] for d in self.output_details]
+
+
+def sample_tfrec(input_file, k, output_file):
+ total = numpytf_count(input_file)
+ next = sorted(random.sample(range(total), k=k), reverse=True)
+
+ reader = NumpyTFReader(input_file)
+ with NumpyTFWriter(output_file) as writer:
+ for i, data in enumerate(reader):
+ if i == next[-1]:
+ next.pop()
+ writer.write(data)
+ if not next:
+ break
diff --git a/src/mlia/nn/rewrite/core/utils/parallel.py b/src/mlia/nn/rewrite/core/utils/parallel.py
new file mode 100644
index 0000000..5affc03
--- /dev/null
+++ b/src/mlia/nn/rewrite/core/utils/parallel.py
@@ -0,0 +1,113 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+import math
+import os
+from collections import defaultdict
+from multiprocessing import Pool
+
+import numpy as np
+from psutil import cpu_count
+
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+import tensorflow as tf
+
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import TFLiteModel
+
+
+class ParallelTFLiteModel(TFLiteModel):
+ def __init__(self, filename, num_procs=1, num_threads=0, batch_size=None):
+ """num_procs: 0 => detect real cores on system
+ num_threads: 0 => TFLite impl. specific setting, usually 3
+ batch_size: None => automatic (num_procs or file-determined)
+ """
+ self.pool = None
+ self.filename = filename
+ if not num_procs:
+ self.num_procs = cpu_count(logical=False)
+ else:
+ self.num_procs = int(num_procs)
+
+ self.num_threads = num_threads
+
+ if self.num_procs > 1:
+ if not batch_size:
+ batch_size = self.num_procs # default to min effective batch size
+ local_batch_size = int(math.ceil(batch_size / self.num_procs))
+ super().__init__(filename, batch_size=local_batch_size)
+ del self.interpreter
+ self.pool = Pool(
+ processes=self.num_procs,
+ initializer=_pool_create_worker,
+ initargs=[filename, self.batch_size, self.num_threads],
+ )
+ else: # fall back to serial implementation for max performance
+ super().__init__(
+ filename, batch_size=batch_size, num_threads=self.num_threads
+ )
+
+ self.total_batches = 0
+ self.partial_batches = 0
+ self.warned = False
+
+ def close(self):
+ if self.pool:
+ self.pool.close()
+ self.pool.terminate()
+
+ def __del__(self):
+ self.close()
+
+ def __call__(self, named_input):
+ if self.pool:
+ global_batch_size = next(iter(named_input.values())).shape[0]
+ # Note: self.batch_size comes from superclass and is local batch size
+ chunks = int(math.ceil(global_batch_size / self.batch_size))
+ self.total_batches += 1
+ if chunks != self.num_procs:
+ self.partial_batches += 1
+ if (
+ not self.warned
+ and self.total_batches > 10
+ and self.partial_batches / self.total_batches >= 0.5
+ ):
+ print(
+ "ParallelTFLiteModel(%s): warning - %.1f%% of batches do not use all %d processes, set batch size to a multiple of this"
+ % (
+ self.filename,
+ 100 * self.partial_batches / self.total_batches,
+ self.num_procs,
+ )
+ )
+ self.warned = True
+
+ local_batches = [
+ {
+ key: values[i * self.batch_size : (i + 1) * self.batch_size]
+ for key, values in named_input.items()
+ }
+ for i in range(chunks)
+ ]
+ chunk_results = self.pool.map(_pool_run, local_batches)
+ named_ys = defaultdict(list)
+ for chunk in chunk_results:
+ for k, v in chunk.items():
+ named_ys[k].append(v)
+ return {k: np.concatenate(v) for k, v in named_ys.items()}
+ else:
+ return super().__call__(named_input)
+
+
+_local_model = None
+
+
+def _pool_create_worker(filename, local_batch_size=None, num_threads=None):
+ global _local_model
+ _local_model = TFLiteModel(
+ filename, batch_size=local_batch_size, num_threads=num_threads
+ )
+
+
+def _pool_run(named_inputs):
+ return _local_model(named_inputs)
diff --git a/src/mlia/nn/rewrite/core/utils/utils.py b/src/mlia/nn/rewrite/core/utils/utils.py
new file mode 100644
index 0000000..ed6c81d
--- /dev/null
+++ b/src/mlia/nn/rewrite/core/utils/utils.py
@@ -0,0 +1,26 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+import os
+
+import flatbuffers
+from tensorflow.lite.python import schema_py_generated as schema_fb
+
+
+def load(input_tflite_file):
+ if not os.path.exists(input_tflite_file):
+ raise RuntimeError("TFLite file not found at %r\n" % input_tflite_file)
+ with open(input_tflite_file, "rb") as file_handle:
+ file_data = bytearray(file_handle.read())
+ model_obj = schema_fb.Model.GetRootAsModel(file_data, 0)
+ model = schema_fb.ModelT.InitFromObj(model_obj)
+ return model
+
+
+def save(model, output_tflite_file):
+ builder = flatbuffers.Builder(1024) # Initial size of the buffer, which
+ # will grow automatically if needed
+ model_offset = model.Pack(builder)
+ builder.Finish(model_offset, file_identifier=b"TFL3")
+ model_data = builder.Output()
+ with open(output_tflite_file, "wb") as out_file:
+ out_file.write(model_data)