1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
|
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Inference advisor module."""
from abc import abstractmethod
from pathlib import Path
from typing import cast
from typing import List
from mlia.core.advice_generation import AdviceProducer
from mlia.core.common import NamedEntity
from mlia.core.context import Context
from mlia.core.data_analysis import DataAnalyzer
from mlia.core.data_collection import DataCollector
from mlia.core.events import Event
from mlia.core.mixins import ParameterResolverMixin
from mlia.core.workflow import DefaultWorkflowExecutor
from mlia.core.workflow import WorkflowExecutor
class InferenceAdvisor(NamedEntity):
"""Base class for inference advisors."""
@abstractmethod
def configure(self, context: Context) -> WorkflowExecutor:
"""Configure advisor execution."""
def run(self, context: Context) -> None:
"""Run inference advisor."""
executor = self.configure(context)
executor.run()
class DefaultInferenceAdvisor(InferenceAdvisor, ParameterResolverMixin):
"""Default implementation for the advisor."""
def configure(self, context: Context) -> WorkflowExecutor:
"""Configure advisor."""
return DefaultWorkflowExecutor(
context,
self.get_collectors(context),
self.get_analyzers(context),
self.get_producers(context),
self.get_events(context),
)
@abstractmethod
def get_collectors(self, context: Context) -> List[DataCollector]:
"""Return list of the data collectors."""
@abstractmethod
def get_analyzers(self, context: Context) -> List[DataAnalyzer]:
"""Return list of the data analyzers."""
@abstractmethod
def get_producers(self, context: Context) -> List[AdviceProducer]:
"""Return list of the advice producers."""
@abstractmethod
def get_events(self, context: Context) -> List[Event]:
"""Return list of the startup events."""
def get_string_parameter(self, context: Context, param: str) -> str:
"""Get string parameter value."""
value = self.get_parameter(
self.name(),
param,
expected_type=str,
context=context,
)
return cast(str, value)
def get_model(self, context: Context) -> Path:
"""Get path to the model."""
model_param = self.get_string_parameter(context, "model")
model = Path(model_param)
if not model.exists():
raise Exception(f"Path {model} does not exist")
return model
def get_target_profile(self, context: Context) -> str:
"""Get target profile."""
return self.get_string_parameter(context, "target_profile")
|