aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/target
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2023-01-11 12:32:02 +0000
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-02-08 15:23:29 +0000
commita4fb8c72f15146c95df16c25e75f03344e9814fd (patch)
treece6d9cf39951a0c85d2773d436cc5010ecf78a8f /src/mlia/target
parent09ecc5c8acb758e8def33155feb746a34dd7b560 (diff)
downloadmlia-a4fb8c72f15146c95df16c25e75f03344e9814fd.tar.gz
MLIA-591 Create interface for target profiles
New class 'TargetProfile' is used to load and verify target profiles. Change-Id: I76373a923e2e5f55c4e95860635afe9fc5627a5d
Diffstat (limited to 'src/mlia/target')
-rw-r--r--src/mlia/target/config.py99
-rw-r--r--src/mlia/target/cortex_a/advisor.py4
-rw-r--r--src/mlia/target/cortex_a/config.py23
-rw-r--r--src/mlia/target/ethos_u/advisor.py12
-rw-r--r--src/mlia/target/ethos_u/config.py73
-rw-r--r--src/mlia/target/tosa/advisor.py2
-rw-r--r--src/mlia/target/tosa/config.py24
7 files changed, 155 insertions, 82 deletions
diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py
index f257784..ec3fb4c 100644
--- a/src/mlia/target/config.py
+++ b/src/mlia/target/config.py
@@ -1,21 +1,110 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
-"""IP configuration module."""
+"""Target configuration module."""
from __future__ import annotations
+from abc import ABC
+from abc import abstractmethod
from dataclasses import dataclass
+from pathlib import Path
+from typing import Any
+from typing import cast
+from typing import TypeVar
+
+try:
+ import tomllib
+except ModuleNotFoundError:
+ import tomli as tomllib # type: ignore
from mlia.backend.registry import registry as backend_registry
from mlia.core.common import AdviceCategory
+from mlia.utils.filesystem import get_mlia_target_profiles_dir
+
+
+def get_profile_file(target_profile: str | Path) -> Path:
+ """Get the target profile toml file."""
+ if not target_profile:
+ raise Exception("Target profile is not provided.")
+
+ profile_file = Path(get_mlia_target_profiles_dir() / f"{target_profile}.toml")
+ if not profile_file.is_file():
+ profile_file = Path(target_profile)
+
+ if not profile_file.exists():
+ raise Exception(f"File not found: {profile_file}.")
+ return profile_file
+
+
+def load_profile(path: str | Path) -> dict[str, Any]:
+ """Get settings for the provided target profile."""
+ with open(path, "rb") as file:
+ profile = tomllib.load(file)
+
+ return cast(dict, profile)
+
+
+def get_builtin_supported_profile_names() -> list[str]:
+ """Return list of default profiles in the target profiles directory."""
+ return sorted(
+ [
+ item.stem
+ for item in get_mlia_target_profiles_dir().iterdir()
+ if item.is_file() and item.suffix == ".toml"
+ ]
+ )
+
+def get_target(target_profile: str | Path) -> str:
+ """Return target for the provided target_profile."""
+ profile_file = get_profile_file(target_profile)
+ profile = load_profile(profile_file)
+ return cast(str, profile["target"])
-class IPConfiguration: # pylint: disable=too-few-public-methods
- """Base class for IP configuration."""
+
+T = TypeVar("T", bound="TargetProfile")
+
+
+class TargetProfile(ABC):
+ """Base class for target profiles."""
def __init__(self, target: str) -> None:
- """Init IP configuration instance."""
+ """Init TargetProfile instance with the target name."""
self.target = target
+ @classmethod
+ def load(cls: type[T], path: str | Path) -> T:
+ """Load and verify a target profile from file and return new instance."""
+ profile = load_profile(path)
+
+ try:
+ new_instance = cls(**profile)
+ except KeyError as ex:
+ raise KeyError(f"Missing key in file {path}.") from ex
+
+ new_instance.verify()
+
+ return new_instance
+
+ @classmethod
+ def load_profile(cls: type[T], target_profile: str) -> T:
+ """Load a target profile by name."""
+ profile_file = get_profile_file(target_profile)
+ return cls.load(profile_file)
+
+ def save(self, path: str | Path) -> None:
+ """Save this target profile to a file."""
+ raise NotImplementedError("Saving target profiles is currently not supported.")
+
+ @abstractmethod
+ def verify(self) -> None:
+ """
+ Check that all attributes contain valid values etc.
+
+ Raises a ValueError, if an issue is detected.
+ """
+ if not self.target:
+ raise ValueError(f"Invalid target name: {self.target}")
+
@dataclass
class TargetInfo:
diff --git a/src/mlia/target/cortex_a/advisor.py b/src/mlia/target/cortex_a/advisor.py
index 52af592..518c9f1 100644
--- a/src/mlia/target/cortex_a/advisor.py
+++ b/src/mlia/target/cortex_a/advisor.py
@@ -58,7 +58,9 @@ class CortexAInferenceAdvisor(DefaultInferenceAdvisor):
target_profile = self.get_target_profile(context)
return [
- CortexAAdvisorStartedEvent(model, CortexAConfiguration(target_profile)),
+ CortexAAdvisorStartedEvent(
+ model, CortexAConfiguration.load_profile(target_profile)
+ ),
]
diff --git a/src/mlia/target/cortex_a/config.py b/src/mlia/target/cortex_a/config.py
index b2b51ea..fd39e0a 100644
--- a/src/mlia/target/cortex_a/config.py
+++ b/src/mlia/target/cortex_a/config.py
@@ -1,20 +1,23 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Cortex-A configuration."""
from __future__ import annotations
-from mlia.target.config import IPConfiguration
-from mlia.utils.filesystem import get_profile
+from typing import Any
+from mlia.target.config import TargetProfile
-class CortexAConfiguration(IPConfiguration): # pylint: disable=too-few-public-methods
+
+class CortexAConfiguration(TargetProfile):
"""Cortex-A configuration."""
- def __init__(self, target_profile: str) -> None:
+ def __init__(self, **kwargs: Any) -> None:
"""Init Cortex-A target configuration."""
- target_data = get_profile(target_profile)
-
- target = target_data["target"]
- if target != "cortex-a":
- raise Exception(f"Wrong target {target} for Cortex-A configuration")
+ target = kwargs["target"]
super().__init__(target)
+
+ def verify(self) -> None:
+ """Check the parameters."""
+ super().verify()
+ if self.target != "cortex-a":
+ raise ValueError(f"Wrong target {self.target} for Cortex-A configuration.")
diff --git a/src/mlia/target/ethos_u/advisor.py b/src/mlia/target/ethos_u/advisor.py
index 937e91c..225fd87 100644
--- a/src/mlia/target/ethos_u/advisor.py
+++ b/src/mlia/target/ethos_u/advisor.py
@@ -19,7 +19,6 @@ from mlia.nn.tensorflow.utils import is_tflite_model
from mlia.target.ethos_u.advice_generation import EthosUAdviceProducer
from mlia.target.ethos_u.advice_generation import EthosUStaticAdviceProducer
from mlia.target.ethos_u.config import EthosUConfiguration
-from mlia.target.ethos_u.config import get_target
from mlia.target.ethos_u.data_analysis import EthosUDataAnalyzer
from mlia.target.ethos_u.data_collection import EthosUOperatorCompatibility
from mlia.target.ethos_u.data_collection import EthosUOptimizationPerformance
@@ -40,7 +39,7 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor):
def get_collectors(self, context: Context) -> list[DataCollector]:
"""Return list of the data collectors."""
model = self.get_model(context)
- device = self._get_device(context)
+ device = self._get_device_cfg(context)
backends = self._get_backends(context)
collectors: list[DataCollector] = []
@@ -88,17 +87,16 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor):
def get_events(self, context: Context) -> list[Event]:
"""Return list of the startup events."""
model = self.get_model(context)
- device = self._get_device(context)
+ device = self._get_device_cfg(context)
return [
EthosUAdvisorStartedEvent(device=device, model=model),
]
- def _get_device(self, context: Context) -> EthosUConfiguration:
- """Get device."""
+ def _get_device_cfg(self, context: Context) -> EthosUConfiguration:
+ """Get device configuration."""
target_profile = self.get_target_profile(context)
-
- return get_target(target_profile)
+ return EthosUConfiguration.load_profile(target_profile)
def _get_optimization_settings(self, context: Context) -> list[list[dict]]:
"""Get optimization settings."""
diff --git a/src/mlia/target/ethos_u/config.py b/src/mlia/target/ethos_u/config.py
index 8d8f481..eb5691d 100644
--- a/src/mlia/target/ethos_u/config.py
+++ b/src/mlia/target/ethos_u/config.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Ethos-U configuration."""
from __future__ import annotations
@@ -8,36 +8,49 @@ from typing import Any
from mlia.backend.vela.compiler import resolve_compiler_config
from mlia.backend.vela.compiler import VelaCompilerOptions
-from mlia.target.config import IPConfiguration
-from mlia.utils.filesystem import get_profile
+from mlia.target.config import TargetProfile
from mlia.utils.filesystem import get_vela_config
logger = logging.getLogger(__name__)
-class EthosUConfiguration(IPConfiguration):
+class EthosUConfiguration(TargetProfile):
"""Ethos-U configuration."""
- def __init__(self, target_profile: str) -> None:
+ def __init__(self, **kwargs: Any) -> None:
"""Init Ethos-U target configuration."""
- target_data = get_profile(target_profile)
- _check_target_data_complete(target_data)
-
- target = target_data["target"]
+ target = kwargs["target"]
super().__init__(target)
- mac = target_data["mac"]
- _check_device_options_valid(target, mac)
+ mac = kwargs["mac"]
self.mac = mac
self.compiler_options = VelaCompilerOptions(
- system_config=target_data["system_config"],
- memory_mode=target_data["memory_mode"],
+ system_config=kwargs["system_config"],
+ memory_mode=kwargs["memory_mode"],
config_files=str(get_vela_config()),
accelerator_config=f"{self.target}-{mac}", # type: ignore
)
+ def verify(self) -> None:
+ """Check the parameters."""
+ super().verify()
+
+ target_mac_ranges = {
+ "ethos-u55": [32, 64, 128, 256],
+ "ethos-u65": [256, 512],
+ }
+
+ if self.target not in target_mac_ranges:
+ raise ValueError(f"Unsupported target: {self.target}")
+
+ target_mac_range = target_mac_ranges[self.target]
+ if self.mac not in target_mac_range:
+ raise ValueError(
+ f"Mac value for selected device should be in {target_mac_range}."
+ )
+
@property
def resolved_compiler_config(self) -> dict[str, Any]:
"""Resolve compiler configuration."""
@@ -54,37 +67,3 @@ class EthosUConfiguration(IPConfiguration):
def __repr__(self) -> str:
"""Return string representation."""
return f"<Ethos-U configuration target={self.target}>"
-
-
-def get_target(target_profile: str) -> EthosUConfiguration:
- """Get target instance based on provided params."""
- if not target_profile:
- raise Exception("No target profile given")
-
- return EthosUConfiguration(target_profile)
-
-
-def _check_target_data_complete(target_data: dict[str, Any]) -> None:
- """Check if profile contains all needed data."""
- mandatory_keys = {"target", "mac", "system_config", "memory_mode"}
- missing_keys = sorted(mandatory_keys - target_data.keys())
-
- if missing_keys:
- raise Exception(f"Mandatory fields missing from target profile: {missing_keys}")
-
-
-def _check_device_options_valid(target: str, mac: int) -> None:
- """Check if mac is valid for selected device."""
- target_mac_ranges = {
- "ethos-u55": [32, 64, 128, 256],
- "ethos-u65": [256, 512],
- }
-
- if target not in target_mac_ranges:
- raise Exception(f"Unsupported target: {target}")
-
- target_mac_range = target_mac_ranges[target]
- if mac not in target_mac_range:
- raise Exception(
- f"Mac value for selected device should be in {target_mac_range}"
- )
diff --git a/src/mlia/target/tosa/advisor.py b/src/mlia/target/tosa/advisor.py
index e8aad53..5588d0f 100644
--- a/src/mlia/target/tosa/advisor.py
+++ b/src/mlia/target/tosa/advisor.py
@@ -66,7 +66,7 @@ class TOSAInferenceAdvisor(DefaultInferenceAdvisor):
return [
TOSAAdvisorStartedEvent(
model,
- TOSAConfiguration(target_profile),
+ TOSAConfiguration.load_profile(target_profile),
MetadataDisplay(
TOSAMetadata("tosa-checker"),
MLIAMetadata("mlia"),
diff --git a/src/mlia/target/tosa/config.py b/src/mlia/target/tosa/config.py
index 22805b7..826e719 100644
--- a/src/mlia/target/tosa/config.py
+++ b/src/mlia/target/tosa/config.py
@@ -1,19 +1,21 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""TOSA target configuration."""
-from mlia.target.config import IPConfiguration
-from mlia.utils.filesystem import get_profile
+from typing import Any
+from mlia.target.config import TargetProfile
-class TOSAConfiguration(IPConfiguration): # pylint: disable=too-few-public-methods
+
+class TOSAConfiguration(TargetProfile):
"""TOSA configuration."""
- def __init__(self, target_profile: str) -> None:
+ def __init__(self, **kwargs: Any) -> None:
"""Init configuration."""
- target_data = get_profile(target_profile)
- target = target_data["target"]
-
- if target != "tosa":
- raise Exception(f"Wrong target {target} for TOSA configuration")
-
+ target = kwargs["target"]
super().__init__(target)
+
+ def verify(self) -> None:
+ """Check the parameters."""
+ super().verify()
+ if self.target != "tosa":
+ raise ValueError(f"Wrong target {self.target} for TOSA configuration.")