aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_rewrite_core_graph_edit_record.py
diff options
context:
space:
mode:
authorAnnie Tallund <annie.tallund@arm.com>2023-03-15 11:27:08 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-10-11 15:43:14 +0100
commit867f37d643e66c0223457c28f5345f2f21db97f2 (patch)
tree4e3c55896760e24a8b5eadc5176ce7f5586552e1 /tests/test_nn_rewrite_core_graph_edit_record.py
parent62768232c5fe4ed6b87136c336b65e13d030e9d4 (diff)
downloadmlia-867f37d643e66c0223457c28f5345f2f21db97f2.tar.gz
Adapt rewrite module to MLIA coding standards
- Fix imports - Update variable names - Refactor helper functions - Add licence headers - Add docstrings - Use f-strings rather than % notation - Create type annotations in rewrite module - Migrate from tqdm to rich progress bar - Use logging module in rewrite module: All print statements are replaced with logging module Resolves: MLIA-831, MLIA-842, MLIA-844, MLIA-846 Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com> Change-Id: Idee37538d72b9f01128a894281a8d10155f7c17c
Diffstat (limited to 'tests/test_nn_rewrite_core_graph_edit_record.py')
-rw-r--r--tests/test_nn_rewrite_core_graph_edit_record.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/tests/test_nn_rewrite_core_graph_edit_record.py b/tests/test_nn_rewrite_core_graph_edit_record.py
index 39aeef5..cd728af 100644
--- a/tests/test_nn_rewrite_core_graph_edit_record.py
+++ b/tests/test_nn_rewrite_core_graph_edit_record.py
@@ -7,7 +7,7 @@ import pytest
import tensorflow as tf
from mlia.nn.rewrite.core.graph_edit.record import record_model
-from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFReader
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
@pytest.mark.parametrize("batch_size", (None, 1, 2))
@@ -46,7 +46,7 @@ def test_record_model(
# any of the model outputs
interpreter = tf.lite.Interpreter(str(test_tflite_model))
model_outputs = interpreter.get_output_details()
- dataset = NumpyTFReader(str(output_file))
+ dataset = numpytf_read(str(output_file))
for data in dataset:
for name, tensor in data.items():
assert data_matches_outputs(name, tensor, model_outputs)