aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/target/tosa/advisor.py
blob: 3619a8396a54ecc2dd468de3ff69480c9b23a7ab (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""TOSA advisor."""
from __future__ import annotations

from pathlib import Path
from typing import Any
from typing import cast

from mlia.core.advice_generation import AdviceCategory
from mlia.core.advice_generation import AdviceProducer
from mlia.core.advisor import DefaultInferenceAdvisor
from mlia.core.advisor import InferenceAdvisor
from mlia.core.context import Context
from mlia.core.context import ExecutionContext
from mlia.core.data_analysis import DataAnalyzer
from mlia.core.data_collection import DataCollector
from mlia.core.events import Event
from mlia.target.common.optimization import add_common_optimization_params
from mlia.target.common.optimization import OptimizingDataCollector
from mlia.target.registry import profile
from mlia.target.tosa.advice_generation import TOSAAdviceProducer
from mlia.target.tosa.config import TOSAConfiguration
from mlia.target.tosa.data_analysis import TOSADataAnalyzer
from mlia.target.tosa.data_collection import TOSAOperatorCompatibility
from mlia.target.tosa.events import TOSAAdvisorStartedEvent
from mlia.target.tosa.handlers import TOSAEventHandler
from mlia.target.tosa.reporters import MetadataDisplay


class TOSAInferenceAdvisor(DefaultInferenceAdvisor):
    """TOSA inference advisor."""

    @classmethod
    def name(cls) -> str:
        """Return name of the advisor."""
        return "tosa_inference_advisor"

    def get_collectors(self, context: Context) -> list[DataCollector]:
        """Return list of the data collectors."""
        model = self.get_model(context)

        collectors: list[DataCollector] = []

        if context.category_enabled(AdviceCategory.COMPATIBILITY):
            collectors.append(TOSAOperatorCompatibility(model))

        if context.category_enabled(AdviceCategory.PERFORMANCE):
            raise RuntimeError(
                "Performance estimation is currently not supported for TOSA."
            )

        if context.category_enabled(AdviceCategory.OPTIMIZATION):
            target_profile = self.get_target_profile(context)
            target_config = profile(target_profile)
            collectors.append(OptimizingDataCollector(model, target_config))

        return collectors

    def get_analyzers(self, context: Context) -> list[DataAnalyzer]:
        """Return list of the data analyzers."""
        return [
            TOSADataAnalyzer(),
        ]

    def get_producers(self, context: Context) -> list[AdviceProducer]:
        """Return list of the advice producers."""
        return [
            TOSAAdviceProducer(),
        ]

    def get_events(self, context: Context) -> list[Event]:
        """Return list of the startup events."""
        model = self.get_model(context)
        target_profile = self.get_target_profile(context)

        return [
            TOSAAdvisorStartedEvent(
                model,
                cast(TOSAConfiguration, profile(target_profile)),
                MetadataDisplay(model),
            )
        ]


def configure_and_get_tosa_advisor(
    context: ExecutionContext,
    target_profile: str | Path,
    model: str | Path,
    **extra_args: Any,
) -> InferenceAdvisor:
    """Create and configure TOSA advisor."""
    if context.event_handlers is None:
        context.event_handlers = [TOSAEventHandler()]

    if context.config_parameters is None:
        context.config_parameters = _get_config_parameters(
            model, target_profile, **extra_args
        )

    return TOSAInferenceAdvisor()


def _get_config_parameters(
    model: str | Path, target_profile: str | Path, **extra_args: Any
) -> dict[str, Any]:
    """Get configuration parameters for the advisor."""
    advisor_parameters: dict[str, Any] = {
        "tosa_inference_advisor": {
            "model": str(model),
            "target_profile": target_profile,
        }
    }
    add_common_optimization_params(advisor_parameters, extra_args)
    return advisor_parameters