diff options
-rw-r--r-- | README.md | 6 | ||||
-rw-r--r-- | src/mlia/api.py | 18 | ||||
-rw-r--r-- | src/mlia/cli/commands.py | 4 | ||||
-rw-r--r-- | src/mlia/cli/options.py | 3 | ||||
-rw-r--r-- | src/mlia/devices/cortexa/operators.py | 8 | ||||
-rw-r--r-- | src/mlia/devices/ethosu/operators.py | 2 | ||||
-rw-r--r-- | src/mlia/devices/tosa/operators.py | 8 | ||||
-rw-r--r-- | tests/test_api.py | 51 |
8 files changed, 92 insertions, 8 deletions
@@ -216,7 +216,7 @@ mlia operators --target-profile ethos-u55-256 ~/models/mobilenet_v1_1.0_224_quan * -h/--help: Show the general help document and exit. * --supported-ops-report: Generate the SUPPORTED_OPS.md file in the current working - directory and exit. + directory and exit (Ethos-U target profiles only). ### **Performance** (perf) @@ -272,7 +272,7 @@ mlia performance ~/models/mobilenet_v1_1.0_224_quant.tflite \ * -h/--help: Show the general help document and exit. * --supported-ops-report: Generate the SUPPORTED_OPS.md file in the current - working directory and exit. + working directory and exit (Ethos-U target profiles only). ### **Model optimization** (opt) @@ -334,7 +334,7 @@ mlia optimization \ * -h/--help: Show the general help document and exit. * --supported-ops-report: Generate the SUPPORTED_OPS.md file in the current - working directory and exit. + working directory and exit (Ethos-U target profiles only). ### **All tests** (all) diff --git a/src/mlia/api.py b/src/mlia/api.py index fc61af0..6af7db2 100644 --- a/src/mlia/api.py +++ b/src/mlia/api.py @@ -13,11 +13,13 @@ from mlia.core.common import AdviceCategory from mlia.core.context import ExecutionContext from mlia.core.typing import PathOrFileLike from mlia.devices.cortexa.advisor import configure_and_get_cortexa_advisor +from mlia.devices.cortexa.operators import report as cortex_a_report from mlia.devices.ethosu.advisor import configure_and_get_ethosu_advisor +from mlia.devices.ethosu.operators import report as ethos_u_report from mlia.devices.tosa.advisor import configure_and_get_tosa_advisor +from mlia.devices.tosa.operators import report as tosa_report from mlia.utils.filesystem import get_target - logger = logging.getLogger(__name__) @@ -121,3 +123,17 @@ def get_advisor( output, **extra_args, ) + + +def generate_supported_operators_report(target_profile: str) -> None: + """Generate a supported operators report based on given target profile.""" + generators_map = { + "ethos-u55": ethos_u_report, + "ethos-u65": ethos_u_report, + "cortex-a": cortex_a_report, + "tosa": tosa_report, + } + + target = get_target(target_profile) + + generators_map[target]() diff --git a/src/mlia/cli/commands.py b/src/mlia/cli/commands.py index e044e1a..72ae4bb 100644 --- a/src/mlia/cli/commands.py +++ b/src/mlia/cli/commands.py @@ -23,11 +23,11 @@ from pathlib import Path from typing import cast from mlia.api import ExecutionContext +from mlia.api import generate_supported_operators_report from mlia.api import get_advice from mlia.api import PathOrFileLike from mlia.cli.config import get_installation_manager from mlia.cli.options import parse_optimization_parameters -from mlia.devices.ethosu.operators import generate_supported_operators_report from mlia.utils.console import create_section_header from mlia.utils.types import only_one_selected @@ -129,7 +129,7 @@ def operators( "model.tflite") """ if supported_ops_report: - generate_supported_operators_report() + generate_supported_operators_report(target_profile) logger.info("Report saved into SUPPORTED_OPS.md") return diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py index e5e85f0..b28fa8f 100644 --- a/src/mlia/cli/options.py +++ b/src/mlia/cli/options.py @@ -124,7 +124,8 @@ def add_custom_supported_operators_options(parser: argparse.ArgumentParser) -> N default=False, help=( "Generate the SUPPORTED_OPS.md file in the " - "current working directory and exit" + "current working directory and exit " + "(Ethos-U target profiles only)" ), ) diff --git a/src/mlia/devices/cortexa/operators.py b/src/mlia/devices/cortexa/operators.py index 8fd2571..d46b107 100644 --- a/src/mlia/devices/cortexa/operators.py +++ b/src/mlia/devices/cortexa/operators.py @@ -29,3 +29,11 @@ def get_cortex_a_compatibility_info( ) -> CortexACompatibilityInfo | None: """Return list of model's operators.""" return None + + +def report() -> None: + """Generate supported operators report.""" + raise Exception( + "Generating a supported operators report is not " + "currently supported with Cortex-A target profile." + ) diff --git a/src/mlia/devices/ethosu/operators.py b/src/mlia/devices/ethosu/operators.py index ff0d99f..1a4ce8d 100644 --- a/src/mlia/devices/ethosu/operators.py +++ b/src/mlia/devices/ethosu/operators.py @@ -9,6 +9,6 @@ from mlia.tools import vela_wrapper logger = logging.getLogger(__name__) -def generate_supported_operators_report() -> None: +def report() -> None: """Generate supported operators report.""" vela_wrapper.generate_supported_operators_report() diff --git a/src/mlia/devices/tosa/operators.py b/src/mlia/devices/tosa/operators.py index 6cfb87f..03f6fb8 100644 --- a/src/mlia/devices/tosa/operators.py +++ b/src/mlia/devices/tosa/operators.py @@ -68,3 +68,11 @@ def get_tosa_checker(tflite_model_path: PathOrFileLike) -> TOSAChecker | None: checker = tc.TOSAChecker(str(tflite_model_path)) return cast(TOSAChecker, checker) + + +def report() -> None: + """Generate supported operators report.""" + raise Exception( + "Generating a supported operators report is not " + "currently supported with TOSA target profile." + ) diff --git a/tests/test_api.py b/tests/test_api.py index 7b567bf..6fa15b3 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,11 +1,15 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the API functions.""" +from __future__ import annotations + from pathlib import Path from unittest.mock import MagicMock +from unittest.mock import patch import pytest +from mlia.api import generate_supported_operators_report from mlia.api import get_advice from mlia.api import get_advisor from mlia.core.common import AdviceCategory @@ -107,3 +111,50 @@ def test_get_advisor( tosa_advisor = get_advisor(ExecutionContext(), "tosa", str(test_keras_model)) assert isinstance(tosa_advisor, TOSAInferenceAdvisor) + + +@pytest.mark.parametrize( + ["target_profile", "required_calls", "exception_msg"], + [ + [ + "ethos-u55-128", + "mlia.tools.vela_wrapper.generate_supported_operators_report", + None, + ], + [ + "ethos-u65-256", + "mlia.tools.vela_wrapper.generate_supported_operators_report", + None, + ], + [ + "tosa", + None, + "Generating a supported operators report is not " + "currently supported with TOSA target profile.", + ], + [ + "cortex-a", + None, + "Generating a supported operators report is not " + "currently supported with Cortex-A target profile.", + ], + [ + "Unknown", + None, + "Unable to find target profile Unknown", + ], + ], +) +def test_supported_ops_report_generator( + target_profile: str, required_calls: str | None, exception_msg: str | None +) -> None: + """Test supported operators report generator with different target profiles.""" + if exception_msg: + with pytest.raises(Exception) as exc: + generate_supported_operators_report(target_profile) + assert str(exc.value) == exception_msg + + if required_calls: + with patch(required_calls) as mock_method: + generate_supported_operators_report(target_profile) + mock_method.assert_called_once() |