#!/usr/bin/env python3 # Copyright (c) 2021 Arm Limited. # # SPDX-License-Identifier: MIT # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to # deal in the Software without restriction, including without limitation the # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or # sell copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. import json import logging import os import sys from argparse import ArgumentParser import tflite sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../") from utils.model_identification import identify_model_type from utils.tflite_helpers import tflite_op2acl, tflite_typecode2name, tflite_typecode2aclname SUPPORTED_MODEL_TYPES = ["tflite"] logger = logging.getLogger("report_model_ops") def get_ops_types_from_tflite_graph(model): """ Helper function that extract operator related meta-data from a TFLite model Parameters ---------- model: str Respective TFLite model to analyse Returns ---------- supported_ops, unsupported_ops, data_types: tuple A tuple with the sets of unique operator types and data-types that are present in the model """ logger.debug(f"Analysing TFLite mode '{model}'!") with open(model, "rb") as f: buf = f.read() model = tflite.Model.GetRootAsModel(buf, 0) # Extract unique operators nr_unique_ops = model.OperatorCodesLength() unique_ops = {tflite.opcode2name(model.OperatorCodes(op_id).BuiltinCode()) for op_id in range(0, nr_unique_ops)} # Extract IO data-types supported_data_types = set() unsupported_data_types = set() for subgraph_id in range(0, model.SubgraphsLength()): subgraph = model.Subgraphs(subgraph_id) for tensor_id in range(0, subgraph.TensorsLength()): try: supported_data_types.add(tflite_typecode2aclname(subgraph.Tensors(tensor_id).Type())) except ValueError: unsupported_data_types.add(tflite_typecode2name(subgraph.Tensors(tensor_id).Type())) logger.warning(f"Data type {tflite_typecode2name(subgraph.Tensors(tensor_id).Type())} is not supported by ComputeLibrary") # Perform mapping between TfLite ops to ComputeLibrary ones supported_ops = set() unsupported_ops = set() for top in unique_ops: try: supported_ops.add(tflite_op2acl(top)) except ValueError: unsupported_ops.add(top) logger.warning(f"Operator {top} does not have ComputeLibrary mapping") return (supported_ops, unsupported_ops, supported_data_types, unsupported_data_types) def extract_model_meta(model, model_type): """ Function that calls the appropriate model parser to extract model related meta-data Supported parsers: TFLite Parameters ---------- model: str Path to model that we want to analyze model_type: type of the model Returns ---------- ops, data_types: (tuple) A tuple with the list of unique operator types and data-types that are present in the model """ if model_type == "tflite": return get_ops_types_from_tflite_graph(model) else: logger.warning(f"Model type '{model_type}' is unsupported!") return () def generate_build_config(ops, data_types, data_layouts): """ Function that generates a compatible ComputeLibrary operator-based build configuration Parameters ---------- ops: set Set with the operators to add in the build configuration data_types: Set with the data types to add in the build configuration data_layouts: Set with the data layouts to add in the build configuration Returns ---------- config_data: dict Dictionary compatible with ComputeLibrary """ config_data = {} config_data["operators"] = list(ops) config_data["data_types"] = list(data_types) config_data["data_layouts"] = list(data_layouts) return config_data if __name__ == "__main__": parser = ArgumentParser( description="""Report map of operations in a list of models. The script consumes deep learning models and reports the type of operations and data-types used Supported model types: TFLite """ ) parser.add_argument( "-m", "--models", nargs="+", required=True, type=str, help=f"List of models; supported model types: {SUPPORTED_MODEL_TYPES}", ) parser.add_argument("-D", "--debug", action="store_true", help="Enable script debugging output") parser.add_argument( "-c", "--config", type=str, help="JSON configuration file used that can be used for custom ComputeLibrary builds", ) args = parser.parse_args() # Setup Logger logging_level = logging.INFO if args.debug: logging_level = logging.DEBUG logging.basicConfig(level=logging_level) # Extract operator mapping final_supported_ops = set() final_unsupported_ops = set() final_supported_dts = set() final_unsupported_dts = set() final_layouts = {"nhwc"} # Data layout for TFLite is always NHWC for model in args.models: logger.debug(f"Starting analyzing {model} model") model_type = identify_model_type(model) supported_model_ops, unsupported_mode_ops, supported_model_dts, unsupported_model_dts = extract_model_meta(model, model_type) final_supported_ops.update(supported_model_ops) final_unsupported_ops.update(unsupported_mode_ops) final_supported_dts.update(supported_model_dts) final_unsupported_dts.update(unsupported_model_dts) logger.info("=== Supported Operators") logger.info(final_supported_ops) if(len(final_unsupported_ops)): logger.info("=== Unsupported Operators") logger.info(final_unsupported_ops) logger.info("=== Data Types") logger.info(final_supported_dts) if(len(final_unsupported_dts)): logger.info("=== Unsupported Data Types") logger.info(final_unsupported_dts) logger.info("=== Data Layouts") logger.info(final_layouts) # Generate JSON file if args.config: logger.debug("Generating JSON build configuration file") config_data = generate_build_config(final_supported_ops, final_supported_dts, final_layouts) with open(args.config, "w") as f: json.dump(config_data, f)