aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/graph_edit
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/graph_edit')
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/__init__.py2
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/diff.py30
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/record.py33
3 files changed, 20 insertions, 45 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/__init__.py b/src/mlia/nn/rewrite/core/graph_edit/__init__.py
index 48b1622..8c1f750 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/__init__.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/__init__.py
@@ -1,2 +1,2 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file
+# SPDX-License-Identifier: Apache-2.0
diff --git a/src/mlia/nn/rewrite/core/graph_edit/diff.py b/src/mlia/nn/rewrite/core/graph_edit/diff.py
index b6c9616..0829f0a 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/diff.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/diff.py
@@ -13,36 +13,6 @@ from tensorflow.lite.python import interpreter as interpreter_wrapper
from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFReader, NumpyTFWriter
-def diff(file1, file2):
- results = []
-
- dataset1 = NumpyTFReader(file1)
- dataset2 = NumpyTFReader(file2)
-
- for i, (x1, x2) in enumerate(zip(dataset1, dataset2)):
- assert x1.keys() == x2.keys(), (
- "At input %d the files have different sets of tensors.\n%s: %s\n%s: %s\n"
- % (
- i,
- file1,
- ", ".join(x1.keys()),
- file2,
- ", ".join(x2.keys()),
- )
- )
- results.append({})
- for k in x1.keys():
- v1 = x1[k].numpy().astype(np.double)
- v2 = x2[k].numpy().astype(np.double)
- mae = abs(v1 - v2).mean()
- results[-1][k] = mae
-
- total = sum(sum(x.values()) for x in results)
- count = sum(len(x.values()) for x in results)
- mean = total / count
- return results, mean
-
-
def diff_stats(file1, file2, per_tensor_and_channel=False):
dataset1 = NumpyTFReader(file1)
dataset2 = NumpyTFReader(file2)
diff --git a/src/mlia/nn/rewrite/core/graph_edit/record.py b/src/mlia/nn/rewrite/core/graph_edit/record.py
index 03cd3f9..ae13313 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/record.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/record.py
@@ -37,7 +37,6 @@ def record_model(
total = numpytf_count(input_filename)
dataset = NumpyTFReader(input_filename)
- writer = NumpyTFWriter(output_filename)
if batch_size > 1:
# Collapse batch-size 1 items into batch-size n. I regret using batch-size 1 items in tfrecs now.
@@ -47,16 +46,22 @@ def record_model(
dataset = dataset.batch(batch_size, drop_remainder=False)
total = int(math.ceil(total / batch_size))
- for j, named_x in enumerate(
- tqdm(dataset.as_numpy_iterator(), total=total, disable=not show_progress)
- ):
- named_y = model(named_x)
- if batch_size > 1:
- for i in range(batch_size):
- # Expand the batches and recreate each dict as a batch-size 1 item for the tfrec output
- d = {k: v[i : i + 1] for k, v in named_y.items() if i < v.shape[0]}
- if d:
- writer.write(d)
- else:
- writer.write(named_y)
- model.close()
+ with NumpyTFWriter(output_filename) as writer:
+ for _, named_x in enumerate(
+ tqdm(dataset.as_numpy_iterator(), total=total, disable=not show_progress)
+ ):
+ named_y = model(named_x)
+ if batch_size > 1:
+ for i in range(batch_size):
+ # Expand the batches and recreate each dict as a
+ # batch-size 1 item for the tfrec output
+ recreated_dict = {
+ k: v[i : i + 1] # noqa: E203
+ for k, v in named_y.items()
+ if i < v.shape[0]
+ }
+ if recreated_dict:
+ writer.write(recreated_dict)
+ else:
+ writer.write(named_y)
+ model.close()