aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/api.py
blob: 2cabf372f7d4eaeb7d1d72edffe379d6381111c2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for the API functions."""
from __future__ import annotations

import logging
from pathlib import Path
from typing import Any

from mlia.core.advisor import InferenceAdvisor
from mlia.core.common import AdviceCategory
from mlia.core.common import FormattedFilePath
from mlia.core.context import ExecutionContext
from mlia.target.cortex_a.advisor import configure_and_get_cortexa_advisor
from mlia.target.ethos_u.advisor import configure_and_get_ethosu_advisor
from mlia.target.tosa.advisor import configure_and_get_tosa_advisor
from mlia.utils.filesystem import get_target

logger = logging.getLogger(__name__)


def get_advice(
    target_profile: str,
    model: str | Path,
    category: set[str],
    optimization_targets: list[dict[str, Any]] | None = None,
    output: FormattedFilePath | None = None,
    context: ExecutionContext | None = None,
    backends: list[str] | None = None,
) -> None:
    """Get the advice.

    This function represents an entry point to the library API.

    Based on provided parameters it will collect and analyze the data
    and produce the advice.

    :param target_profile: target profile identifier
    :param model: path to the NN model
    :param category: set of categories of the advice. MLIA supports three categories:
           "compatibility", "performance", "optimization". If not provided
           category "compatibility" is used by default.
    :param optimization_targets: optional model optimization targets that
           could be used for generating advice in "optimization" category.
    :param output: path to the report file. If provided, MLIA will save
           report in this location.
    :param context: optional parameter which represents execution context,
           could be used for advanced use cases
    :param backends: A list of backends that should be used for the given
           target. Default settings will be used if None.

    Examples:
        NB: Before launching MLIA, the logging functionality should be configured!

        Getting the advice for the provided target profile and the model

        >>> get_advice("ethos-u55-256", "path/to/the/model",
                       {"optimization", "compatibility"})

        Getting the advice for the category "performance" and save result report in file
        "report.json"

        >>> get_advice("ethos-u55-256", "path/to/the/model", {"performance"},
                       output=FormattedFilePath("report.json")

    """
    advice_category = AdviceCategory.from_string(category)

    if context is not None:
        context.advice_category = advice_category

    if context is None:
        context = ExecutionContext(advice_category=advice_category)

    advisor = get_advisor(
        context,
        target_profile,
        model,
        output,
        optimization_targets=optimization_targets,
        backends=backends,
    )

    advisor.run(context)


def get_advisor(
    context: ExecutionContext,
    target_profile: str,
    model: str | Path,
    output: FormattedFilePath | None = None,
    **extra_args: Any,
) -> InferenceAdvisor:
    """Find appropriate advisor for the target."""
    target_factories = {
        "ethos-u55": configure_and_get_ethosu_advisor,
        "ethos-u65": configure_and_get_ethosu_advisor,
        "tosa": configure_and_get_tosa_advisor,
        "cortex-a": configure_and_get_cortexa_advisor,
    }

    try:
        target = get_target(target_profile)
        factory_function = target_factories[target]
    except KeyError as err:
        raise Exception(f"Unsupported profile {target_profile}") from err

    return factory_function(
        context,
        target_profile,
        model,
        output,
        **extra_args,
    )