aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/api.py
blob: fd5fc13e6488b3e95637e4f284a83b77781c6806 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for the API functions."""
from __future__ import annotations

import logging
from pathlib import Path
from typing import Any

from mlia.core.advisor import InferenceAdvisor
from mlia.core.common import AdviceCategory
from mlia.core.context import ExecutionContext
from mlia.target.config import get_target
from mlia.target.cortex_a.advisor import configure_and_get_cortexa_advisor
from mlia.target.ethos_u.advisor import configure_and_get_ethosu_advisor
from mlia.target.tosa.advisor import configure_and_get_tosa_advisor

logger = logging.getLogger(__name__)


def get_advice(
    target_profile: str,
    model: str | Path,
    category: set[str],
    optimization_targets: list[dict[str, Any]] | None = None,
    context: ExecutionContext | None = None,
    backends: list[str] | None = None,
) -> None:
    """Get the advice.

    This function represents an entry point to the library API.

    Based on provided parameters it will collect and analyze the data
    and produce the advice.

    :param target_profile: target profile identifier
    :param model: path to the NN model
    :param category: set of categories of the advice. MLIA supports three categories:
           "compatibility", "performance", "optimization". If not provided
           category "compatibility" is used by default.
    :param optimization_targets: optional model optimization targets that
           could be used for generating advice in "optimization" category.
    :param context: optional parameter which represents execution context,
           could be used for advanced use cases
    :param backends: A list of backends that should be used for the given
           target. Default settings will be used if None.

    Examples:
        NB: Before launching MLIA, the logging functionality should be configured!

        Getting the advice for the provided target profile and the model

        >>> get_advice("ethos-u55-256", "path/to/the/model",
                       {"optimization", "compatibility"})

        Getting the advice for the category "performance".

        >>> get_advice("ethos-u55-256", "path/to/the/model", {"performance"})

    """
    advice_category = AdviceCategory.from_string(category)

    if context is not None:
        context.advice_category = advice_category

    if context is None:
        context = ExecutionContext(advice_category=advice_category)

    advisor = get_advisor(
        context,
        target_profile,
        model,
        optimization_targets=optimization_targets,
        backends=backends,
    )

    advisor.run(context)


def get_advisor(
    context: ExecutionContext,
    target_profile: str | Path,
    model: str | Path,
    **extra_args: Any,
) -> InferenceAdvisor:
    """Find appropriate advisor for the target."""
    target_factories = {
        "ethos-u55": configure_and_get_ethosu_advisor,
        "ethos-u65": configure_and_get_ethosu_advisor,
        "tosa": configure_and_get_tosa_advisor,
        "cortex-a": configure_and_get_cortexa_advisor,
    }

    try:
        target = get_target(target_profile)
        factory_function = target_factories[target]
    except KeyError as err:
        raise Exception(f"Unsupported profile {target_profile}") from err

    return factory_function(
        context,
        target_profile,
        model,
        **extra_args,
    )