diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/graph_edit')
-rw-r--r-- | src/mlia/nn/rewrite/core/graph_edit/__init__.py | 2 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/graph_edit/diff.py | 30 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/graph_edit/record.py | 33 |
3 files changed, 20 insertions, 45 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/__init__.py b/src/mlia/nn/rewrite/core/graph_edit/__init__.py index 48b1622..8c1f750 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/__init__.py +++ b/src/mlia/nn/rewrite/core/graph_edit/__init__.py @@ -1,2 +1,2 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0
\ No newline at end of file +# SPDX-License-Identifier: Apache-2.0 diff --git a/src/mlia/nn/rewrite/core/graph_edit/diff.py b/src/mlia/nn/rewrite/core/graph_edit/diff.py index b6c9616..0829f0a 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/diff.py +++ b/src/mlia/nn/rewrite/core/graph_edit/diff.py @@ -13,36 +13,6 @@ from tensorflow.lite.python import interpreter as interpreter_wrapper from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFReader, NumpyTFWriter -def diff(file1, file2): - results = [] - - dataset1 = NumpyTFReader(file1) - dataset2 = NumpyTFReader(file2) - - for i, (x1, x2) in enumerate(zip(dataset1, dataset2)): - assert x1.keys() == x2.keys(), ( - "At input %d the files have different sets of tensors.\n%s: %s\n%s: %s\n" - % ( - i, - file1, - ", ".join(x1.keys()), - file2, - ", ".join(x2.keys()), - ) - ) - results.append({}) - for k in x1.keys(): - v1 = x1[k].numpy().astype(np.double) - v2 = x2[k].numpy().astype(np.double) - mae = abs(v1 - v2).mean() - results[-1][k] = mae - - total = sum(sum(x.values()) for x in results) - count = sum(len(x.values()) for x in results) - mean = total / count - return results, mean - - def diff_stats(file1, file2, per_tensor_and_channel=False): dataset1 = NumpyTFReader(file1) dataset2 = NumpyTFReader(file2) diff --git a/src/mlia/nn/rewrite/core/graph_edit/record.py b/src/mlia/nn/rewrite/core/graph_edit/record.py index 03cd3f9..ae13313 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/record.py +++ b/src/mlia/nn/rewrite/core/graph_edit/record.py @@ -37,7 +37,6 @@ def record_model( total = numpytf_count(input_filename) dataset = NumpyTFReader(input_filename) - writer = NumpyTFWriter(output_filename) if batch_size > 1: # Collapse batch-size 1 items into batch-size n. I regret using batch-size 1 items in tfrecs now. @@ -47,16 +46,22 @@ def record_model( dataset = dataset.batch(batch_size, drop_remainder=False) total = int(math.ceil(total / batch_size)) - for j, named_x in enumerate( - tqdm(dataset.as_numpy_iterator(), total=total, disable=not show_progress) - ): - named_y = model(named_x) - if batch_size > 1: - for i in range(batch_size): - # Expand the batches and recreate each dict as a batch-size 1 item for the tfrec output - d = {k: v[i : i + 1] for k, v in named_y.items() if i < v.shape[0]} - if d: - writer.write(d) - else: - writer.write(named_y) - model.close() + with NumpyTFWriter(output_filename) as writer: + for _, named_x in enumerate( + tqdm(dataset.as_numpy_iterator(), total=total, disable=not show_progress) + ): + named_y = model(named_x) + if batch_size > 1: + for i in range(batch_size): + # Expand the batches and recreate each dict as a + # batch-size 1 item for the tfrec output + recreated_dict = { + k: v[i : i + 1] # noqa: E203 + for k, v in named_y.items() + if i < v.shape[0] + } + if recreated_dict: + writer.write(recreated_dict) + else: + writer.write(named_y) + model.close() |