diff options
author | Annie Tallund <annie.tallund@arm.com> | 2023-03-15 11:27:08 +0100 |
---|---|---|
committer | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2023-10-11 15:43:14 +0100 |
commit | 867f37d643e66c0223457c28f5345f2f21db97f2 (patch) | |
tree | 4e3c55896760e24a8b5eadc5176ce7f5586552e1 /src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py | |
parent | 62768232c5fe4ed6b87136c336b65e13d030e9d4 (diff) | |
download | mlia-867f37d643e66c0223457c28f5345f2f21db97f2.tar.gz |
Adapt rewrite module to MLIA coding standards
- Fix imports
- Update variable names
- Refactor helper functions
- Add licence headers
- Add docstrings
- Use f-strings rather than % notation
- Create type annotations in rewrite module
- Migrate from tqdm to rich progress bar
- Use logging module in rewrite module: All print statements are
replaced with logging module
Resolves: MLIA-831, MLIA-842, MLIA-844, MLIA-846
Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
Change-Id: Idee37538d72b9f01128a894281a8d10155f7c17c
Diffstat (limited to 'src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py | 140 |
1 files changed, 87 insertions, 53 deletions
diff --git a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py index 2141003..9229810 100644 --- a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py +++ b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py @@ -1,26 +1,32 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 +"""Numpy TFRecord utils.""" +from __future__ import annotations + import json import os import random import tempfile from collections import defaultdict +from pathlib import Path +from typing import Any +from typing import Callable import numpy as np +import tensorflow as tf +from tensorflow.lite.python import interpreter as interpreter_wrapper from mlia.nn.rewrite.core.utils.utils import load from mlia.nn.rewrite.core.utils.utils import save os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" -import tensorflow as tf - tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) -from tensorflow.lite.python import interpreter as interpreter_wrapper +def make_decode_fn(filename: str) -> Callable: + """Make decode filename.""" -def make_decode_fn(filename): - def decode_fn(record_bytes, type_map): + def decode_fn(record_bytes: Any, type_map: dict) -> dict: parse_dict = { name: tf.io.FixedLenFeature([], tf.string) for name in type_map.keys() } @@ -32,38 +38,48 @@ def make_decode_fn(filename): return features meta_filename = filename + ".meta" - with open(meta_filename) as f: - type_map = json.load(f)["type_map"] + with open(meta_filename, encoding="utf-8") as file: + type_map = json.load(file)["type_map"] return lambda record_bytes: decode_fn(record_bytes, type_map) -def NumpyTFReader(filename): - decode_fn = make_decode_fn(filename) - dataset = tf.data.TFRecordDataset(filename) +def numpytf_read(filename: str | Path) -> Any: + """Read TFRecord dataset.""" + decode_fn = make_decode_fn(str(filename)) + dataset = tf.data.TFRecordDataset(str(filename)) return dataset.map(decode_fn) -def numpytf_count(filename): - meta_filename = filename + ".meta" - with open(meta_filename) as f: - return json.load(f)["count"] +def numpytf_count(filename: str | Path) -> Any: + """Return count from TFRecord file.""" + meta_filename = f"{filename}.meta" + with open(meta_filename, encoding="utf-8") as file: + return json.load(file)["count"] class NumpyTFWriter: - def __init__(self, filename): + """Numpy TF serializer.""" + + def __init__(self, filename: str | Path) -> None: + """Initiate a Numpy TF Serializer.""" self.filename = filename - self.meta_filename = filename + ".meta" - self.writer = tf.io.TFRecordWriter(filename) - self.type_map = {} + self.meta_filename = f"{filename}.meta" + self.writer = tf.io.TFRecordWriter(str(filename)) + self.type_map: dict = {} self.count = 0 - def __enter__(self): + def __enter__(self) -> Any: + """Enter instance.""" return self - def __exit__(self, type, value, traceback): + def __exit__( + self, exception_type: Any, exception_value: Any, exception_traceback: Any + ) -> None: + """Close instance.""" self.close() - def write(self, array_dict): + def write(self, array_dict: dict) -> None: + """Write array dict.""" type_map = {n: str(a.dtype.name) for n, a in array_dict.items()} self.type_map.update(type_map) self.count += 1 @@ -77,31 +93,41 @@ class NumpyTFWriter: example = tf.train.Example(features=tf.train.Features(feature=feature)) self.writer.write(example.SerializeToString()) - def close(self): - with open(self.meta_filename, "w") as f: + def close(self) -> None: + """Close NumpyTFWriter.""" + with open(self.meta_filename, "w", encoding="utf-8") as file: meta = {"type_map": self.type_map, "count": self.count} - json.dump(meta, f) + json.dump(meta, file) self.writer.close() class TFLiteModel: - def __init__(self, filename, batch_size=None, num_threads=None): - if num_threads == 0: + """A representation of a TFLite Model.""" + + def __init__( + self, + filename: str, + batch_size: int | None = None, + num_threads: int | None = None, + ) -> None: + """Initiate a TFLite Model.""" + if not num_threads: num_threads = None - if batch_size == None: + if not batch_size: self.interpreter = interpreter_wrapper.Interpreter( model_path=filename, num_threads=num_threads ) else: # if a batch size is specified, modify the TFLite model to use this size with tempfile.TemporaryDirectory() as tmp: - fb = load(filename) - for sg in fb.subgraphs: - for t in list(sg.inputs) + list(sg.outputs): - sg.tensors[t].shape = np.array( - [batch_size] + list(sg.tensors[t].shape[1:]), dtype=np.int32 + flatbuffer = load(filename) + for subgraph in flatbuffer.subgraphs: + for tensor in list(subgraph.inputs) + list(subgraph.outputs): + subgraph.tensors[tensor].shape = np.array( + [batch_size] + list(subgraph.tensors[tensor].shape[1:]), + dtype=np.int32, ) tempname = os.path.join(tmp, "rewrite_tmp.tflite") - save(fb, tempname) + save(flatbuffer, tempname) self.interpreter = interpreter_wrapper.Interpreter( model_path=tempname, num_threads=num_threads ) @@ -122,8 +148,9 @@ class TFLiteModel: self.shape_from_name = {d["name"]: d["shape"] for d in details} self.batch_size = next(iter(self.shape_from_name.values()))[0] - def __call__(self, named_input): - """Execute the model on one or a batch of named inputs (a dict of name: numpy array)""" + def __call__(self, named_input: dict) -> dict: + """Execute the model on one or a batch of named inputs \ + (a dict of name: numpy array).""" input_len = next(iter(named_input.values())).shape[0] full_steps = input_len // self.batch_size remainder = input_len % self.batch_size @@ -131,39 +158,46 @@ class TFLiteModel: named_ys = defaultdict(list) for i in range(full_steps): for name, x_batch in named_input.items(): - x = x_batch[i : i + self.batch_size] - self.interpreter.set_tensor(self.handle_from_name[name], x) + x_tensor = x_batch[i : i + self.batch_size] # noqa: E203 + self.interpreter.set_tensor(self.handle_from_name[name], x_tensor) self.interpreter.invoke() - for d in self.output_details: - named_ys[d["name"]].append(self.interpreter.get_tensor(d["index"])) + for output_detail in self.output_details: + named_ys[output_detail["name"]].append( + self.interpreter.get_tensor(output_detail["index"]) + ) if remainder: for name, x_batch in named_input.items(): - x = np.zeros(self.shape_from_name[name]).astype(x_batch.dtype) - x[:remainder] = x_batch[-remainder:] - self.interpreter.set_tensor(self.handle_from_name[name], x) + x_tensor = np.zeros( # pylint: disable=invalid-name + self.shape_from_name[name] + ).astype(x_batch.dtype) + x_tensor[:remainder] = x_batch[-remainder:] + self.interpreter.set_tensor(self.handle_from_name[name], x_tensor) self.interpreter.invoke() - for d in self.output_details: - named_ys[d["name"]].append( - self.interpreter.get_tensor(d["index"])[:remainder] + for output_detail in self.output_details: + named_ys[output_detail["name"]].append( + self.interpreter.get_tensor(output_detail["index"])[:remainder] ) return {k: np.concatenate(v) for k, v in named_ys.items()} - def input_tensors(self): + def input_tensors(self) -> list: + """Return name from input details.""" return [d["name"] for d in self.input_details] - def output_tensors(self): + def output_tensors(self) -> list: + """Return name from output details.""" return [d["name"] for d in self.output_details] -def sample_tfrec(input_file, k, output_file): +def sample_tfrec(input_file: str, k: int, output_file: str) -> None: + """Count, read and write TFRecord input and output data.""" total = numpytf_count(input_file) - next = sorted(random.sample(range(total), k=k), reverse=True) + next_sample = sorted(random.sample(range(total), k=k), reverse=True) - reader = NumpyTFReader(input_file) + reader = numpytf_read(input_file) with NumpyTFWriter(output_file) as writer: for i, data in enumerate(reader): - if i == next[-1]: - next.pop() + if i == next_sample[-1]: + next_sample.pop() writer.write(data) - if not next: + if not next_sample: break |