aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/core/advisor.py
blob: 7db5cfbe66fda71b8b3ed9cbc56d54e2d81a8069 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Inference advisor module."""
from __future__ import annotations

from abc import abstractmethod
from pathlib import Path
from typing import cast

from mlia.core.advice_generation import AdviceProducer
from mlia.core.common import NamedEntity
from mlia.core.context import Context
from mlia.core.data_analysis import DataAnalyzer
from mlia.core.data_collection import DataCollector
from mlia.core.events import Event
from mlia.core.mixins import ParameterResolverMixin
from mlia.core.workflow import DefaultWorkflowExecutor
from mlia.core.workflow import WorkflowExecutor


class InferenceAdvisor(NamedEntity):
    """Base class for inference advisors."""

    @abstractmethod
    def configure(self, context: Context) -> WorkflowExecutor:
        """Configure advisor execution."""

    def run(self, context: Context) -> None:
        """Run inference advisor."""
        executor = self.configure(context)
        executor.run()


class DefaultInferenceAdvisor(InferenceAdvisor, ParameterResolverMixin):
    """Default implementation for the advisor."""

    def configure(self, context: Context) -> WorkflowExecutor:
        """Configure advisor."""
        return DefaultWorkflowExecutor(
            context,
            self.get_collectors(context),
            self.get_analyzers(context),
            self.get_producers(context),
            self.get_events(context),
        )

    @abstractmethod
    def get_collectors(self, context: Context) -> list[DataCollector]:
        """Return list of the data collectors."""

    @abstractmethod
    def get_analyzers(self, context: Context) -> list[DataAnalyzer]:
        """Return list of the data analyzers."""

    @abstractmethod
    def get_producers(self, context: Context) -> list[AdviceProducer]:
        """Return list of the advice producers."""

    @abstractmethod
    def get_events(self, context: Context) -> list[Event]:
        """Return list of the startup events."""

    def get_string_parameter(self, context: Context, param: str) -> str:
        """Get string parameter value."""
        value = self.get_parameter(
            self.name(),
            param,
            expected_type=str,
            context=context,
        )

        return cast(str, value)

    def get_model(self, context: Context) -> Path:
        """Get path to the model."""
        model_param = self.get_string_parameter(context, "model")

        model = Path(model_param)
        if not model.exists():
            raise FileNotFoundError(f"Path {model} does not exist.")

        return model

    def get_target_profile(self, context: Context) -> str:
        """Get target profile."""
        return self.get_string_parameter(context, "target_profile")