aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/target/ethos_u/advisor.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/target/ethos_u/advisor.py')
-rw-r--r--src/mlia/target/ethos_u/advisor.py12
1 files changed, 5 insertions, 7 deletions
diff --git a/src/mlia/target/ethos_u/advisor.py b/src/mlia/target/ethos_u/advisor.py
index 937e91c..225fd87 100644
--- a/src/mlia/target/ethos_u/advisor.py
+++ b/src/mlia/target/ethos_u/advisor.py
@@ -19,7 +19,6 @@ from mlia.nn.tensorflow.utils import is_tflite_model
from mlia.target.ethos_u.advice_generation import EthosUAdviceProducer
from mlia.target.ethos_u.advice_generation import EthosUStaticAdviceProducer
from mlia.target.ethos_u.config import EthosUConfiguration
-from mlia.target.ethos_u.config import get_target
from mlia.target.ethos_u.data_analysis import EthosUDataAnalyzer
from mlia.target.ethos_u.data_collection import EthosUOperatorCompatibility
from mlia.target.ethos_u.data_collection import EthosUOptimizationPerformance
@@ -40,7 +39,7 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor):
def get_collectors(self, context: Context) -> list[DataCollector]:
"""Return list of the data collectors."""
model = self.get_model(context)
- device = self._get_device(context)
+ device = self._get_device_cfg(context)
backends = self._get_backends(context)
collectors: list[DataCollector] = []
@@ -88,17 +87,16 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor):
def get_events(self, context: Context) -> list[Event]:
"""Return list of the startup events."""
model = self.get_model(context)
- device = self._get_device(context)
+ device = self._get_device_cfg(context)
return [
EthosUAdvisorStartedEvent(device=device, model=model),
]
- def _get_device(self, context: Context) -> EthosUConfiguration:
- """Get device."""
+ def _get_device_cfg(self, context: Context) -> EthosUConfiguration:
+ """Get device configuration."""
target_profile = self.get_target_profile(context)
-
- return get_target(target_profile)
+ return EthosUConfiguration.load_profile(target_profile)
def _get_optimization_settings(self, context: Context) -> list[list[dict]]:
"""Get optimization settings."""