1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
|
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""TOSA advisor."""
from pathlib import Path
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
from mlia.core._typing import PathOrFileLike
from mlia.core.advice_generation import AdviceCategory
from mlia.core.advice_generation import AdviceProducer
from mlia.core.advisor import DefaultInferenceAdvisor
from mlia.core.advisor import InferenceAdvisor
from mlia.core.context import Context
from mlia.core.context import ExecutionContext
from mlia.core.data_analysis import DataAnalyzer
from mlia.core.data_collection import DataCollector
from mlia.core.events import Event
from mlia.devices.tosa.advice_generation import TOSAAdviceProducer
from mlia.devices.tosa.config import TOSAConfiguration
from mlia.devices.tosa.data_analysis import TOSADataAnalyzer
from mlia.devices.tosa.data_collection import TOSAOperatorCompatibility
from mlia.devices.tosa.events import TOSAAdvisorStartedEvent
from mlia.devices.tosa.handlers import TOSAEventHandler
class TOSAInferenceAdvisor(DefaultInferenceAdvisor):
"""TOSA inference advisor."""
@classmethod
def name(cls) -> str:
"""Return name of the advisor."""
return "tosa_inference_advisor"
def get_collectors(self, context: Context) -> List[DataCollector]:
"""Return list of the data collectors."""
model = self.get_model(context)
collectors: List[DataCollector] = []
if AdviceCategory.OPERATORS in context.advice_category:
collectors.append(TOSAOperatorCompatibility(model))
return collectors
def get_analyzers(self, context: Context) -> List[DataAnalyzer]:
"""Return list of the data analyzers."""
return [
TOSADataAnalyzer(),
]
def get_producers(self, context: Context) -> List[AdviceProducer]:
"""Return list of the advice producers."""
return [
TOSAAdviceProducer(),
]
def get_events(self, context: Context) -> List[Event]:
"""Return list of the startup events."""
model = self.get_model(context)
target_profile = self.get_target_profile(context)
return [
TOSAAdvisorStartedEvent(model, TOSAConfiguration(target_profile)),
]
def configure_and_get_tosa_advisor(
context: ExecutionContext,
target_profile: str,
model: Union[Path, str],
output: Optional[PathOrFileLike] = None,
**_extra_args: Any
) -> InferenceAdvisor:
"""Create and configure TOSA advisor."""
if context.event_handlers is None:
context.event_handlers = [TOSAEventHandler(output)]
if context.config_parameters is None:
context.config_parameters = _get_config_parameters(model, target_profile)
return TOSAInferenceAdvisor()
def _get_config_parameters(
model: Union[Path, str], target_profile: str
) -> Dict[str, Any]:
"""Get configuration parameters for the advisor."""
advisor_parameters: Dict[str, Any] = {
"tosa_inference_advisor": {
"model": str(model),
"target_profile": target_profile,
}
}
return advisor_parameters
|