aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/extract.py
diff options
context:
space:
mode:
authorAnnie Tallund <annie.tallund@arm.com>2023-03-15 11:27:08 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-10-11 15:43:14 +0100
commit867f37d643e66c0223457c28f5345f2f21db97f2 (patch)
tree4e3c55896760e24a8b5eadc5176ce7f5586552e1 /src/mlia/nn/rewrite/core/extract.py
parent62768232c5fe4ed6b87136c336b65e13d030e9d4 (diff)
downloadmlia-867f37d643e66c0223457c28f5345f2f21db97f2.tar.gz
Adapt rewrite module to MLIA coding standards
- Fix imports - Update variable names - Refactor helper functions - Add licence headers - Add docstrings - Use f-strings rather than % notation - Create type annotations in rewrite module - Migrate from tqdm to rich progress bar - Use logging module in rewrite module: All print statements are replaced with logging module Resolves: MLIA-831, MLIA-842, MLIA-844, MLIA-846 Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com> Change-Id: Idee37538d72b9f01128a894281a8d10155f7c17c
Diffstat (limited to 'src/mlia/nn/rewrite/core/extract.py')
-rw-r--r--src/mlia/nn/rewrite/core/extract.py35
1 files changed, 20 insertions, 15 deletions
diff --git a/src/mlia/nn/rewrite/core/extract.py b/src/mlia/nn/rewrite/core/extract.py
index 5fcd348..f609955 100644
--- a/src/mlia/nn/rewrite/core/extract.py
+++ b/src/mlia/nn/rewrite/core/extract.py
@@ -1,28 +1,33 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
+"""Extract module."""
+# pylint: disable=too-many-arguments, too-many-locals
import os
-os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
-
-tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+from tensorflow.lite.python.schema_py_generated import SubGraphT
from mlia.nn.rewrite.core.graph_edit.cut import cut_model
from mlia.nn.rewrite.core.graph_edit.record import record_model
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
+
def extract(
- output_path,
- model_file,
- input_data,
- input_names,
- output_names,
- subgraph=0,
- skip_outputs=False,
- show_progress=False,
- num_procs=1,
- num_threads=0,
-):
+ output_path: str,
+ model_file: str,
+ input_filename: str,
+ input_names: list,
+ output_names: list,
+ subgraph: SubGraphT = 0,
+ skip_outputs: bool = False,
+ show_progress: bool = False,
+ num_procs: int = 1,
+ num_threads: int = 0,
+) -> None:
+ """Extract a model after cut and record."""
try:
os.mkdir(output_path)
except FileExistsError:
@@ -39,7 +44,7 @@ def extract(
input_tfrec = os.path.join(output_path, "input.tfrec")
record_model(
- input_data,
+ input_filename,
start_file,
input_tfrec,
show_progress=show_progress,