diff options
Diffstat (limited to 'src/mlia/api.py')
-rw-r--r-- | src/mlia/api.py | 57 |
1 files changed, 16 insertions, 41 deletions
diff --git a/src/mlia/api.py b/src/mlia/api.py index c7be9ec..2cabf37 100644 --- a/src/mlia/api.py +++ b/src/mlia/api.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for the API functions.""" from __future__ import annotations @@ -6,18 +6,14 @@ from __future__ import annotations import logging from pathlib import Path from typing import Any -from typing import Literal from mlia.core.advisor import InferenceAdvisor from mlia.core.common import AdviceCategory +from mlia.core.common import FormattedFilePath from mlia.core.context import ExecutionContext -from mlia.core.typing import PathOrFileLike from mlia.target.cortex_a.advisor import configure_and_get_cortexa_advisor -from mlia.target.cortex_a.operators import report as cortex_a_report from mlia.target.ethos_u.advisor import configure_and_get_ethosu_advisor -from mlia.target.ethos_u.operators import report as ethos_u_report from mlia.target.tosa.advisor import configure_and_get_tosa_advisor -from mlia.target.tosa.operators import report as tosa_report from mlia.utils.filesystem import get_target logger = logging.getLogger(__name__) @@ -26,10 +22,9 @@ logger = logging.getLogger(__name__) def get_advice( target_profile: str, model: str | Path, - category: Literal["all", "operators", "performance", "optimization"] = "all", + category: set[str], optimization_targets: list[dict[str, Any]] | None = None, - working_dir: str | Path = "mlia_output", - output: PathOrFileLike | None = None, + output: FormattedFilePath | None = None, context: ExecutionContext | None = None, backends: list[str] | None = None, ) -> None: @@ -42,17 +37,13 @@ def get_advice( :param target_profile: target profile identifier :param model: path to the NN model - :param category: category of the advice. MLIA supports four categories: - "all", "operators", "performance", "optimization". If not provided - category "all" is used by default. + :param category: set of categories of the advice. MLIA supports three categories: + "compatibility", "performance", "optimization". If not provided + category "compatibility" is used by default. :param optimization_targets: optional model optimization targets that - could be used for generating advice in categories - "all" and "optimization." - :param working_dir: path to the directory that will be used for storing - intermediate files during execution (e.g. converted models) - :param output: path to the report file. If provided MLIA will save - report in this location. Format of the report automatically - detected based on file extension. + could be used for generating advice in "optimization" category. + :param output: path to the report file. If provided, MLIA will save + report in this location. :param context: optional parameter which represents execution context, could be used for advanced use cases :param backends: A list of backends that should be used for the given @@ -63,13 +54,14 @@ def get_advice( Getting the advice for the provided target profile and the model - >>> get_advice("ethos-u55-256", "path/to/the/model") + >>> get_advice("ethos-u55-256", "path/to/the/model", + {"optimization", "compatibility"}) Getting the advice for the category "performance" and save result report in file "report.json" - >>> get_advice("ethos-u55-256", "path/to/the/model", "performance", - output="report.json") + >>> get_advice("ethos-u55-256", "path/to/the/model", {"performance"}, + output=FormattedFilePath("report.json") """ advice_category = AdviceCategory.from_string(category) @@ -78,10 +70,7 @@ def get_advice( context.advice_category = advice_category if context is None: - context = ExecutionContext( - advice_category=advice_category, - working_dir=working_dir, - ) + context = ExecutionContext(advice_category=advice_category) advisor = get_advisor( context, @@ -99,7 +88,7 @@ def get_advisor( context: ExecutionContext, target_profile: str, model: str | Path, - output: PathOrFileLike | None = None, + output: FormattedFilePath | None = None, **extra_args: Any, ) -> InferenceAdvisor: """Find appropriate advisor for the target.""" @@ -123,17 +112,3 @@ def get_advisor( output, **extra_args, ) - - -def generate_supported_operators_report(target_profile: str) -> None: - """Generate a supported operators report based on given target profile.""" - generators_map = { - "ethos-u55": ethos_u_report, - "ethos-u65": ethos_u_report, - "cortex-a": cortex_a_report, - "tosa": tosa_report, - } - - target = get_target(target_profile) - - generators_map[target]() |