aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/extract.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/extract.py')
-rw-r--r--src/mlia/nn/rewrite/core/extract.py35
1 files changed, 20 insertions, 15 deletions
diff --git a/src/mlia/nn/rewrite/core/extract.py b/src/mlia/nn/rewrite/core/extract.py
index 5fcd348..f609955 100644
--- a/src/mlia/nn/rewrite/core/extract.py
+++ b/src/mlia/nn/rewrite/core/extract.py
@@ -1,28 +1,33 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
+"""Extract module."""
+# pylint: disable=too-many-arguments, too-many-locals
import os
-os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
-
-tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+from tensorflow.lite.python.schema_py_generated import SubGraphT
from mlia.nn.rewrite.core.graph_edit.cut import cut_model
from mlia.nn.rewrite.core.graph_edit.record import record_model
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
+
def extract(
- output_path,
- model_file,
- input_data,
- input_names,
- output_names,
- subgraph=0,
- skip_outputs=False,
- show_progress=False,
- num_procs=1,
- num_threads=0,
-):
+ output_path: str,
+ model_file: str,
+ input_filename: str,
+ input_names: list,
+ output_names: list,
+ subgraph: SubGraphT = 0,
+ skip_outputs: bool = False,
+ show_progress: bool = False,
+ num_procs: int = 1,
+ num_threads: int = 0,
+) -> None:
+ """Extract a model after cut and record."""
try:
os.mkdir(output_path)
except FileExistsError:
@@ -39,7 +44,7 @@ def extract(
input_tfrec = os.path.join(output_path, "input.tfrec")
record_model(
- input_data,
+ input_filename,
start_file,
input_tfrec,
show_progress=show_progress,