diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/extract.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/extract.py | 87 |
1 files changed, 87 insertions, 0 deletions
diff --git a/src/mlia/nn/rewrite/core/extract.py b/src/mlia/nn/rewrite/core/extract.py new file mode 100644 index 0000000..5fcd348 --- /dev/null +++ b/src/mlia/nn/rewrite/core/extract.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +import os + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +import tensorflow as tf + +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) + +from mlia.nn.rewrite.core.graph_edit.cut import cut_model +from mlia.nn.rewrite.core.graph_edit.record import record_model + + +def extract( + output_path, + model_file, + input_data, + input_names, + output_names, + subgraph=0, + skip_outputs=False, + show_progress=False, + num_procs=1, + num_threads=0, +): + try: + os.mkdir(output_path) + except FileExistsError: + pass + + start_file = os.path.join(output_path, "start.tflite") + cut_model( + model_file, + input_names=None, + output_names=input_names, + subgraph_index=subgraph, + output_file=start_file, + ) + + input_tfrec = os.path.join(output_path, "input.tfrec") + record_model( + input_data, + start_file, + input_tfrec, + show_progress=show_progress, + num_procs=num_procs, + num_threads=num_threads, + ) + + replace_file = os.path.join(output_path, "replace.tflite") + cut_model( + model_file, + input_names=input_names, + output_names=output_names, + subgraph_index=subgraph, + output_file=replace_file, + ) + + end_file = os.path.join(output_path, "end.tflite") + cut_model( + model_file, + input_names=output_names, + output_names=None, + subgraph_index=subgraph, + output_file=end_file, + ) + + if not skip_outputs: + output_tfrec = os.path.join(output_path, "output.tfrec") + record_model( + input_tfrec, + replace_file, + output_tfrec, + show_progress=show_progress, + num_procs=num_procs, + num_threads=num_threads, + ) + + end_tfrec = os.path.join(output_path, "end.tfrec") + record_model( + output_tfrec, + end_file, + end_tfrec, + show_progress=show_progress, + num_procs=num_procs, + num_threads=num_threads, + ) |