aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_rewrite_core_graph_edit_join.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_nn_rewrite_core_graph_edit_join.py')
-rw-r--r--tests/test_nn_rewrite_core_graph_edit_join.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/tests/test_nn_rewrite_core_graph_edit_join.py b/tests/test_nn_rewrite_core_graph_edit_join.py
index cb3e4e2..0cb121e 100644
--- a/tests/test_nn_rewrite_core_graph_edit_join.py
+++ b/tests/test_nn_rewrite_core_graph_edit_join.py
@@ -10,7 +10,7 @@ import pytest
from mlia.nn.rewrite.core.graph_edit.cut import cut_model
from mlia.nn.rewrite.core.graph_edit.join import append_relabel
from mlia.nn.rewrite.core.graph_edit.join import join_models
-from mlia.nn.rewrite.core.utils.utils import load
+from mlia.nn.tensorflow.tflite_graph import load_fb
from tests.utils.rewrite import models_are_equal
@@ -49,8 +49,8 @@ def test_join_model(test_tflite_model: Path, tmp_path: Path) -> None:
)
assert joined_file.is_file()
- orig_model = load(str(test_tflite_model))
- joined_model = load(str(joined_file))
+ orig_model = load_fb(str(test_tflite_model))
+ joined_model = load_fb(str(joined_file))
assert models_are_equal(orig_model, joined_model)