aboutsummaryrefslogtreecommitdiff
path: root/python/scripts/report-model-ops/report_model_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/scripts/report-model-ops/report_model_ops.py')
-rwxr-xr-xpython/scripts/report-model-ops/report_model_ops.py205
1 files changed, 205 insertions, 0 deletions
diff --git a/python/scripts/report-model-ops/report_model_ops.py b/python/scripts/report-model-ops/report_model_ops.py
new file mode 100755
index 0000000000..1549005da5
--- /dev/null
+++ b/python/scripts/report-model-ops/report_model_ops.py
@@ -0,0 +1,205 @@
+#!/usr/bin/env python3
+# Copyright (c) 2021 Arm Limited.
+#
+# SPDX-License-Identifier: MIT
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to
+# deal in the Software without restriction, including without limitation the
+# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+# sell copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+import json
+import logging
+import os
+import sys
+from argparse import ArgumentParser
+
+import tflite
+
+sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
+
+from utils.model_identification import identify_model_type
+from utils.tflite_helpers import tflite_op2acl, tflite_typecode2name, tflite_typecode2aclname
+
+SUPPORTED_MODEL_TYPES = ["tflite"]
+logger = logging.getLogger("report_model_ops")
+
+
+def get_ops_types_from_tflite_graph(model):
+ """
+ Helper function that extract operator related meta-data from a TFLite model
+
+ Parameters
+ ----------
+ model: str
+ Respective TFLite model to analyse
+
+ Returns
+ ----------
+ supported_ops, unsupported_ops, data_types: tuple
+ A tuple with the sets of unique operator types and data-types that are present in the model
+ """
+
+ logger.debug(f"Analysing TFLite mode '{model}'!")
+
+ with open(model, "rb") as f:
+ buf = f.read()
+ model = tflite.Model.GetRootAsModel(buf, 0)
+
+ # Extract unique operators
+ nr_unique_ops = model.OperatorCodesLength()
+ unique_ops = {tflite.opcode2name(model.OperatorCodes(op_id).BuiltinCode()) for op_id in range(0, nr_unique_ops)}
+
+ # Extract IO data-types
+ supported_data_types = set()
+ unsupported_data_types = set()
+ for subgraph_id in range(0, model.SubgraphsLength()):
+ subgraph = model.Subgraphs(subgraph_id)
+ for tensor_id in range(0, subgraph.TensorsLength()):
+ try:
+ supported_data_types.add(tflite_typecode2aclname(subgraph.Tensors(tensor_id).Type()))
+ except ValueError:
+ unsupported_data_types.add(tflite_typecode2name(subgraph.Tensors(tensor_id).Type()))
+ logger.warning(f"Data type {tflite_typecode2name(subgraph.Tensors(tensor_id).Type())} is not supported by ComputeLibrary")
+
+ # Perform mapping between TfLite ops to ComputeLibrary ones
+ supported_ops = set()
+ unsupported_ops = set()
+ for top in unique_ops:
+ try:
+ supported_ops.add(tflite_op2acl(top))
+ except ValueError:
+ unsupported_ops.add(top)
+ logger.warning(f"Operator {top} does not have ComputeLibrary mapping")
+
+ return (supported_ops, unsupported_ops, supported_data_types, unsupported_data_types)
+
+
+def extract_model_meta(model, model_type):
+ """
+ Function that calls the appropriate model parser to extract model related meta-data
+ Supported parsers: TFLite
+
+ Parameters
+ ----------
+ model: str
+ Path to model that we want to analyze
+ model_type:
+ type of the model
+
+ Returns
+ ----------
+ ops, data_types: (tuple)
+ A tuple with the list of unique operator types and data-types that are present in the model
+ """
+
+ if model_type == "tflite":
+ return get_ops_types_from_tflite_graph(model)
+ else:
+ logger.warning(f"Model type '{model_type}' is unsupported!")
+ return ()
+
+
+def generate_build_config(ops, data_types, data_layouts):
+ """
+ Function that generates a compatible ComputeLibrary operator-based build configuration
+
+ Parameters
+ ----------
+ ops: set
+ Set with the operators to add in the build configuration
+ data_types:
+ Set with the data types to add in the build configuration
+ data_layouts:
+ Set with the data layouts to add in the build configuration
+
+ Returns
+ ----------
+ config_data: dict
+ Dictionary compatible with ComputeLibrary
+ """
+ config_data = {}
+ config_data["operators"] = list(ops)
+ config_data["data_types"] = list(data_types)
+ config_data["data_layouts"] = list(data_layouts)
+
+ return config_data
+
+
+if __name__ == "__main__":
+ parser = ArgumentParser(
+ description="""Report map of operations in a list of models.
+ The script consumes deep learning models and reports the type of operations and data-types used
+ Supported model types: TFLite """
+ )
+
+ parser.add_argument(
+ "-m",
+ "--models",
+ nargs="+",
+ required=True,
+ type=str,
+ help=f"List of models; supported model types: {SUPPORTED_MODEL_TYPES}",
+ )
+ parser.add_argument("-D", "--debug", action="store_true", help="Enable script debugging output")
+ parser.add_argument(
+ "-c",
+ "--config",
+ type=str,
+ help="JSON configuration file used that can be used for custom ComputeLibrary builds",
+ )
+ args = parser.parse_args()
+
+ # Setup Logger
+ logging_level = logging.INFO
+ if args.debug:
+ logging_level = logging.DEBUG
+ logging.basicConfig(level=logging_level)
+
+ # Extract operator mapping
+ final_supported_ops = set()
+ final_unsupported_ops = set()
+ final_supported_dts = set()
+ final_unsupported_dts = set()
+ final_layouts = {"nhwc"} # Data layout for TFLite is always NHWC
+ for model in args.models:
+ logger.debug(f"Starting analyzing {model} model")
+
+ model_type = identify_model_type(model)
+ supported_model_ops, unsupported_mode_ops, supported_model_dts, unsupported_model_dts = extract_model_meta(model, model_type)
+ final_supported_ops.update(supported_model_ops)
+ final_unsupported_ops.update(unsupported_mode_ops)
+ final_supported_dts.update(supported_model_dts)
+ final_unsupported_dts.update(unsupported_model_dts)
+
+ logger.info("=== Supported Operators")
+ logger.info(final_supported_ops)
+ if(len(final_unsupported_ops)):
+ logger.info("=== Unsupported Operators")
+ logger.info(final_unsupported_ops)
+ logger.info("=== Data Types")
+ logger.info(final_supported_dts)
+ if(len(final_unsupported_dts)):
+ logger.info("=== Unsupported Data Types")
+ logger.info(final_unsupported_dts)
+ logger.info("=== Data Layouts")
+ logger.info(final_layouts)
+
+ # Generate JSON file
+ if args.config:
+ logger.debug("Generating JSON build configuration file")
+ config_data = generate_build_config(final_supported_ops, final_supported_dts, final_layouts)
+ with open(args.config, "w") as f:
+ json.dump(config_data, f)