aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/extract.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/extract.py')
-rw-r--r--src/mlia/nn/rewrite/core/extract.py87
1 files changed, 87 insertions, 0 deletions
diff --git a/src/mlia/nn/rewrite/core/extract.py b/src/mlia/nn/rewrite/core/extract.py
new file mode 100644
index 0000000..5fcd348
--- /dev/null
+++ b/src/mlia/nn/rewrite/core/extract.py
@@ -0,0 +1,87 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+import os
+
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+import tensorflow as tf
+
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
+from mlia.nn.rewrite.core.graph_edit.cut import cut_model
+from mlia.nn.rewrite.core.graph_edit.record import record_model
+
+
+def extract(
+ output_path,
+ model_file,
+ input_data,
+ input_names,
+ output_names,
+ subgraph=0,
+ skip_outputs=False,
+ show_progress=False,
+ num_procs=1,
+ num_threads=0,
+):
+ try:
+ os.mkdir(output_path)
+ except FileExistsError:
+ pass
+
+ start_file = os.path.join(output_path, "start.tflite")
+ cut_model(
+ model_file,
+ input_names=None,
+ output_names=input_names,
+ subgraph_index=subgraph,
+ output_file=start_file,
+ )
+
+ input_tfrec = os.path.join(output_path, "input.tfrec")
+ record_model(
+ input_data,
+ start_file,
+ input_tfrec,
+ show_progress=show_progress,
+ num_procs=num_procs,
+ num_threads=num_threads,
+ )
+
+ replace_file = os.path.join(output_path, "replace.tflite")
+ cut_model(
+ model_file,
+ input_names=input_names,
+ output_names=output_names,
+ subgraph_index=subgraph,
+ output_file=replace_file,
+ )
+
+ end_file = os.path.join(output_path, "end.tflite")
+ cut_model(
+ model_file,
+ input_names=output_names,
+ output_names=None,
+ subgraph_index=subgraph,
+ output_file=end_file,
+ )
+
+ if not skip_outputs:
+ output_tfrec = os.path.join(output_path, "output.tfrec")
+ record_model(
+ input_tfrec,
+ replace_file,
+ output_tfrec,
+ show_progress=show_progress,
+ num_procs=num_procs,
+ num_threads=num_threads,
+ )
+
+ end_tfrec = os.path.join(output_path, "end.tfrec")
+ record_model(
+ output_tfrec,
+ end_file,
+ end_tfrec,
+ show_progress=show_progress,
+ num_procs=num_procs,
+ num_threads=num_threads,
+ )