diff options
author | Dmitrii Agibov <dmitrii.agibov@arm.com> | 2022-07-21 14:06:03 +0100 |
---|---|---|
committer | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2022-08-19 10:23:23 +0100 |
commit | a8ee1aee3e674c78a77801d1bf2256881ab6b4b9 (patch) | |
tree | 8463b24ba0446a49b3e012477b0834c3b5415b86 /src/mlia/core/advisor.py | |
parent | 76ec769ad8f8ed53ec3ff829fdd34d53db8229fd (diff) | |
download | mlia-a8ee1aee3e674c78a77801d1bf2256881ab6b4b9.tar.gz |
MLIA-549 Refactor API module to support several target profiles
- Move target specific details out of API module
- Move common logic for workflow event handler into a
separate class
Change-Id: Ic4a22657b722af1c1fead1d478f606ac57325788
Diffstat (limited to 'src/mlia/core/advisor.py')
-rw-r--r-- | src/mlia/core/advisor.py | 64 |
1 files changed, 64 insertions, 0 deletions
diff --git a/src/mlia/core/advisor.py b/src/mlia/core/advisor.py index 868d0c7..13689fa 100644 --- a/src/mlia/core/advisor.py +++ b/src/mlia/core/advisor.py @@ -2,9 +2,18 @@ # SPDX-License-Identifier: Apache-2.0 """Inference advisor module.""" from abc import abstractmethod +from pathlib import Path +from typing import cast +from typing import List +from mlia.core.advice_generation import AdviceProducer from mlia.core.common import NamedEntity from mlia.core.context import Context +from mlia.core.data_analysis import DataAnalyzer +from mlia.core.data_collection import DataCollector +from mlia.core.events import Event +from mlia.core.mixins import ParameterResolverMixin +from mlia.core.workflow import DefaultWorkflowExecutor from mlia.core.workflow import WorkflowExecutor @@ -19,3 +28,58 @@ class InferenceAdvisor(NamedEntity): """Run inference advisor.""" executor = self.configure(context) executor.run() + + +class DefaultInferenceAdvisor(InferenceAdvisor, ParameterResolverMixin): + """Default implementation for the advisor.""" + + def configure(self, context: Context) -> WorkflowExecutor: + """Configure advisor.""" + return DefaultWorkflowExecutor( + context, + self.get_collectors(context), + self.get_analyzers(context), + self.get_producers(context), + self.get_events(context), + ) + + @abstractmethod + def get_collectors(self, context: Context) -> List[DataCollector]: + """Return list of the data collectors.""" + + @abstractmethod + def get_analyzers(self, context: Context) -> List[DataAnalyzer]: + """Return list of the data analyzers.""" + + @abstractmethod + def get_producers(self, context: Context) -> List[AdviceProducer]: + """Return list of the advice producers.""" + + @abstractmethod + def get_events(self, context: Context) -> List[Event]: + """Return list of the startup events.""" + + def get_string_parameter(self, context: Context, param: str) -> str: + """Get string parameter value.""" + value = self.get_parameter( + self.name(), + param, + expected_type=str, + context=context, + ) + + return cast(str, value) + + def get_model(self, context: Context) -> Path: + """Get path to the model.""" + model_param = self.get_string_parameter(context, "model") + + model = Path(model_param) + if not model.exists(): + raise Exception(f"Path {model} does not exist") + + return model + + def get_target_profile(self, context: Context) -> str: + """Get target profile.""" + return self.get_string_parameter(context, "target_profile") |