From a4fb8c72f15146c95df16c25e75f03344e9814fd Mon Sep 17 00:00:00 2001 From: Benjamin Klimczak Date: Wed, 11 Jan 2023 12:32:02 +0000 Subject: MLIA-591 Create interface for target profiles New class 'TargetProfile' is used to load and verify target profiles. Change-Id: I76373a923e2e5f55c4e95860635afe9fc5627a5d --- src/mlia/target/ethos_u/advisor.py | 12 +++---- src/mlia/target/ethos_u/config.py | 73 ++++++++++++++------------------------ 2 files changed, 31 insertions(+), 54 deletions(-) (limited to 'src/mlia/target/ethos_u') diff --git a/src/mlia/target/ethos_u/advisor.py b/src/mlia/target/ethos_u/advisor.py index 937e91c..225fd87 100644 --- a/src/mlia/target/ethos_u/advisor.py +++ b/src/mlia/target/ethos_u/advisor.py @@ -19,7 +19,6 @@ from mlia.nn.tensorflow.utils import is_tflite_model from mlia.target.ethos_u.advice_generation import EthosUAdviceProducer from mlia.target.ethos_u.advice_generation import EthosUStaticAdviceProducer from mlia.target.ethos_u.config import EthosUConfiguration -from mlia.target.ethos_u.config import get_target from mlia.target.ethos_u.data_analysis import EthosUDataAnalyzer from mlia.target.ethos_u.data_collection import EthosUOperatorCompatibility from mlia.target.ethos_u.data_collection import EthosUOptimizationPerformance @@ -40,7 +39,7 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor): def get_collectors(self, context: Context) -> list[DataCollector]: """Return list of the data collectors.""" model = self.get_model(context) - device = self._get_device(context) + device = self._get_device_cfg(context) backends = self._get_backends(context) collectors: list[DataCollector] = [] @@ -88,17 +87,16 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor): def get_events(self, context: Context) -> list[Event]: """Return list of the startup events.""" model = self.get_model(context) - device = self._get_device(context) + device = self._get_device_cfg(context) return [ EthosUAdvisorStartedEvent(device=device, model=model), ] - def _get_device(self, context: Context) -> EthosUConfiguration: - """Get device.""" + def _get_device_cfg(self, context: Context) -> EthosUConfiguration: + """Get device configuration.""" target_profile = self.get_target_profile(context) - - return get_target(target_profile) + return EthosUConfiguration.load_profile(target_profile) def _get_optimization_settings(self, context: Context) -> list[list[dict]]: """Get optimization settings.""" diff --git a/src/mlia/target/ethos_u/config.py b/src/mlia/target/ethos_u/config.py index 8d8f481..eb5691d 100644 --- a/src/mlia/target/ethos_u/config.py +++ b/src/mlia/target/ethos_u/config.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Ethos-U configuration.""" from __future__ import annotations @@ -8,36 +8,49 @@ from typing import Any from mlia.backend.vela.compiler import resolve_compiler_config from mlia.backend.vela.compiler import VelaCompilerOptions -from mlia.target.config import IPConfiguration -from mlia.utils.filesystem import get_profile +from mlia.target.config import TargetProfile from mlia.utils.filesystem import get_vela_config logger = logging.getLogger(__name__) -class EthosUConfiguration(IPConfiguration): +class EthosUConfiguration(TargetProfile): """Ethos-U configuration.""" - def __init__(self, target_profile: str) -> None: + def __init__(self, **kwargs: Any) -> None: """Init Ethos-U target configuration.""" - target_data = get_profile(target_profile) - _check_target_data_complete(target_data) - - target = target_data["target"] + target = kwargs["target"] super().__init__(target) - mac = target_data["mac"] - _check_device_options_valid(target, mac) + mac = kwargs["mac"] self.mac = mac self.compiler_options = VelaCompilerOptions( - system_config=target_data["system_config"], - memory_mode=target_data["memory_mode"], + system_config=kwargs["system_config"], + memory_mode=kwargs["memory_mode"], config_files=str(get_vela_config()), accelerator_config=f"{self.target}-{mac}", # type: ignore ) + def verify(self) -> None: + """Check the parameters.""" + super().verify() + + target_mac_ranges = { + "ethos-u55": [32, 64, 128, 256], + "ethos-u65": [256, 512], + } + + if self.target not in target_mac_ranges: + raise ValueError(f"Unsupported target: {self.target}") + + target_mac_range = target_mac_ranges[self.target] + if self.mac not in target_mac_range: + raise ValueError( + f"Mac value for selected device should be in {target_mac_range}." + ) + @property def resolved_compiler_config(self) -> dict[str, Any]: """Resolve compiler configuration.""" @@ -54,37 +67,3 @@ class EthosUConfiguration(IPConfiguration): def __repr__(self) -> str: """Return string representation.""" return f"" - - -def get_target(target_profile: str) -> EthosUConfiguration: - """Get target instance based on provided params.""" - if not target_profile: - raise Exception("No target profile given") - - return EthosUConfiguration(target_profile) - - -def _check_target_data_complete(target_data: dict[str, Any]) -> None: - """Check if profile contains all needed data.""" - mandatory_keys = {"target", "mac", "system_config", "memory_mode"} - missing_keys = sorted(mandatory_keys - target_data.keys()) - - if missing_keys: - raise Exception(f"Mandatory fields missing from target profile: {missing_keys}") - - -def _check_device_options_valid(target: str, mac: int) -> None: - """Check if mac is valid for selected device.""" - target_mac_ranges = { - "ethos-u55": [32, 64, 128, 256], - "ethos-u65": [256, 512], - } - - if target not in target_mac_ranges: - raise Exception(f"Unsupported target: {target}") - - target_mac_range = target_mac_ranges[target] - if mac not in target_mac_range: - raise Exception( - f"Mac value for selected device should be in {target_mac_range}" - ) -- cgit v1.2.1