aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/extract.py
diff options
context:
space:
mode:
authorAnnie Tallund <annie.tallund@arm.com>2023-03-13 17:00:31 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-10-11 15:41:48 +0100
commitf0b8ed75fed9dc69ab1f6313339f9f7e38bfc725 (patch)
treebc353fad664040b44915b5cf7ae807894b0b87e8 /src/mlia/nn/rewrite/core/extract.py
parentb236127b9a18ec2668271c6b5baafa6a7c1dde51 (diff)
downloadmlia-f0b8ed75fed9dc69ab1f6313339f9f7e38bfc725.tar.gz
MLIA-845 Migrate rewrite code
- Add required files for rewriting of TensorFlow Lite graphs - Adapt rewrite dependency paths and project name - Add license headers Change-Id: I19c5f63215fe2af2fa7d7d44af08144c6c5f911c Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
Diffstat (limited to 'src/mlia/nn/rewrite/core/extract.py')
-rw-r--r--src/mlia/nn/rewrite/core/extract.py87
1 files changed, 87 insertions, 0 deletions
diff --git a/src/mlia/nn/rewrite/core/extract.py b/src/mlia/nn/rewrite/core/extract.py
new file mode 100644
index 0000000..5fcd348
--- /dev/null
+++ b/src/mlia/nn/rewrite/core/extract.py
@@ -0,0 +1,87 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+import os
+
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+import tensorflow as tf
+
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
+from mlia.nn.rewrite.core.graph_edit.cut import cut_model
+from mlia.nn.rewrite.core.graph_edit.record import record_model
+
+
+def extract(
+ output_path,
+ model_file,
+ input_data,
+ input_names,
+ output_names,
+ subgraph=0,
+ skip_outputs=False,
+ show_progress=False,
+ num_procs=1,
+ num_threads=0,
+):
+ try:
+ os.mkdir(output_path)
+ except FileExistsError:
+ pass
+
+ start_file = os.path.join(output_path, "start.tflite")
+ cut_model(
+ model_file,
+ input_names=None,
+ output_names=input_names,
+ subgraph_index=subgraph,
+ output_file=start_file,
+ )
+
+ input_tfrec = os.path.join(output_path, "input.tfrec")
+ record_model(
+ input_data,
+ start_file,
+ input_tfrec,
+ show_progress=show_progress,
+ num_procs=num_procs,
+ num_threads=num_threads,
+ )
+
+ replace_file = os.path.join(output_path, "replace.tflite")
+ cut_model(
+ model_file,
+ input_names=input_names,
+ output_names=output_names,
+ subgraph_index=subgraph,
+ output_file=replace_file,
+ )
+
+ end_file = os.path.join(output_path, "end.tflite")
+ cut_model(
+ model_file,
+ input_names=output_names,
+ output_names=None,
+ subgraph_index=subgraph,
+ output_file=end_file,
+ )
+
+ if not skip_outputs:
+ output_tfrec = os.path.join(output_path, "output.tfrec")
+ record_model(
+ input_tfrec,
+ replace_file,
+ output_tfrec,
+ show_progress=show_progress,
+ num_procs=num_procs,
+ num_threads=num_threads,
+ )
+
+ end_tfrec = os.path.join(output_path, "end.tfrec")
+ record_model(
+ output_tfrec,
+ end_file,
+ end_tfrec,
+ show_progress=show_progress,
+ num_procs=num_procs,
+ num_threads=num_threads,
+ )