aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2023-01-11 12:32:02 +0000
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-02-08 15:23:29 +0000
commita4fb8c72f15146c95df16c25e75f03344e9814fd (patch)
treece6d9cf39951a0c85d2773d436cc5010ecf78a8f
parent09ecc5c8acb758e8def33155feb746a34dd7b560 (diff)
downloadmlia-a4fb8c72f15146c95df16c25e75f03344e9814fd.tar.gz
MLIA-591 Create interface for target profiles
New class 'TargetProfile' is used to load and verify target profiles. Change-Id: I76373a923e2e5f55c4e95860635afe9fc5627a5d
-rw-r--r--src/mlia/api.py2
-rw-r--r--src/mlia/cli/command_validators.py2
-rw-r--r--src/mlia/cli/options.py2
-rw-r--r--src/mlia/target/config.py99
-rw-r--r--src/mlia/target/cortex_a/advisor.py4
-rw-r--r--src/mlia/target/cortex_a/config.py23
-rw-r--r--src/mlia/target/ethos_u/advisor.py12
-rw-r--r--src/mlia/target/ethos_u/config.py73
-rw-r--r--src/mlia/target/tosa/advisor.py2
-rw-r--r--src/mlia/target/tosa/config.py24
-rw-r--r--src/mlia/utils/filesystem.py51
-rw-r--r--tests/conftest.py2
-rw-r--r--tests/test_backend_vela_compat.py4
-rw-r--r--tests/test_backend_vela_compiler.py8
-rw-r--r--tests/test_backend_vela_performance.py10
-rw-r--r--tests/test_cli_commands.py2
-rw-r--r--tests/test_target_config.py63
-rw-r--r--tests/test_target_cortex_a_reporters.py4
-rw-r--r--tests/test_target_ethos_u_config.py42
-rw-r--r--tests/test_target_ethos_u_data_analysis.py8
-rw-r--r--tests/test_target_ethos_u_data_collection.py8
-rw-r--r--tests/test_target_ethos_u_reporters.py2
-rw-r--r--tests/test_target_tosa_reporters.py2
-rw-r--r--tests/test_utils_filesystem.py27
-rw-r--r--tests_e2e/test_e2e.py2
25 files changed, 251 insertions, 227 deletions
diff --git a/src/mlia/api.py b/src/mlia/api.py
index 437c457..fd5fc13 100644
--- a/src/mlia/api.py
+++ b/src/mlia/api.py
@@ -10,10 +10,10 @@ from typing import Any
from mlia.core.advisor import InferenceAdvisor
from mlia.core.common import AdviceCategory
from mlia.core.context import ExecutionContext
+from mlia.target.config import get_target
from mlia.target.cortex_a.advisor import configure_and_get_cortexa_advisor
from mlia.target.ethos_u.advisor import configure_and_get_ethosu_advisor
from mlia.target.tosa.advisor import configure_and_get_tosa_advisor
-from mlia.utils.filesystem import get_target
logger = logging.getLogger(__name__)
diff --git a/src/mlia/cli/command_validators.py b/src/mlia/cli/command_validators.py
index 1974a1d..eb04192 100644
--- a/src/mlia/cli/command_validators.py
+++ b/src/mlia/cli/command_validators.py
@@ -8,8 +8,8 @@ import logging
import sys
from mlia.cli.config import get_default_backends
+from mlia.target.config import get_target
from mlia.target.registry import supported_backends
-from mlia.utils.filesystem import get_target
logger = logging.getLogger(__name__)
diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py
index 1b92958..dac8c82 100644
--- a/src/mlia/cli/options.py
+++ b/src/mlia/cli/options.py
@@ -13,7 +13,7 @@ from mlia.cli.config import DEFAULT_PRUNING_TARGET
from mlia.cli.config import get_available_backends
from mlia.cli.config import is_corstone_backend
from mlia.core.typing import OutputFormat
-from mlia.utils.filesystem import get_builtin_supported_profile_names
+from mlia.target.config import get_builtin_supported_profile_names
def add_check_category_options(parser: argparse.ArgumentParser) -> None:
diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py
index f257784..ec3fb4c 100644
--- a/src/mlia/target/config.py
+++ b/src/mlia/target/config.py
@@ -1,21 +1,110 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
-"""IP configuration module."""
+"""Target configuration module."""
from __future__ import annotations
+from abc import ABC
+from abc import abstractmethod
from dataclasses import dataclass
+from pathlib import Path
+from typing import Any
+from typing import cast
+from typing import TypeVar
+
+try:
+ import tomllib
+except ModuleNotFoundError:
+ import tomli as tomllib # type: ignore
from mlia.backend.registry import registry as backend_registry
from mlia.core.common import AdviceCategory
+from mlia.utils.filesystem import get_mlia_target_profiles_dir
+
+
+def get_profile_file(target_profile: str | Path) -> Path:
+ """Get the target profile toml file."""
+ if not target_profile:
+ raise Exception("Target profile is not provided.")
+
+ profile_file = Path(get_mlia_target_profiles_dir() / f"{target_profile}.toml")
+ if not profile_file.is_file():
+ profile_file = Path(target_profile)
+
+ if not profile_file.exists():
+ raise Exception(f"File not found: {profile_file}.")
+ return profile_file
+
+
+def load_profile(path: str | Path) -> dict[str, Any]:
+ """Get settings for the provided target profile."""
+ with open(path, "rb") as file:
+ profile = tomllib.load(file)
+
+ return cast(dict, profile)
+
+
+def get_builtin_supported_profile_names() -> list[str]:
+ """Return list of default profiles in the target profiles directory."""
+ return sorted(
+ [
+ item.stem
+ for item in get_mlia_target_profiles_dir().iterdir()
+ if item.is_file() and item.suffix == ".toml"
+ ]
+ )
+
+def get_target(target_profile: str | Path) -> str:
+ """Return target for the provided target_profile."""
+ profile_file = get_profile_file(target_profile)
+ profile = load_profile(profile_file)
+ return cast(str, profile["target"])
-class IPConfiguration: # pylint: disable=too-few-public-methods
- """Base class for IP configuration."""
+
+T = TypeVar("T", bound="TargetProfile")
+
+
+class TargetProfile(ABC):
+ """Base class for target profiles."""
def __init__(self, target: str) -> None:
- """Init IP configuration instance."""
+ """Init TargetProfile instance with the target name."""
self.target = target
+ @classmethod
+ def load(cls: type[T], path: str | Path) -> T:
+ """Load and verify a target profile from file and return new instance."""
+ profile = load_profile(path)
+
+ try:
+ new_instance = cls(**profile)
+ except KeyError as ex:
+ raise KeyError(f"Missing key in file {path}.") from ex
+
+ new_instance.verify()
+
+ return new_instance
+
+ @classmethod
+ def load_profile(cls: type[T], target_profile: str) -> T:
+ """Load a target profile by name."""
+ profile_file = get_profile_file(target_profile)
+ return cls.load(profile_file)
+
+ def save(self, path: str | Path) -> None:
+ """Save this target profile to a file."""
+ raise NotImplementedError("Saving target profiles is currently not supported.")
+
+ @abstractmethod
+ def verify(self) -> None:
+ """
+ Check that all attributes contain valid values etc.
+
+ Raises a ValueError, if an issue is detected.
+ """
+ if not self.target:
+ raise ValueError(f"Invalid target name: {self.target}")
+
@dataclass
class TargetInfo:
diff --git a/src/mlia/target/cortex_a/advisor.py b/src/mlia/target/cortex_a/advisor.py
index 52af592..518c9f1 100644
--- a/src/mlia/target/cortex_a/advisor.py
+++ b/src/mlia/target/cortex_a/advisor.py
@@ -58,7 +58,9 @@ class CortexAInferenceAdvisor(DefaultInferenceAdvisor):
target_profile = self.get_target_profile(context)
return [
- CortexAAdvisorStartedEvent(model, CortexAConfiguration(target_profile)),
+ CortexAAdvisorStartedEvent(
+ model, CortexAConfiguration.load_profile(target_profile)
+ ),
]
diff --git a/src/mlia/target/cortex_a/config.py b/src/mlia/target/cortex_a/config.py
index b2b51ea..fd39e0a 100644
--- a/src/mlia/target/cortex_a/config.py
+++ b/src/mlia/target/cortex_a/config.py
@@ -1,20 +1,23 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Cortex-A configuration."""
from __future__ import annotations
-from mlia.target.config import IPConfiguration
-from mlia.utils.filesystem import get_profile
+from typing import Any
+from mlia.target.config import TargetProfile
-class CortexAConfiguration(IPConfiguration): # pylint: disable=too-few-public-methods
+
+class CortexAConfiguration(TargetProfile):
"""Cortex-A configuration."""
- def __init__(self, target_profile: str) -> None:
+ def __init__(self, **kwargs: Any) -> None:
"""Init Cortex-A target configuration."""
- target_data = get_profile(target_profile)
-
- target = target_data["target"]
- if target != "cortex-a":
- raise Exception(f"Wrong target {target} for Cortex-A configuration")
+ target = kwargs["target"]
super().__init__(target)
+
+ def verify(self) -> None:
+ """Check the parameters."""
+ super().verify()
+ if self.target != "cortex-a":
+ raise ValueError(f"Wrong target {self.target} for Cortex-A configuration.")
diff --git a/src/mlia/target/ethos_u/advisor.py b/src/mlia/target/ethos_u/advisor.py
index 937e91c..225fd87 100644
--- a/src/mlia/target/ethos_u/advisor.py
+++ b/src/mlia/target/ethos_u/advisor.py
@@ -19,7 +19,6 @@ from mlia.nn.tensorflow.utils import is_tflite_model
from mlia.target.ethos_u.advice_generation import EthosUAdviceProducer
from mlia.target.ethos_u.advice_generation import EthosUStaticAdviceProducer
from mlia.target.ethos_u.config import EthosUConfiguration
-from mlia.target.ethos_u.config import get_target
from mlia.target.ethos_u.data_analysis import EthosUDataAnalyzer
from mlia.target.ethos_u.data_collection import EthosUOperatorCompatibility
from mlia.target.ethos_u.data_collection import EthosUOptimizationPerformance
@@ -40,7 +39,7 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor):
def get_collectors(self, context: Context) -> list[DataCollector]:
"""Return list of the data collectors."""
model = self.get_model(context)
- device = self._get_device(context)
+ device = self._get_device_cfg(context)
backends = self._get_backends(context)
collectors: list[DataCollector] = []
@@ -88,17 +87,16 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor):
def get_events(self, context: Context) -> list[Event]:
"""Return list of the startup events."""
model = self.get_model(context)
- device = self._get_device(context)
+ device = self._get_device_cfg(context)
return [
EthosUAdvisorStartedEvent(device=device, model=model),
]
- def _get_device(self, context: Context) -> EthosUConfiguration:
- """Get device."""
+ def _get_device_cfg(self, context: Context) -> EthosUConfiguration:
+ """Get device configuration."""
target_profile = self.get_target_profile(context)
-
- return get_target(target_profile)
+ return EthosUConfiguration.load_profile(target_profile)
def _get_optimization_settings(self, context: Context) -> list[list[dict]]:
"""Get optimization settings."""
diff --git a/src/mlia/target/ethos_u/config.py b/src/mlia/target/ethos_u/config.py
index 8d8f481..eb5691d 100644
--- a/src/mlia/target/ethos_u/config.py
+++ b/src/mlia/target/ethos_u/config.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Ethos-U configuration."""
from __future__ import annotations
@@ -8,36 +8,49 @@ from typing import Any
from mlia.backend.vela.compiler import resolve_compiler_config
from mlia.backend.vela.compiler import VelaCompilerOptions
-from mlia.target.config import IPConfiguration
-from mlia.utils.filesystem import get_profile
+from mlia.target.config import TargetProfile
from mlia.utils.filesystem import get_vela_config
logger = logging.getLogger(__name__)
-class EthosUConfiguration(IPConfiguration):
+class EthosUConfiguration(TargetProfile):
"""Ethos-U configuration."""
- def __init__(self, target_profile: str) -> None:
+ def __init__(self, **kwargs: Any) -> None:
"""Init Ethos-U target configuration."""
- target_data = get_profile(target_profile)
- _check_target_data_complete(target_data)
-
- target = target_data["target"]
+ target = kwargs["target"]
super().__init__(target)
- mac = target_data["mac"]
- _check_device_options_valid(target, mac)
+ mac = kwargs["mac"]
self.mac = mac
self.compiler_options = VelaCompilerOptions(
- system_config=target_data["system_config"],
- memory_mode=target_data["memory_mode"],
+ system_config=kwargs["system_config"],
+ memory_mode=kwargs["memory_mode"],
config_files=str(get_vela_config()),
accelerator_config=f"{self.target}-{mac}", # type: ignore
)
+ def verify(self) -> None:
+ """Check the parameters."""
+ super().verify()
+
+ target_mac_ranges = {
+ "ethos-u55": [32, 64, 128, 256],
+ "ethos-u65": [256, 512],
+ }
+
+ if self.target not in target_mac_ranges:
+ raise ValueError(f"Unsupported target: {self.target}")
+
+ target_mac_range = target_mac_ranges[self.target]
+ if self.mac not in target_mac_range:
+ raise ValueError(
+ f"Mac value for selected device should be in {target_mac_range}."
+ )
+
@property
def resolved_compiler_config(self) -> dict[str, Any]:
"""Resolve compiler configuration."""
@@ -54,37 +67,3 @@ class EthosUConfiguration(IPConfiguration):
def __repr__(self) -> str:
"""Return string representation."""
return f"<Ethos-U configuration target={self.target}>"
-
-
-def get_target(target_profile: str) -> EthosUConfiguration:
- """Get target instance based on provided params."""
- if not target_profile:
- raise Exception("No target profile given")
-
- return EthosUConfiguration(target_profile)
-
-
-def _check_target_data_complete(target_data: dict[str, Any]) -> None:
- """Check if profile contains all needed data."""
- mandatory_keys = {"target", "mac", "system_config", "memory_mode"}
- missing_keys = sorted(mandatory_keys - target_data.keys())
-
- if missing_keys:
- raise Exception(f"Mandatory fields missing from target profile: {missing_keys}")
-
-
-def _check_device_options_valid(target: str, mac: int) -> None:
- """Check if mac is valid for selected device."""
- target_mac_ranges = {
- "ethos-u55": [32, 64, 128, 256],
- "ethos-u65": [256, 512],
- }
-
- if target not in target_mac_ranges:
- raise Exception(f"Unsupported target: {target}")
-
- target_mac_range = target_mac_ranges[target]
- if mac not in target_mac_range:
- raise Exception(
- f"Mac value for selected device should be in {target_mac_range}"
- )
diff --git a/src/mlia/target/tosa/advisor.py b/src/mlia/target/tosa/advisor.py
index e8aad53..5588d0f 100644
--- a/src/mlia/target/tosa/advisor.py
+++ b/src/mlia/target/tosa/advisor.py
@@ -66,7 +66,7 @@ class TOSAInferenceAdvisor(DefaultInferenceAdvisor):
return [
TOSAAdvisorStartedEvent(
model,
- TOSAConfiguration(target_profile),
+ TOSAConfiguration.load_profile(target_profile),
MetadataDisplay(
TOSAMetadata("tosa-checker"),
MLIAMetadata("mlia"),
diff --git a/src/mlia/target/tosa/config.py b/src/mlia/target/tosa/config.py
index 22805b7..826e719 100644
--- a/src/mlia/target/tosa/config.py
+++ b/src/mlia/target/tosa/config.py
@@ -1,19 +1,21 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""TOSA target configuration."""
-from mlia.target.config import IPConfiguration
-from mlia.utils.filesystem import get_profile
+from typing import Any
+from mlia.target.config import TargetProfile
-class TOSAConfiguration(IPConfiguration): # pylint: disable=too-few-public-methods
+
+class TOSAConfiguration(TargetProfile):
"""TOSA configuration."""
- def __init__(self, target_profile: str) -> None:
+ def __init__(self, **kwargs: Any) -> None:
"""Init configuration."""
- target_data = get_profile(target_profile)
- target = target_data["target"]
-
- if target != "tosa":
- raise Exception(f"Wrong target {target} for TOSA configuration")
-
+ target = kwargs["target"]
super().__init__(target)
+
+ def verify(self) -> None:
+ """Check the parameters."""
+ super().verify()
+ if self.target != "tosa":
+ raise ValueError(f"Wrong target {self.target} for TOSA configuration.")
diff --git a/src/mlia/utils/filesystem.py b/src/mlia/utils/filesystem.py
index 4734a84..f92629b 100644
--- a/src/mlia/utils/filesystem.py
+++ b/src/mlia/utils/filesystem.py
@@ -11,16 +11,9 @@ from contextlib import contextmanager
from pathlib import Path
from tempfile import mkstemp
from tempfile import TemporaryDirectory
-from typing import Any
-from typing import cast
from typing import Generator
from typing import Iterable
-try:
- import tomllib
-except ModuleNotFoundError:
- import tomli as tomllib # type: ignore
-
def get_mlia_resources() -> Path:
"""Get the path to the resources directory."""
@@ -39,50 +32,6 @@ def get_mlia_target_profiles_dir() -> Path:
return get_mlia_resources() / "target_profiles"
-def get_profile_toml_file(target_profile: str | Path) -> str | Path:
- """Get the target profile toml file."""
- if not target_profile:
- raise Exception("Target profile is not provided")
-
- profile_toml_file = Path(get_mlia_target_profiles_dir() / f"{target_profile}.toml")
- if not profile_toml_file.is_file():
- profile_toml_file = Path(target_profile)
-
- if not profile_toml_file.exists():
- raise Exception(f"File not found: {profile_toml_file}.")
- return profile_toml_file
-
-
-def get_profile(target_profile: str | Path) -> dict[str, Any]:
- """Get settings for the provided target profile."""
- if not target_profile:
- raise Exception("Target profile is not provided")
-
- toml_file = get_profile_toml_file(target_profile)
-
- with open(toml_file, "rb") as file:
- profile = tomllib.load(file)
-
- return cast(dict, profile)
-
-
-def get_builtin_supported_profile_names() -> list[str]:
- """Return list of default profiles in the target profiles directory."""
- return sorted(
- [
- item.stem
- for item in get_mlia_target_profiles_dir().iterdir()
- if item.is_file() and item.suffix == ".toml"
- ]
- )
-
-
-def get_target(target_profile: str | Path) -> str:
- """Return target for the provided target_profile."""
- profile = get_profile(target_profile)
- return cast(str, profile["target"])
-
-
@contextmanager
def temp_file(suffix: str | None = None) -> Generator[Path, None, None]:
"""Create temp file and remove it after."""
diff --git a/tests/conftest.py b/tests/conftest.py
index b698a73..67549e7 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -157,7 +157,7 @@ def fixture_test_models_path(
save_tflite_model(tflite_model, tflite_model_path)
tflite_vela_model = tmp_path / "test_model_vela.tflite"
- device = EthosUConfiguration("ethos-u55-256")
+ device = EthosUConfiguration.load_profile("ethos-u55-256")
optimize_model(tflite_model_path, device.compiler_options, tflite_vela_model)
tf.saved_model.save(keras_model, str(tmp_path / "tf_model_test_model"))
diff --git a/tests/test_backend_vela_compat.py b/tests/test_backend_vela_compat.py
index a2e7f90..4653d7d 100644
--- a/tests/test_backend_vela_compat.py
+++ b/tests/test_backend_vela_compat.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module vela/compat."""
from pathlib import Path
@@ -55,7 +55,7 @@ from mlia.utils.filesystem import working_directory
)
def test_operators(test_models_path: Path, model: str, expected_ops: Operators) -> None:
"""Test operators function."""
- device = EthosUConfiguration("ethos-u55-256")
+ device = EthosUConfiguration.load_profile("ethos-u55-256")
operators = supported_operators(test_models_path / model, device.compiler_options)
for expected, actual in zip(expected_ops.ops, operators.ops):
diff --git a/tests/test_backend_vela_compiler.py b/tests/test_backend_vela_compiler.py
index 20121d6..2d937ea 100644
--- a/tests/test_backend_vela_compiler.py
+++ b/tests/test_backend_vela_compiler.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module vela/compiler."""
from pathlib import Path
@@ -146,7 +146,9 @@ def test_vela_compiler_with_parameters(test_resources_path: Path) -> None:
def test_compile_model(test_tflite_model: Path) -> None:
"""Test model optimization."""
- compiler = VelaCompiler(EthosUConfiguration("ethos-u55-256").compiler_options)
+ compiler = VelaCompiler(
+ EthosUConfiguration.load_profile("ethos-u55-256").compiler_options
+ )
optimized_model = compiler.compile_model(test_tflite_model)
assert isinstance(optimized_model, OptimizedModel)
@@ -156,7 +158,7 @@ def test_optimize_model(tmp_path: Path, test_tflite_model: Path) -> None:
"""Test model optimization and saving into file."""
tmp_file = tmp_path / "temp.tflite"
- device = EthosUConfiguration("ethos-u55-256")
+ device = EthosUConfiguration.load_profile("ethos-u55-256")
optimize_model(test_tflite_model, device.compiler_options, tmp_file.absolute())
assert tmp_file.is_file()
diff --git a/tests/test_backend_vela_performance.py b/tests/test_backend_vela_performance.py
index 34c11ab..569de61 100644
--- a/tests/test_backend_vela_performance.py
+++ b/tests/test_backend_vela_performance.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module vela/performance."""
from pathlib import Path
@@ -14,7 +14,7 @@ from mlia.target.ethos_u.config import EthosUConfiguration
def test_estimate_performance(test_tflite_model: Path) -> None:
"""Test getting performance estimations."""
- device = EthosUConfiguration("ethos-u55-256")
+ device = EthosUConfiguration.load_profile("ethos-u55-256")
perf_metrics = estimate_performance(test_tflite_model, device.compiler_options)
assert isinstance(perf_metrics, PerformanceMetrics)
@@ -24,7 +24,7 @@ def test_estimate_performance_already_optimized(
tmp_path: Path, test_tflite_model: Path
) -> None:
"""Test that performance estimation should fail for already optimized model."""
- device = EthosUConfiguration("ethos-u55-256")
+ device = EthosUConfiguration.load_profile("ethos-u55-256")
optimized_model_path = tmp_path / "optimized_model.tflite"
@@ -41,7 +41,7 @@ def test_read_invalid_model(test_tflite_invalid_model: Path) -> None:
with pytest.raises(
Exception, match=f"Unable to read model {test_tflite_invalid_model}"
):
- device = EthosUConfiguration("ethos-u55-256")
+ device = EthosUConfiguration.load_profile("ethos-u55-256")
estimate_performance(test_tflite_invalid_model, device.compiler_options)
@@ -58,7 +58,7 @@ def test_compile_invalid_model(
with pytest.raises(
Exception, match="Model could not be optimized with Vela compiler"
):
- device = EthosUConfiguration("ethos-u55-256")
+ device = EthosUConfiguration.load_profile("ethos-u55-256")
optimize_model(test_tflite_model, device.compiler_options, model_path)
assert not model_path.exists()
diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py
index b65d90e..61cc5a6 100644
--- a/tests/test_cli_commands.py
+++ b/tests/test_cli_commands.py
@@ -79,7 +79,7 @@ def test_opt_valid_optimization_target(
def mock_performance_estimation(monkeypatch: pytest.MonkeyPatch) -> None:
"""Mock performance estimation."""
metrics = PerformanceMetrics(
- EthosUConfiguration("ethos-u55-256"),
+ EthosUConfiguration.load_profile("ethos-u55-256"),
NPUCycles(1, 2, 3, 4, 5, 6),
MemoryUsage(1, 2, 3, 4, 5),
)
diff --git a/tests/test_target_config.py b/tests/test_target_config.py
index 48f0a58..26f524e 100644
--- a/tests/test_target_config.py
+++ b/tests/test_target_config.py
@@ -9,15 +9,68 @@ from mlia.backend.config import BackendConfiguration
from mlia.backend.config import BackendType
from mlia.backend.config import System
from mlia.core.common import AdviceCategory
-from mlia.target.config import IPConfiguration
+from mlia.target.config import get_builtin_supported_profile_names
+from mlia.target.config import get_profile_file
+from mlia.target.config import load_profile
from mlia.target.config import TargetInfo
+from mlia.target.config import TargetProfile
from mlia.utils.registry import Registry
-def test_ip_config() -> None:
- """Test the class 'IPConfiguration'."""
- cfg = IPConfiguration("AnyTarget")
- assert cfg.target == "AnyTarget"
+def test_get_builtin_supported_profile_names() -> None:
+ """Test profile names getter."""
+ assert get_builtin_supported_profile_names() == [
+ "cortex-a",
+ "ethos-u55-128",
+ "ethos-u55-256",
+ "ethos-u65-256",
+ "ethos-u65-512",
+ "tosa",
+ ]
+
+
+def test_get_profile_file() -> None:
+ """Test function 'get_profile_file'."""
+ profile_file = get_profile_file("cortex-a")
+ assert profile_file.is_file()
+ assert profile_file == get_profile_file(profile_file)
+
+ with pytest.raises(Exception):
+ get_profile_file("UNKNOWN")
+ with pytest.raises(Exception):
+ get_profile_file("")
+
+
+def test_load_profile() -> None:
+ """Test getting profile data."""
+ profile_file = get_profile_file("ethos-u55-256")
+ assert load_profile(profile_file) == {
+ "target": "ethos-u55",
+ "mac": 256,
+ "memory_mode": "Shared_Sram",
+ "system_config": "Ethos_U55_High_End_Embedded",
+ }
+
+ with pytest.raises(Exception, match=r"No such file or directory: 'unknown'"):
+ load_profile("unknown")
+
+
+def test_target_profile() -> None:
+ """Test the class 'TargetProfile'."""
+
+ class MyTargetProfile(TargetProfile):
+ """Test class deriving from TargetProfile."""
+
+ def verify(self) -> None:
+ super().verify()
+ assert self.target
+
+ profile = MyTargetProfile("AnyTarget")
+ assert profile.target == "AnyTarget"
+
+ profile = MyTargetProfile("")
+ with pytest.raises(ValueError):
+ profile.verify()
@pytest.mark.parametrize(
diff --git a/tests/test_target_cortex_a_reporters.py b/tests/test_target_cortex_a_reporters.py
index 4b39aa1..c32ef7b 100644
--- a/tests/test_target_cortex_a_reporters.py
+++ b/tests/test_target_cortex_a_reporters.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for Cortex-A reporters."""
from typing import Any
@@ -18,7 +18,7 @@ from mlia.target.cortex_a.reporters import report_device
def test_report_device() -> None:
"""Test function report_device()."""
- report = report_device(CortexAConfiguration("cortex-a"))
+ report = report_device(CortexAConfiguration.load_profile("cortex-a"))
assert report.to_plain_text()
diff --git a/tests/test_target_ethos_u_config.py b/tests/test_target_ethos_u_config.py
index 08a20ff..7f13b26 100644
--- a/tests/test_target_ethos_u_config.py
+++ b/tests/test_target_ethos_u_config.py
@@ -5,14 +5,11 @@ from __future__ import annotations
from contextlib import ExitStack as does_not_raise
from typing import Any
-from unittest.mock import MagicMock
import pytest
from mlia.backend.vela.compiler import VelaCompilerOptions
from mlia.target.ethos_u.config import EthosUConfiguration
-from mlia.target.ethos_u.config import get_target
-from mlia.utils.filesystem import get_vela_config
def test_compiler_options_default_init() -> None:
@@ -33,48 +30,28 @@ def test_compiler_options_default_init() -> None:
def test_ethosu_target() -> None:
"""Test Ethos-U target configuration init."""
- default_config = EthosUConfiguration("ethos-u55-256")
+ default_config = EthosUConfiguration.load_profile("ethos-u55-256")
assert default_config.target == "ethos-u55"
assert default_config.mac == 256
assert default_config.compiler_options is not None
-def test_get_target() -> None:
- """Test function get_target."""
- with pytest.raises(Exception, match="No target profile given"):
- get_target(None) # type: ignore
-
- with pytest.raises(Exception, match=r"File not found:*"):
- get_target("unknown")
-
- u65_device = get_target("ethos-u65-512")
-
- assert isinstance(u65_device, EthosUConfiguration)
- assert u65_device.target == "ethos-u65"
- assert u65_device.mac == 512
- assert u65_device.compiler_options.accelerator_config == "ethos-u65-512"
- assert u65_device.compiler_options.memory_mode == "Dedicated_Sram"
- assert u65_device.compiler_options.config_files == str(get_vela_config())
-
-
@pytest.mark.parametrize(
"profile_data, expected_error",
[
[
{},
pytest.raises(
- Exception,
- match="Mandatory fields missing from target profile: "
- r"\['mac', 'memory_mode', 'system_config', 'target'\]",
+ KeyError,
+ match=r"'target'",
),
],
[
{"target": "ethos-u65", "mac": 512},
pytest.raises(
- Exception,
- match="Mandatory fields missing from target profile: "
- r"\['memory_mode', 'system_config'\]",
+ KeyError,
+ match=r"'system_config'",
),
],
[
@@ -114,12 +91,9 @@ def test_get_target() -> None:
],
)
def test_ethosu_configuration(
- monkeypatch: pytest.MonkeyPatch, profile_data: dict[str, Any], expected_error: Any
+ profile_data: dict[str, Any], expected_error: Any
) -> None:
"""Test creating Ethos-U configuration."""
- monkeypatch.setattr(
- "mlia.target.ethos_u.config.get_profile", MagicMock(return_value=profile_data)
- )
-
with expected_error:
- EthosUConfiguration("target")
+ cfg = EthosUConfiguration(**profile_data)
+ cfg.verify()
diff --git a/tests/test_target_ethos_u_data_analysis.py b/tests/test_target_ethos_u_data_analysis.py
index bac27ad..e919f5d 100644
--- a/tests/test_target_ethos_u_data_analysis.py
+++ b/tests/test_target_ethos_u_data_analysis.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for Ethos-U data analysis module."""
from __future__ import annotations
@@ -84,7 +84,7 @@ def test_perf_metrics_diff() -> None:
[
OptimizationPerformanceMetrics(
PerformanceMetrics(
- EthosUConfiguration("ethos-u55-256"),
+ EthosUConfiguration.load_profile("ethos-u55-256"),
NPUCycles(1, 2, 3, 4, 5, 6),
# memory metrics are in kilobytes
MemoryUsage(*[i * 1024 for i in range(1, 6)]), # type: ignore
@@ -95,7 +95,7 @@ def test_perf_metrics_diff() -> None:
OptimizationSettings("pruning", 0.5, None),
],
PerformanceMetrics(
- EthosUConfiguration("ethos-u55-256"),
+ EthosUConfiguration.load_profile("ethos-u55-256"),
NPUCycles(1, 2, 3, 4, 5, 6),
# memory metrics are in kilobytes
MemoryUsage(
@@ -127,7 +127,7 @@ def test_perf_metrics_diff() -> None:
[
OptimizationPerformanceMetrics(
PerformanceMetrics(
- EthosUConfiguration("ethos-u55-256"),
+ EthosUConfiguration.load_profile("ethos-u55-256"),
NPUCycles(1, 2, 3, 4, 5, 6),
# memory metrics are in kilobytes
MemoryUsage(*[i * 1024 for i in range(1, 6)]), # type: ignore
diff --git a/tests/test_target_ethos_u_data_collection.py b/tests/test_target_ethos_u_data_collection.py
index 2cf7482..829d2a7 100644
--- a/tests/test_target_ethos_u_data_collection.py
+++ b/tests/test_target_ethos_u_data_collection.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the data collection module for Ethos-U."""
from pathlib import Path
@@ -50,7 +50,7 @@ def test_operator_compatibility_collector(
sample_context: Context, test_tflite_model: Path
) -> None:
"""Test operator compatibility data collector."""
- device = EthosUConfiguration("ethos-u55-256")
+ device = EthosUConfiguration.load_profile("ethos-u55-256")
collector = EthosUOperatorCompatibility(test_tflite_model, device)
collector.set_context(sample_context)
@@ -63,7 +63,7 @@ def test_performance_collector(
monkeypatch: pytest.MonkeyPatch, sample_context: Context, test_tflite_model: Path
) -> None:
"""Test performance data collector."""
- device = EthosUConfiguration("ethos-u55-256")
+ device = EthosUConfiguration.load_profile("ethos-u55-256")
mock_performance_estimation(monkeypatch, device)
@@ -81,7 +81,7 @@ def test_optimization_performance_collector(
test_tflite_model: Path,
) -> None:
"""Test optimization performance data collector."""
- device = EthosUConfiguration("ethos-u55-256")
+ device = EthosUConfiguration.load_profile("ethos-u55-256")
mock_performance_estimation(monkeypatch, device)
collector = EthosUOptimizationPerformance(
diff --git a/tests/test_target_ethos_u_reporters.py b/tests/test_target_ethos_u_reporters.py
index bc764a0..0c5764e 100644
--- a/tests/test_target_ethos_u_reporters.py
+++ b/tests/test_target_ethos_u_reporters.py
@@ -118,7 +118,7 @@ def test_report_operators(
"device, expected_plain_text, expected_json_dict",
[
[
- EthosUConfiguration("ethos-u55-256"),
+ EthosUConfiguration.load_profile("ethos-u55-256"),
"""Device information:
Target ethos-u55
MAC 256
diff --git a/tests/test_target_tosa_reporters.py b/tests/test_target_tosa_reporters.py
index 59da270..43d2a56 100644
--- a/tests/test_target_tosa_reporters.py
+++ b/tests/test_target_tosa_reporters.py
@@ -18,7 +18,7 @@ from mlia.target.tosa.reporters import tosa_formatters
def test_tosa_report_device() -> None:
"""Test function report_device()."""
- report = report_device(TOSAConfiguration("tosa"))
+ report = report_device(TOSAConfiguration.load_profile("tosa"))
assert report.to_plain_text()
diff --git a/tests/test_utils_filesystem.py b/tests/test_utils_filesystem.py
index 954f9e3..d0a6e6f 100644
--- a/tests/test_utils_filesystem.py
+++ b/tests/test_utils_filesystem.py
@@ -9,10 +9,8 @@ import pytest
from mlia.utils.filesystem import all_files_exist
from mlia.utils.filesystem import all_paths_valid
from mlia.utils.filesystem import copy_all
-from mlia.utils.filesystem import get_builtin_supported_profile_names
from mlia.utils.filesystem import get_mlia_resources
from mlia.utils.filesystem import get_mlia_target_profiles_dir
-from mlia.utils.filesystem import get_profile
from mlia.utils.filesystem import get_vela_config
from mlia.utils.filesystem import sha256
from mlia.utils.filesystem import temp_directory
@@ -36,31 +34,6 @@ def test_get_mlia_target_profiles() -> None:
assert get_mlia_target_profiles_dir().is_dir()
-def test_get_builtin_supported_profile_names() -> None:
- """Test profile names getter."""
- assert get_builtin_supported_profile_names() == [
- "cortex-a",
- "ethos-u55-128",
- "ethos-u55-256",
- "ethos-u65-256",
- "ethos-u65-512",
- "tosa",
- ]
-
-
-def test_get_profile() -> None:
- """Test getting profile data."""
- assert get_profile("ethos-u55-256") == {
- "target": "ethos-u55",
- "mac": 256,
- "memory_mode": "Shared_Sram",
- "system_config": "Ethos_U55_High_End_Embedded",
- }
-
- with pytest.raises(Exception, match=r"File not found:*"):
- get_profile("unknown")
-
-
@pytest.mark.parametrize("raise_exception", [True, False])
def test_temp_file(raise_exception: bool) -> None:
"""Test temp_file context manager."""
diff --git a/tests_e2e/test_e2e.py b/tests_e2e/test_e2e.py
index 546bd7e..d09d0ab 100644
--- a/tests_e2e/test_e2e.py
+++ b/tests_e2e/test_e2e.py
@@ -24,7 +24,7 @@ from mlia.cli.config import get_available_backends
from mlia.cli.main import get_commands
from mlia.cli.main import get_possible_command_names
from mlia.cli.main import init_parser
-from mlia.utils.filesystem import get_builtin_supported_profile_names
+from mlia.target.config import get_builtin_supported_profile_names
from mlia.utils.types import is_list_of