aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_rewrite_core_graph_edit_record.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_nn_rewrite_core_graph_edit_record.py')
-rw-r--r--tests/test_nn_rewrite_core_graph_edit_record.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/tests/test_nn_rewrite_core_graph_edit_record.py b/tests/test_nn_rewrite_core_graph_edit_record.py
index 39aeef5..cd728af 100644
--- a/tests/test_nn_rewrite_core_graph_edit_record.py
+++ b/tests/test_nn_rewrite_core_graph_edit_record.py
@@ -7,7 +7,7 @@ import pytest
import tensorflow as tf
from mlia.nn.rewrite.core.graph_edit.record import record_model
-from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFReader
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
@pytest.mark.parametrize("batch_size", (None, 1, 2))
@@ -46,7 +46,7 @@ def test_record_model(
# any of the model outputs
interpreter = tf.lite.Interpreter(str(test_tflite_model))
model_outputs = interpreter.get_output_details()
- dataset = NumpyTFReader(str(output_file))
+ dataset = numpytf_read(str(output_file))
for data in dataset:
for name, tensor in data.items():
assert data_matches_outputs(name, tensor, model_outputs)