aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/graph_edit/join.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/graph_edit/join.py')
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/join.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/join.py b/src/mlia/nn/rewrite/core/graph_edit/join.py
index 2530ec8..70109d8 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/join.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/join.py
@@ -11,8 +11,8 @@ from tensorflow.lite.python.schema_py_generated import ModelT
from tensorflow.lite.python.schema_py_generated import OperatorCodeT
from tensorflow.lite.python.schema_py_generated import SubGraphT
-from mlia.nn.rewrite.core.utils.utils import load
-from mlia.nn.rewrite.core.utils.utils import save
+from mlia.nn.tensorflow.tflite_graph import load_fb
+from mlia.nn.tensorflow.tflite_graph import save_fb
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
@@ -26,12 +26,12 @@ def join_models(
subgraph_dst: int = 0,
) -> None:
"""Join two models and save the result into a given model file path."""
- src_model = load(input_src)
- dst_model = load(input_dst)
+ src_model = load_fb(input_src)
+ dst_model = load_fb(input_dst)
src_subgraph = src_model.subgraphs[subgraph_src]
dst_subgraph = dst_model.subgraphs[subgraph_dst]
join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph)
- save(dst_model, output_file)
+ save_fb(dst_model, output_file)
def join_subgraphs(