From 6a88ee5315b4ce5b023370c1e55e48bf9f2b6f67 Mon Sep 17 00:00:00 2001 From: Dmitrii Agibov Date: Fri, 18 Nov 2022 17:21:09 +0000 Subject: Rename modules - Rename module "mlia.devices" into "mlia.target" - Rename module "mlia.target.ethosu" into "mlia.target.ethos_u" - Rename module "mlia.target.cortexa" into "mlia.target.cortex_a" - Rename and update tests Change-Id: I6dca7c8646d881f739fb6b5914d1cc7e45e63dc2 --- src/mlia/target/cortex_a/__init__.py | 3 + src/mlia/target/cortex_a/advice_generation.py | 153 +++++++++++++++++ src/mlia/target/cortex_a/advisor.py | 92 +++++++++++ src/mlia/target/cortex_a/config.py | 20 +++ src/mlia/target/cortex_a/data_analysis.py | 128 ++++++++++++++ src/mlia/target/cortex_a/data_collection.py | 51 ++++++ src/mlia/target/cortex_a/events.py | 24 +++ src/mlia/target/cortex_a/handlers.py | 39 +++++ src/mlia/target/cortex_a/operator_compatibility.py | 184 +++++++++++++++++++++ src/mlia/target/cortex_a/operators.py | 148 +++++++++++++++++ src/mlia/target/cortex_a/reporters.py | 140 ++++++++++++++++ 11 files changed, 982 insertions(+) create mode 100644 src/mlia/target/cortex_a/__init__.py create mode 100644 src/mlia/target/cortex_a/advice_generation.py create mode 100644 src/mlia/target/cortex_a/advisor.py create mode 100644 src/mlia/target/cortex_a/config.py create mode 100644 src/mlia/target/cortex_a/data_analysis.py create mode 100644 src/mlia/target/cortex_a/data_collection.py create mode 100644 src/mlia/target/cortex_a/events.py create mode 100644 src/mlia/target/cortex_a/handlers.py create mode 100644 src/mlia/target/cortex_a/operator_compatibility.py create mode 100644 src/mlia/target/cortex_a/operators.py create mode 100644 src/mlia/target/cortex_a/reporters.py (limited to 'src/mlia/target/cortex_a') diff --git a/src/mlia/target/cortex_a/__init__.py b/src/mlia/target/cortex_a/__init__.py new file mode 100644 index 0000000..fe01835 --- /dev/null +++ b/src/mlia/target/cortex_a/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Cortex-A target module.""" diff --git a/src/mlia/target/cortex_a/advice_generation.py b/src/mlia/target/cortex_a/advice_generation.py new file mode 100644 index 0000000..b68106e --- /dev/null +++ b/src/mlia/target/cortex_a/advice_generation.py @@ -0,0 +1,153 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Cortex-A advice generation.""" +from functools import singledispatchmethod + +from mlia.core.advice_generation import advice_category +from mlia.core.advice_generation import FactBasedAdviceProducer +from mlia.core.common import AdviceCategory +from mlia.core.common import DataItem +from mlia.target.cortex_a.data_analysis import ModelHasCustomOperators +from mlia.target.cortex_a.data_analysis import ModelIsCortexACompatible +from mlia.target.cortex_a.data_analysis import ModelIsNotCortexACompatible +from mlia.target.cortex_a.data_analysis import ModelIsNotTFLiteCompatible +from mlia.target.cortex_a.data_analysis import TFLiteCompatibilityCheckFailed + + +class CortexAAdviceProducer(FactBasedAdviceProducer): + """Cortex-A advice producer.""" + + cortex_a_disclaimer = ( + "Note that the provided compatibility information is general. " + "At runtime individual operators in the given model might fall back to " + "the TensorFlow Lite reference or might produce errors based on the " + "specific parameters." + ) + + @singledispatchmethod + def produce_advice(self, _data_item: DataItem) -> None: # type: ignore + """Produce advice.""" + + @produce_advice.register + @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS) + def handle_model_is_cortex_a_compatible( + self, data_item: ModelIsCortexACompatible + ) -> None: + """Advice for Cortex-A compatibility.""" + self.add_advice( + [ + f"Model is fully compatible with {data_item.backend_info} for " + "Cortex-A.", + self.cortex_a_disclaimer, + ] + ) + + @produce_advice.register + @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS) + def handle_model_is_not_cortex_a_compatible( + self, data_item: ModelIsNotCortexACompatible + ) -> None: + """Advice for Cortex-A compatibility.""" + if data_item.unsupported_ops: + self.add_advice( + [ + "The following operators are not supported by " + f"{data_item.backend_info} and will fall back to the " + "TensorFlow Lite runtime:", + "\n".join(f" - {op}" for op in data_item.unsupported_ops), + ] + ) + + if data_item.activation_func_support: + self.add_advice( + [ + "The fused activation functions of the following operators " + f"are not supported by {data_item.backend_info}. Please " + "consider using one of the supported activation functions " + "instead:", + "\n".join( + f" - {op}\n" + f" - Used unsupported: {act.used_unsupported}\n" + f" - Supported: {act.supported}" + for op, act in data_item.activation_func_support.items() + ), + ] + ) + + self.add_advice( + [ + "Please, refer to the full table of operators above for more " + "information.", + self.cortex_a_disclaimer, + ] + ) + + @produce_advice.register + @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS) + def handle_model_is_not_tflite_compatible( + self, data_item: ModelIsNotTFLiteCompatible + ) -> None: + """Advice for TensorFlow Lite compatibility.""" + if data_item.flex_ops: + self.add_advice( + [ + "The following operators are not natively " + "supported by TensorFlow Lite: " + f"{', '.join(data_item.flex_ops)}.", + "Using select TensorFlow operators in TensorFlow Lite model " + "requires special initialization of TFLiteConverter and " + "TensorFlow Lite run-time.", + "Please refer to the TensorFlow documentation for more " + "details: https://www.tensorflow.org/lite/guide/ops_select", + "Note, such models are not supported by the ML Inference Advisor.", + ] + ) + + if data_item.custom_ops: + self.add_advice( + [ + "The following operators appear to be custom and not natively " + "supported by TensorFlow Lite: " + f"{', '.join(data_item.custom_ops)}.", + "Using custom operators in TensorFlow Lite model " + "requires special initialization of TFLiteConverter and " + "TensorFlow Lite run-time.", + "Please refer to the TensorFlow documentation for more " + "details: https://www.tensorflow.org/lite/guide/ops_custom", + "Note, such models are not supported by the ML Inference Advisor.", + ] + ) + + if not data_item.flex_ops and not data_item.custom_ops: + self.add_advice( + [ + "Model could not be converted into TensorFlow Lite format.", + "Please refer to the table for more details.", + ] + ) + + @produce_advice.register + @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS) + def handle_tflite_check_failed( + self, _data_item: TFLiteCompatibilityCheckFailed + ) -> None: + """Advice for the failed TensorFlow Lite compatibility checks.""" + self.add_advice( + [ + "Model could not be converted into TensorFlow Lite format.", + "Please refer to the table for more details.", + ] + ) + + @produce_advice.register + @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS) + def handle_model_has_custom_operators( + self, _data_item: ModelHasCustomOperators + ) -> None: + """Advice for the models with custom operators.""" + self.add_advice( + [ + "Models with custom operators require special initialization " + "and currently are not supported by the ML Inference Advisor.", + ] + ) diff --git a/src/mlia/target/cortex_a/advisor.py b/src/mlia/target/cortex_a/advisor.py new file mode 100644 index 0000000..5912e38 --- /dev/null +++ b/src/mlia/target/cortex_a/advisor.py @@ -0,0 +1,92 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Cortex-A MLIA module.""" +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from mlia.core.advice_generation import AdviceProducer +from mlia.core.advisor import DefaultInferenceAdvisor +from mlia.core.advisor import InferenceAdvisor +from mlia.core.common import AdviceCategory +from mlia.core.context import Context +from mlia.core.context import ExecutionContext +from mlia.core.data_analysis import DataAnalyzer +from mlia.core.data_collection import DataCollector +from mlia.core.events import Event +from mlia.core.typing import PathOrFileLike +from mlia.target.cortex_a.advice_generation import CortexAAdviceProducer +from mlia.target.cortex_a.config import CortexAConfiguration +from mlia.target.cortex_a.data_analysis import CortexADataAnalyzer +from mlia.target.cortex_a.data_collection import CortexAOperatorCompatibility +from mlia.target.cortex_a.events import CortexAAdvisorStartedEvent +from mlia.target.cortex_a.handlers import CortexAEventHandler + + +class CortexAInferenceAdvisor(DefaultInferenceAdvisor): + """Cortex-A Inference Advisor.""" + + @classmethod + def name(cls) -> str: + """Return name of the advisor.""" + return "cortex_a_inference_advisor" + + def get_collectors(self, context: Context) -> list[DataCollector]: + """Return list of the data collectors.""" + model = self.get_model(context) + + collectors: list[DataCollector] = [] + + if AdviceCategory.OPERATORS in context.advice_category: + collectors.append(CortexAOperatorCompatibility(model)) + + return collectors + + def get_analyzers(self, context: Context) -> list[DataAnalyzer]: + """Return list of the data analyzers.""" + return [ + CortexADataAnalyzer(), + ] + + def get_producers(self, context: Context) -> list[AdviceProducer]: + """Return list of the advice producers.""" + return [CortexAAdviceProducer()] + + def get_events(self, context: Context) -> list[Event]: + """Return list of the startup events.""" + model = self.get_model(context) + target_profile = self.get_target_profile(context) + + return [ + CortexAAdvisorStartedEvent(model, CortexAConfiguration(target_profile)), + ] + + +def configure_and_get_cortexa_advisor( + context: ExecutionContext, + target_profile: str, + model: str | Path, + output: PathOrFileLike | None = None, + **_extra_args: Any, +) -> InferenceAdvisor: + """Create and configure Cortex-A advisor.""" + if context.event_handlers is None: + context.event_handlers = [CortexAEventHandler(output)] + + if context.config_parameters is None: + context.config_parameters = _get_config_parameters(model, target_profile) + + return CortexAInferenceAdvisor() + + +def _get_config_parameters(model: str | Path, target_profile: str) -> dict[str, Any]: + """Get configuration parameters for the advisor.""" + advisor_parameters: dict[str, Any] = { + "cortex_a_inference_advisor": { + "model": str(model), + "target_profile": target_profile, + }, + } + + return advisor_parameters diff --git a/src/mlia/target/cortex_a/config.py b/src/mlia/target/cortex_a/config.py new file mode 100644 index 0000000..b2b51ea --- /dev/null +++ b/src/mlia/target/cortex_a/config.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Cortex-A configuration.""" +from __future__ import annotations + +from mlia.target.config import IPConfiguration +from mlia.utils.filesystem import get_profile + + +class CortexAConfiguration(IPConfiguration): # pylint: disable=too-few-public-methods + """Cortex-A configuration.""" + + def __init__(self, target_profile: str) -> None: + """Init Cortex-A target configuration.""" + target_data = get_profile(target_profile) + + target = target_data["target"] + if target != "cortex-a": + raise Exception(f"Wrong target {target} for Cortex-A configuration") + super().__init__(target) diff --git a/src/mlia/target/cortex_a/data_analysis.py b/src/mlia/target/cortex_a/data_analysis.py new file mode 100644 index 0000000..4a3a068 --- /dev/null +++ b/src/mlia/target/cortex_a/data_analysis.py @@ -0,0 +1,128 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Cortex-A data analysis module.""" +from __future__ import annotations + +from collections import defaultdict +from dataclasses import dataclass +from dataclasses import field +from functools import singledispatchmethod + +from mlia.core.common import DataItem +from mlia.core.data_analysis import Fact +from mlia.core.data_analysis import FactExtractor +from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo +from mlia.target.cortex_a.operators import CortexACompatibilityInfo +from mlia.target.cortex_a.operators import Operator + + +class CortexADataAnalyzer(FactExtractor): + """Cortex-A data analyzer.""" + + @singledispatchmethod + def analyze_data(self, data_item: DataItem) -> None: # type: ignore + """Analyse the data.""" + + @analyze_data.register + def analyze_operator_compatibility( + self, data_item: CortexACompatibilityInfo + ) -> None: + """Analyse operator compatibility information.""" + if data_item.cortex_a_compatible: + self.add_fact(ModelIsCortexACompatible(data_item.backend_info)) + else: + unsupported_ops = set() + activation_func_support: defaultdict[ + str, ModelIsNotCortexACompatible.ActivationFunctionSupport + ] = defaultdict(ModelIsNotCortexACompatible.ActivationFunctionSupport) + for oper in data_item.operators: + if oper.support_type == Operator.SupportType.OP_NOT_SUPPORTED: + unsupported_ops.add(oper.full_name) + + if oper.support_type == Operator.SupportType.ACTIVATION_NOT_SUPPORTED: + # Add used but unsupported actication functions + activation_func_support[oper.full_name].used_unsupported.add( + oper.activation_func.name + ) + # Add supported activation functions + activation_func_support[oper.full_name].supported.update( + oper.supported_activation_functions + ) + + assert ( + unsupported_ops or activation_func_support or not data_item.operators + ), ( + "The model is marked as not compatible with Cortex-A but there " + "are no unsupported ops activation functions listed." + ) + + self.add_fact( + ModelIsNotCortexACompatible( + data_item.backend_info, unsupported_ops, activation_func_support + ) + ) + + @analyze_data.register + def analyze_tflite_compatibility(self, data_item: TFLiteCompatibilityInfo) -> None: + """Analyze TensorFlow Lite compatibility information.""" + if data_item.compatible: + return + + if data_item.conversion_failed_with_errors: + self.add_fact( + ModelIsNotTFLiteCompatible( + custom_ops=data_item.required_custom_ops, + flex_ops=data_item.required_flex_ops, + ) + ) + + if data_item.check_failed_with_unknown_error: + self.add_fact(TFLiteCompatibilityCheckFailed()) + + if data_item.conversion_failed_for_model_with_custom_ops: + self.add_fact(ModelHasCustomOperators()) + + +@dataclass +class CortexACompatibility(Fact): + """Base class for Cortex-A compatibility providing backend info.""" + + backend_info: str + + +@dataclass +class ModelIsCortexACompatible(CortexACompatibility): + """Model is completely compatible with Cortex-A.""" + + +@dataclass +class ModelIsNotCortexACompatible(CortexACompatibility): + """Model is not compatible with Cortex-A.""" + + @dataclass + class ActivationFunctionSupport: + """Activation function support per operator.""" + + used_unsupported: set[str] = field(default_factory=set) + supported: set[str] = field(default_factory=set) + + unsupported_ops: set[str] + activation_func_support: dict[str, ActivationFunctionSupport] + + +@dataclass +class ModelIsNotTFLiteCompatible(Fact): + """Model could not be converted into TensorFlow Lite format.""" + + custom_ops: list[str] | None = None + flex_ops: list[str] | None = None + + +@dataclass +class TFLiteCompatibilityCheckFailed(Fact): + """TensorFlow Lite compatibility check failed by unknown reason.""" + + +@dataclass +class ModelHasCustomOperators(Fact): + """Model could not be loaded because it contains custom ops.""" diff --git a/src/mlia/target/cortex_a/data_collection.py b/src/mlia/target/cortex_a/data_collection.py new file mode 100644 index 0000000..3ec63e2 --- /dev/null +++ b/src/mlia/target/cortex_a/data_collection.py @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Data collection module for Cortex-A.""" +from __future__ import annotations + +import logging +from pathlib import Path + +from mlia.core.data_collection import ContextAwareDataCollector +from mlia.nn.tensorflow.config import get_tflite_model +from mlia.nn.tensorflow.tflite_compat import TFLiteChecker +from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo +from mlia.nn.tensorflow.utils import is_tflite_model +from mlia.target.cortex_a.operators import CortexACompatibilityInfo +from mlia.target.cortex_a.operators import get_cortex_a_compatibility_info +from mlia.utils.logging import log_action + + +logger = logging.getLogger(__name__) + + +class CortexAOperatorCompatibility(ContextAwareDataCollector): + """Collect operator compatibility information.""" + + def __init__(self, model: Path) -> None: + """Init operator compatibility data collector.""" + self.model = model + + def collect_data(self) -> TFLiteCompatibilityInfo | CortexACompatibilityInfo | None: + """Collect operator compatibility information.""" + if not is_tflite_model(self.model): + with log_action("Checking TensorFlow Lite compatibility ..."): + tflite_checker = TFLiteChecker() + tflite_compat = tflite_checker.check_compatibility(self.model) + + if not tflite_compat.compatible: + return tflite_compat + + tflite_model = get_tflite_model(self.model, self.context) + + with log_action("Checking operator compatibility ..."): + return ( + get_cortex_a_compatibility_info( # pylint: disable=assignment-from-none + Path(tflite_model.model_path) + ) + ) + + @classmethod + def name(cls) -> str: + """Return name of the collector.""" + return "cortex_a_operator_compatibility" diff --git a/src/mlia/target/cortex_a/events.py b/src/mlia/target/cortex_a/events.py new file mode 100644 index 0000000..a172d0d --- /dev/null +++ b/src/mlia/target/cortex_a/events.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Cortex-A MLIA module events.""" +from dataclasses import dataclass +from pathlib import Path + +from mlia.core.events import Event +from mlia.core.events import EventDispatcher +from mlia.target.cortex_a.config import CortexAConfiguration + + +@dataclass +class CortexAAdvisorStartedEvent(Event): + """Event with Cortex-A advisor parameters.""" + + model: Path + device: CortexAConfiguration + + +class CortexAAdvisorEventHandler(EventDispatcher): + """Event handler for the Cortex-A inference advisor.""" + + def on_cortex_a_advisor_started(self, event: CortexAAdvisorStartedEvent) -> None: + """Handle CortexAAdvisorStarted event.""" diff --git a/src/mlia/target/cortex_a/handlers.py b/src/mlia/target/cortex_a/handlers.py new file mode 100644 index 0000000..b2d5faa --- /dev/null +++ b/src/mlia/target/cortex_a/handlers.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Event handler.""" +from __future__ import annotations + +import logging + +from mlia.core.events import CollectedDataEvent +from mlia.core.handlers import WorkflowEventsHandler +from mlia.core.typing import PathOrFileLike +from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo +from mlia.target.cortex_a.events import CortexAAdvisorEventHandler +from mlia.target.cortex_a.events import CortexAAdvisorStartedEvent +from mlia.target.cortex_a.operators import CortexACompatibilityInfo +from mlia.target.cortex_a.reporters import cortex_a_formatters + +logger = logging.getLogger(__name__) + + +class CortexAEventHandler(WorkflowEventsHandler, CortexAAdvisorEventHandler): + """CLI event handler.""" + + def __init__(self, output: PathOrFileLike | None = None) -> None: + """Init event handler.""" + super().__init__(cortex_a_formatters, output) + + def on_collected_data(self, event: CollectedDataEvent) -> None: + """Handle CollectedDataEvent event.""" + data_item = event.data_item + + if isinstance(data_item, CortexACompatibilityInfo): + self.reporter.submit(data_item.operators, delay_print=True) + + if isinstance(data_item, TFLiteCompatibilityInfo) and not data_item.compatible: + self.reporter.submit(data_item, delay_print=True) + + def on_cortex_a_advisor_started(self, event: CortexAAdvisorStartedEvent) -> None: + """Handle CortexAAdvisorStarted event.""" + self.reporter.submit(event.device) diff --git a/src/mlia/target/cortex_a/operator_compatibility.py b/src/mlia/target/cortex_a/operator_compatibility.py new file mode 100644 index 0000000..c474e75 --- /dev/null +++ b/src/mlia/target/cortex_a/operator_compatibility.py @@ -0,0 +1,184 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Collection of Cortex-A operator compatibility information.""" +from __future__ import annotations + +from typing import Any + +ARMNN_TFLITE_DELEGATE: dict[str, dict[str, Any]] = { + "metadata": { + "backend": "Arm NN TensorFlow Lite delegate", + "version": "22.08", + }, + # BUILTIN OPERATORS + "builtin_ops": { + "ABS": {}, + "ADD": {}, + "ARG_MAX": {}, + "ARG_MIN": {}, + "AVERAGE_POOL_2D": { + "supported_fused_activation": [ + "RELU", + "RELU6", + "RELU_N1_TO_1", + "SIGMOID", + "TANH", + "NONE", + ] + }, + "BATCH_TO_SPACE_ND": {}, + "CAST": {}, + "CONCATENATION": { + "supported_fused_activation": [ + "RELU", + "RELU6", + "RELU_N1_TO_1", + "SIGMOID", + "TANH", + "NONE", + ] + }, + "CONV_2D": { + "supported_fused_activation": [ + "RELU", + "RELU6", + "RELU_N1_TO_1", + "SIGMOID", + "TANH", + "NONE", + ] + }, + "CONV_3D": { + "supported_fused_activation": [ + "RELU", + "RELU6", + "RELU_N1_TO_1", + "SIGMOID", + "TANH", + "NONE", + ] + }, + "DEPTH_TO_SPACE": {}, + "DEPTHWISE_CONV_2D": { + "supported_fused_activation": [ + "RELU", + "RELU6", + "RELU_N1_TO_1", + "SIGMOID", + "TANH", + "NONE", + ] + }, + "DEQUANTIZE": {}, + "DIV": {}, + "EQUAL": {}, + "ELU": {}, + "EXP": {}, + "EXPAND_DIMS": {}, + "FILL": {}, + "FLOOR": {}, + "FLOOR_DIV": {}, + "FULLY_CONNECTED": { + "supported_fused_activation": [ + "RELU", + "RELU6", + "RELU_N1_TO_1", + "SIGMOID", + "TANH", + "NONE", + ] + }, + "GATHER": {}, + "GATHER_ND": {}, + "GREATER": {}, + "GREATER_EQUAL": {}, + "HARD_SWISH": {}, + "L2_NORMALIZATION": {}, + "L2_POOL_2D": {}, + "LESS": {}, + "LESS_EQUAL": {}, + "LOCAL_RESPONSE_NORMALIZATION": {}, + "LOG": {}, + "LOGICAL_AND": {}, + "LOGICAL_NOT": {}, + "LOGICAL_OR": {}, + "LOGISTIC": {}, + "LOG_SOFTMAX": {}, + "LSTM": {}, + "MAXIMUM": {}, + "MAX_POOL_2D": { + "supported_fused_activation": [ + "RELU", + "RELU6", + "RELU_N1_TO_1", + "SIGMOID", + "TANH", + "NONE", + ] + }, + "MEAN": {}, + "MINIMUM": {}, + "MIRROR_PAD": {}, + "MUL": {}, + "NEG": {}, + "NOT_EQUAL": {}, + "PACK": {}, + "PAD": {}, + "PADV2": {}, + "PRELU": {}, + "QUANTIZE": {}, + "RANK": {}, + "REDUCE_MAX": {}, + "REDUCE_MIN": {}, + "REDUCE_PROD": {}, + "RELU": {}, + "RELU6": {}, + "RELU_N1_TO_1": {}, + "RESHAPE": {}, + "RESIZE_BILINEAR": {}, + "RESIZE_NEAREST_NEIGHBOR": {}, + "RSQRT": {}, + "SHAPE": {}, + "SIN": {}, + "SOFTMAX": {}, + "SPACE_TO_BATCH_ND": {}, + "SPACE_TO_DEPTH": {}, + "SPLIT": {}, + "SPLIT_V": {}, + "SQRT": {}, + "SQUEEZE": {}, + "STRIDED_SLICE": {}, + "SUB": {}, + "SUM": {}, + "TANH": {}, + "TRANSPOSE": {}, + "TRANSPOSE_CONV": {}, + "UNIDIRECTIONAL_SEQUENCE_LSTM": {}, + "UNPACK": {}, + }, + # CUSTOM OPERATORS + "custom_ops": { + "AveragePool3D": { + "supported_fused_activation": [ + "RELU", + "RELU6", + "RELU_N1_TO_1", + "SIGMOID", + "SIGN_BIT", + "TANH", + "NONE", + ] + }, + "MaxPool3D": { + "supported_fused_activation": [ + "RELU", + "RELU6", + "RELU_N1_TO_1", + "SIGMOID", + "SIGN_BIT", + "TANH", + "NONE", + ] + }, + }, +} diff --git a/src/mlia/target/cortex_a/operators.py b/src/mlia/target/cortex_a/operators.py new file mode 100644 index 0000000..91f1886 --- /dev/null +++ b/src/mlia/target/cortex_a/operators.py @@ -0,0 +1,148 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Cortex-A tools module.""" +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Any +from typing import ClassVar + +from mlia.nn.tensorflow.tflite_graph import Op +from mlia.nn.tensorflow.tflite_graph import parse_subgraphs +from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION +from mlia.target.cortex_a.operator_compatibility import ( + ARMNN_TFLITE_DELEGATE as TFLITE_DELEGATE_COMPAT, +) + + +@dataclass +class Operator: + """Cortex-A compatibility information of the operator.""" + + BUILTIN_COMPATIBILITY = TFLITE_DELEGATE_COMPAT["builtin_ops"] + CUSTOM_COMPATIBILITY = TFLITE_DELEGATE_COMPAT["custom_ops"] + + class SupportType(Enum): + """Type of operator support.""" + + COMPATIBLE = "Compatible" + OP_NOT_SUPPORTED = "Operator not supported" + ACTIVATION_NOT_SUPPORTED = "Activation not supported" + + name: str + location: str + support_type: SupportType + activation_func: TFL_ACTIVATION_FUNCTION + custom_name: str | None = None + + @property + def is_cortex_a_compatible(self) -> bool: + """Check if this operator is compatible.""" + return self.support_type == Operator.SupportType.COMPATIBLE + + @property + def full_name(self) -> str: + """Returun the full name including the custom name if applicable.""" + return self.name + (f" - '{self.custom_name}'" if self.custom_name else "") + + @property + def is_custom(self) -> bool: + """Check if this is a custom operator.""" + return bool(self.custom_name) + + @property + def compatibility_data(self) -> dict[str, dict[str, Any]]: + """Get the compatibility data (builtin or custom ops).""" + return ( + Operator.CUSTOM_COMPATIBILITY + if self.is_custom + else Operator.BUILTIN_COMPATIBILITY + ) + + @property + def supported_activation_functions(self) -> list[str]: + """Return a list of fused activation functions supported by this op.""" + op_name = self.custom_name if self.custom_name else self.name + return self.compatibility_data[op_name].get("supported_fused_activation", []) + + @classmethod + def from_tflite_op(cls, tfl_op: Op, location: str) -> Operator: + """Create a new instance from TensorFlow Lite operator and location.""" + support_type = cls._get_support_type(tfl_op) + activation_func = ( + tfl_op.builtin_options["fused_activation_function"] + if ( + tfl_op.builtin_options + and "fused_activation_function" in tfl_op.builtin_options + ) + else TFL_ACTIVATION_FUNCTION.NONE + ) + return Operator( + tfl_op.type, + location, + support_type, + activation_func=activation_func, + custom_name=(tfl_op.custom_type if tfl_op.is_custom else None), + ) + + @staticmethod + def _get_support_type(tfl_op: Op) -> Operator.SupportType: + """Get the support type from the TensorFlow Lite operator.""" + compat_data = ( + Operator.CUSTOM_COMPATIBILITY + if tfl_op.is_custom + else Operator.BUILTIN_COMPATIBILITY + ) + op_type = tfl_op.custom_type if tfl_op.is_custom else tfl_op.type + + if op_type not in compat_data: + return Operator.SupportType.OP_NOT_SUPPORTED + + compat_op = compat_data[op_type] + if "supported_fused_activation" in compat_op: + assert tfl_op.builtin_options + assert "fused_activation_function" in tfl_op.builtin_options + if ( + tfl_op.builtin_options["fused_activation_function"] + not in compat_op["supported_fused_activation"] + ): + return Operator.SupportType.ACTIVATION_NOT_SUPPORTED + + return Operator.SupportType.COMPATIBLE + + +@dataclass +class CortexACompatibilityInfo: + """Model's operators.""" + + cortex_a_compatible: bool + operators: list[Operator] + backend_info: ClassVar[str] = ( + f"{TFLITE_DELEGATE_COMPAT['metadata']['backend']} " + f"{TFLITE_DELEGATE_COMPAT['metadata']['version']}" + ) + + +def get_cortex_a_compatibility_info(model_path: Path) -> CortexACompatibilityInfo: + """Return list of model's operators.""" + model = parse_subgraphs(model_path) + + op_list = [ + Operator.from_tflite_op(oper, f"subgraph:{g_idx},oper:{op_idx}") + for g_idx, g in enumerate(model) + for op_idx, oper in enumerate(g) + ] + all_compatible = all(oper.is_cortex_a_compatible for oper in op_list) + compat_info = CortexACompatibilityInfo(all_compatible, op_list) + + return compat_info + + +def report() -> None: + """Generate supported operators report.""" + raise Exception( + "Generating a supported operators report is not " + "currently supported with Cortex-A target profile." + ) diff --git a/src/mlia/target/cortex_a/reporters.py b/src/mlia/target/cortex_a/reporters.py new file mode 100644 index 0000000..d43d6c3 --- /dev/null +++ b/src/mlia/target/cortex_a/reporters.py @@ -0,0 +1,140 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Reports module.""" +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import cast + +from mlia.core.advice_generation import Advice +from mlia.core.reporters import report_advice +from mlia.core.reporting import Cell +from mlia.core.reporting import Column +from mlia.core.reporting import Format +from mlia.core.reporting import NestedReport +from mlia.core.reporting import Report +from mlia.core.reporting import ReportItem +from mlia.core.reporting import Table +from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo +from mlia.target.cortex_a.config import CortexAConfiguration +from mlia.target.cortex_a.operators import Operator +from mlia.utils.console import style_improvement +from mlia.utils.types import is_list_of + + +def report_device(device: CortexAConfiguration) -> Report: + """Generate report for the device.""" + return NestedReport( + "Device information", + "device", + [ + ReportItem("Target", alias="target", value=device.target), + ], + ) + + +def report_tflite_compatiblity(compat_info: TFLiteCompatibilityInfo) -> Report: + """Generate report for the TensorFlow Lite compatibility information.""" + if compat_info.conversion_errors: + return Table( + [ + Column("#", only_for=["plain_text"]), + Column("Operator", alias="operator"), + Column( + "Operator location", + alias="operator_location", + fmt=Format(wrap_width=25), + ), + Column("Error code", alias="error_code"), + Column( + "Error message", alias="error_message", fmt=Format(wrap_width=25) + ), + ], + [ + ( + index + 1, + err.operator, + ", ".join(err.location), + err.code.name, + err.message, + ) + for index, err in enumerate(compat_info.conversion_errors) + ], + name="TensorFlow Lite conversion errors", + alias="tensorflow_lite_conversion_errors", + ) + + return Table( + columns=[ + Column("Reason", alias="reason"), + Column( + "Exception details", + alias="exception_details", + fmt=Format(wrap_width=40), + ), + ], + rows=[ + ( + "TensorFlow Lite compatibility check failed with exception", + str(compat_info.conversion_exception), + ), + ], + name="TensorFlow Lite compatibility errors", + alias="tflite_compatibility", + ) + + +def report_cortex_a_operators(ops: list[Operator]) -> Report: + """Generate report for the operators.""" + return Table( + [ + Column("#", only_for=["plain_text"]), + Column( + "Operator location", + alias="operator_location", + fmt=Format(wrap_width=30), + ), + Column("Operator name", alias="operator_name", fmt=Format(wrap_width=20)), + Column( + "Arm NN TFLite Delegate compatibility", + alias="cortex_a_compatible", + fmt=Format(wrap_width=40), + ), + ], + [ + ( + index + 1, + op.location, + op.full_name, + Cell( + op.support_type, + Format( + wrap_width=30, + style=style_improvement(op.is_cortex_a_compatible), + str_fmt=lambda v: cast(str, v.value), + ), + ), + ) + for index, op in enumerate(ops) + ], + name="Operators", + alias="operators", + ) + + +def cortex_a_formatters(data: Any) -> Callable[[Any], Report]: + """Find appropriate formatter for the provided data.""" + if is_list_of(data, Advice): + return report_advice + + if isinstance(data, CortexAConfiguration): + return report_device + + if isinstance(data, TFLiteCompatibilityInfo): + return report_tflite_compatiblity + + if is_list_of(data, Operator): + return report_cortex_a_operators + + raise Exception(f"Unable to find appropriate formatter for {data}") -- cgit v1.2.1