diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/graph_edit/join.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/graph_edit/join.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/join.py b/src/mlia/nn/rewrite/core/graph_edit/join.py index 2530ec8..70109d8 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/join.py +++ b/src/mlia/nn/rewrite/core/graph_edit/join.py @@ -11,8 +11,8 @@ from tensorflow.lite.python.schema_py_generated import ModelT from tensorflow.lite.python.schema_py_generated import OperatorCodeT from tensorflow.lite.python.schema_py_generated import SubGraphT -from mlia.nn.rewrite.core.utils.utils import load -from mlia.nn.rewrite.core.utils.utils import save +from mlia.nn.tensorflow.tflite_graph import load_fb +from mlia.nn.tensorflow.tflite_graph import save_fb os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) @@ -26,12 +26,12 @@ def join_models( subgraph_dst: int = 0, ) -> None: """Join two models and save the result into a given model file path.""" - src_model = load(input_src) - dst_model = load(input_dst) + src_model = load_fb(input_src) + dst_model = load_fb(input_dst) src_subgraph = src_model.subgraphs[subgraph_src] dst_subgraph = dst_model.subgraphs[subgraph_dst] join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph) - save(dst_model, output_file) + save_fb(dst_model, output_file) def join_subgraphs( |