aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2024-03-21 17:33:17 +0000
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2024-03-22 10:06:28 +0000
commitc7ee5b783f044d7ff641773aa385840f5ff944cc (patch)
tree297f308978b00282d8ebd3a1f71e1ae5e678a767 /src
parent508281df31dc3c18f2e007f4dd505160342a681a (diff)
downloadmlia-c7ee5b783f044d7ff641773aa385840f5ff944cc.tar.gz
refactor: Backend dependencies and more
- Add backend dependencies: One backend can now depend on another backend. - Re-factor 'DownloadArtifact': - Rename 'DownloadArtifact' to 'DownloadConfig' - Remove attributes 'name' and 'version' not relevant for downloads - Add helper properties: - 'filename' parses the URL to extract the file name from the end - 'headers' calls the function to generate a HTML header for the download - Add OutputLogger helper class - Re-factor handling of backend configurations in the target profiles. Change-Id: Ifda6cf12c375d0c1747d7e4130a0370d22c3d33a Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
Diffstat (limited to 'src')
-rw-r--r--src/mlia/backend/corstone/install.py9
-rw-r--r--src/mlia/backend/corstone/performance.py8
-rw-r--r--src/mlia/backend/install.py40
-rw-r--r--src/mlia/backend/manager.py56
-rw-r--r--src/mlia/target/config.py4
-rw-r--r--src/mlia/target/cortex_a/config.py10
-rw-r--r--src/mlia/utils/download.py62
-rw-r--r--src/mlia/utils/proc.py13
8 files changed, 137 insertions, 65 deletions
diff --git a/src/mlia/backend/corstone/install.py b/src/mlia/backend/corstone/install.py
index 5f11d5b..5c18334 100644
--- a/src/mlia/backend/corstone/install.py
+++ b/src/mlia/backend/corstone/install.py
@@ -19,7 +19,7 @@ from mlia.backend.install import CompoundPathChecker
from mlia.backend.install import Installation
from mlia.backend.install import PackagePathChecker
from mlia.backend.install import StaticPathChecker
-from mlia.utils.download import DownloadArtifact
+from mlia.utils.download import DownloadConfig
from mlia.utils.filesystem import working_directory
@@ -159,8 +159,6 @@ def get_corstone_installation(corstone_name: str) -> Installation:
archive = corstone_fvp.archive
sha256_hash = corstone_fvp.sha256_hash
url = ARM_ECOSYSTEM_FVP_URL + archive
- filename = archive.split("/")[-1]
- version = corstone_fvp.get_fvp_version()
expected_files_fvp = corstone_fvp.fvp_expected_files
expected_files_vht = corstone_fvp.get_vht_expected_files()
backend_subfolder = expected_files_fvp[0].split("FVP")[0]
@@ -169,11 +167,8 @@ def get_corstone_installation(corstone_name: str) -> Installation:
name=corstone_name,
description=corstone_name.capitalize() + " FVP",
fvp_dir_name=corstone_name.replace("-", "_"),
- download_artifact=DownloadArtifact(
- name=corstone_name.capitalize() + " FVP",
+ download_config=DownloadConfig(
url=url,
- filename=filename,
- version=version,
sha256_hash=sha256_hash,
),
supported_platforms=["Linux"],
diff --git a/src/mlia/backend/corstone/performance.py b/src/mlia/backend/corstone/performance.py
index fc50109..fe4e271 100644
--- a/src/mlia/backend/corstone/performance.py
+++ b/src/mlia/backend/corstone/performance.py
@@ -15,6 +15,7 @@ from mlia.backend.errors import BackendExecutionFailed
from mlia.backend.repo import get_backend_repository
from mlia.utils.filesystem import get_mlia_resources
from mlia.utils.proc import Command
+from mlia.utils.proc import OutputLogger
from mlia.utils.proc import process_command_output
@@ -187,15 +188,12 @@ def get_metrics(
) from err
output_parser = GenericInferenceOutputParser()
-
- def redirect_to_log(line: str) -> None:
- """Redirect FVP output to the logger."""
- logger.debug(line.strip())
+ output_logger = OutputLogger(logger)
try:
process_command_output(
command,
- [output_parser, redirect_to_log],
+ [output_parser, output_logger],
)
except subprocess.CalledProcessError as err:
raise BackendExecutionFailed("Backend execution failed.") from err
diff --git a/src/mlia/backend/install.py b/src/mlia/backend/install.py
index 0ced9f6..1a7d58b 100644
--- a/src/mlia/backend/install.py
+++ b/src/mlia/backend/install.py
@@ -16,7 +16,8 @@ from typing import Optional
from typing import Union
from mlia.backend.repo import get_backend_repository
-from mlia.utils.download import DownloadArtifact
+from mlia.utils.download import download
+from mlia.utils.download import DownloadConfig
from mlia.utils.filesystem import all_files_exist
from mlia.utils.filesystem import temp_directory
from mlia.utils.filesystem import working_directory
@@ -45,10 +46,18 @@ InstallationType = Union[InstallFromPath, DownloadAndInstall]
class Installation(ABC):
"""Base class for the installation process of the backends."""
- def __init__(self, name: str, description: str) -> None:
+ def __init__(
+ self, name: str, description: str, dependencies: list[str] | None = None
+ ) -> None:
"""Init the installation."""
+ assert not dependencies or name not in dependencies, (
+ f"Invalid backend configuration: Backend '{name}' has itself as a "
+ "dependency! The backend source code needs to be fixed."
+ )
+
self.name = name
self.description = description
+ self.dependencies = dependencies if dependencies else []
@property
@abstractmethod
@@ -89,21 +98,22 @@ BackendInstaller = Callable[[bool, Path], Path]
class BackendInstallation(Installation):
"""Backend installation."""
- def __init__(
+ def __init__( # pylint: disable=too-many-arguments
self,
name: str,
description: str,
fvp_dir_name: str,
- download_artifact: DownloadArtifact | None,
+ download_config: DownloadConfig | None,
supported_platforms: list[str] | None,
path_checker: PathChecker,
backend_installer: BackendInstaller | None,
+ dependencies: list[str] | None = None,
) -> None:
"""Init the backend installation."""
- super().__init__(name, description)
+ super().__init__(name, description, dependencies)
self.fvp_dir_name = fvp_dir_name
- self.download_artifact = download_artifact
+ self.download_config = download_config
self.supported_platforms = supported_platforms
self.path_checker = path_checker
self.backend_installer = backend_installer
@@ -125,7 +135,7 @@ class BackendInstallation(Installation):
def supports(self, install_type: InstallationType) -> bool:
"""Return true if backends supported type of the installation."""
if isinstance(install_type, DownloadAndInstall):
- return self.download_artifact is not None
+ return self.download_config is not None
if isinstance(install_type, InstallFromPath):
return self.path_checker(install_type.backend_path) is not None
@@ -135,10 +145,10 @@ class BackendInstallation(Installation):
def install(self, install_type: InstallationType) -> None:
"""Install the backend."""
if isinstance(install_type, DownloadAndInstall):
- assert self.download_artifact is not None, "No artifact provided"
+ assert self.download_config is not None, "No artifact provided"
self._download_and_install(
- self.download_artifact, install_type.eula_agreement
+ self.download_config, install_type.eula_agreement
)
elif isinstance(install_type, InstallFromPath):
backend_info = self.path_checker(install_type.backend_path)
@@ -207,13 +217,17 @@ class BackendInstallation(Installation):
ex,
)
- def _download_and_install(
- self, download_artifact: DownloadArtifact, eula_agrement: bool
- ) -> None:
+ def _download_and_install(self, cfg: DownloadConfig, eula_agrement: bool) -> None:
"""Download and install the backend."""
with temp_directory() as tmpdir:
try:
- dest = download_artifact.download_to(tmpdir)
+ dest = tmpdir / cfg.filename
+ download(
+ dest=dest,
+ cfg=cfg,
+ show_progress=True,
+ )
+
except Exception as err:
raise RuntimeError("Unable to download backend artifact.") from err
diff --git a/src/mlia/backend/manager.py b/src/mlia/backend/manager.py
index a4bc8c0..a752791 100644
--- a/src/mlia/backend/manager.py
+++ b/src/mlia/backend/manager.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for installation process."""
from __future__ import annotations
@@ -188,6 +188,33 @@ class DefaultInstallationManager(InstallationManager, InstallationFiltersMixin):
logger.info("%s installation canceled.", installation.name)
return
+ for dependency in installation.dependencies:
+ deps = self.find_by_name(dependency)
+ if not deps:
+ raise ValueError(
+ f"Backend {installation.name} depends on "
+ f"unknown backend '{dependency}'."
+ )
+ missing_deps = [dep for dep in deps if not dep.already_installed]
+ if missing_deps:
+ proceed = self.noninteractive or yes(
+ "The following dependencies are not installed: "
+ f"{[dep.name for dep in missing_deps]}. "
+ "Continue installation anyway?"
+ )
+ logger.warning(
+ "Installing backend %s with the following dependencies "
+ "missing: %s",
+ installation.name,
+ missing_deps,
+ )
+ if not proceed:
+ logger.info(
+ "%s installation canceled due to missing dependencies.",
+ installation.name,
+ )
+ return
+
if installation.already_installed and force:
logger.info(
"Force installing %s, so delete the existing "
@@ -254,16 +281,37 @@ class DefaultInstallationManager(InstallationManager, InstallationFiltersMixin):
installations = self.already_installed(backend_name)
if not installations:
- raise ConfigurationError(f"Backend '{backend_name}' is not installed")
+ raise ConfigurationError(f"Backend '{backend_name}' is not installed.")
if len(installations) != 1:
raise InternalError(
- f"More than one installed backend with name {backend_name} found"
+ f"More than one installed backend with name {backend_name} found."
)
installation = installations[0]
- installation.uninstall()
+ dependent_backends = [
+ dep.name
+ for dep in self.installations
+ if dep.name in installation.dependencies
+ ]
+ if dependent_backends:
+ msg = (
+ f"The following backends depend on '{installation.name}' which "
+ f"you are about to uninstall: {dependent_backends}",
+ )
+ proceed = self.noninteractive or yes(
+ f"{msg}. Continue uninstalling anyway?"
+ )
+ logger.warning(msg)
+ if not proceed:
+ logger.info(
+ "Uninstalling %s canceled due to dependencies.",
+ installation.name,
+ )
+ return
+
+ installation.uninstall()
logger.info("%s successfully uninstalled.", installation.name)
def backend_installed(self, backend_name: str) -> bool:
diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py
index 8ccdad8..8492086 100644
--- a/src/mlia/target/config.py
+++ b/src/mlia/target/config.py
@@ -85,9 +85,11 @@ T = TypeVar("T", bound="TargetProfile")
class TargetProfile(ABC):
"""Base class for target profiles."""
- def __init__(self, target: str) -> None:
+ def __init__(self, target: str, backend_config: dict | None = None) -> None:
"""Init TargetProfile instance with the target name."""
self.target = target
+ # Load backend config(s) to be handled by the backend(s) later.
+ self.backend_config = {} if backend_config is None else backend_config
@classmethod
def load(cls: type[T], path: str | Path) -> T:
diff --git a/src/mlia/target/cortex_a/config.py b/src/mlia/target/cortex_a/config.py
index f91031e..4f33f3d 100644
--- a/src/mlia/target/cortex_a/config.py
+++ b/src/mlia/target/cortex_a/config.py
@@ -15,12 +15,12 @@ class CortexAConfiguration(TargetProfile):
def __init__(self, **kwargs: Any) -> None:
"""Init Cortex-A target configuration."""
target = kwargs["target"]
- super().__init__(target)
+ backend_config = kwargs.get("backend")
+ super().__init__(target, backend_config)
- self.backend_config = kwargs.get("backend")
- self.armnn_tflite_delegate_version = kwargs["backend"]["armnn-tflite-delegate"][
- "version"
- ]
+ self.armnn_tflite_delegate_version = self.backend_config[
+ "armnn-tflite-delegate"
+ ]["version"]
def verify(self) -> None:
"""Check the parameters."""
diff --git a/src/mlia/utils/download.py b/src/mlia/utils/download.py
index e00be28..2b06fed 100644
--- a/src/mlia/utils/download.py
+++ b/src/mlia/utils/download.py
@@ -1,10 +1,11 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Utils for files downloading."""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
+from typing import Callable
from typing import Iterable
import requests
@@ -40,20 +41,46 @@ def download_progress(
yield chunk
+@dataclass
+class DownloadConfig:
+ """Parameters to download an artifact."""
+
+ url: str
+ sha256_hash: str
+ header_gen_fn: Callable[[], dict[str, str]] | None = None
+
+ @property
+ def filename(self) -> str:
+ """Get the filename from the URL."""
+ return self.url.rsplit("/", 1)[-1]
+
+ @property
+ def headers(self) -> dict[str, str]:
+ """Get the headers using the header_gen_fn."""
+ return self.header_gen_fn() if self.header_gen_fn else {}
+
+
def download(
- url: str,
dest: Path,
+ cfg: DownloadConfig,
show_progress: bool = False,
label: str | None = None,
chunk_size: int = 8192,
timeout: int = 30,
) -> None:
"""Download the file."""
- with requests.get(url, stream=True, timeout=timeout) as resp:
+ if dest.exists():
+ raise FileExistsError(f"{dest} already exists.")
+
+ with requests.get(
+ cfg.url, stream=True, timeout=timeout, headers=cfg.headers
+ ) as resp:
resp.raise_for_status()
content_chunks = resp.iter_content(chunk_size=chunk_size)
if show_progress:
+ if not label:
+ label = f"Downloading to {dest}."
content_length = parse_int(resp.headers.get("Content-Length"))
content_chunks = download_progress(content_chunks, content_length, label)
@@ -61,30 +88,5 @@ def download(
for chunk in content_chunks:
file.write(chunk)
-
-@dataclass
-class DownloadArtifact:
- """Download artifact attributes."""
-
- name: str
- url: str
- filename: str
- version: str
- sha256_hash: str
-
- def download_to(self, dest_dir: Path, show_progress: bool = True) -> Path:
- """Download artifact into destination directory."""
- if (dest := dest_dir / self.filename).exists():
- raise ValueError(f"{dest} already exists")
-
- download(
- self.url,
- dest,
- show_progress=show_progress,
- label=f"Downloading {self.name} ver. {self.version}",
- )
-
- if sha256(dest) != self.sha256_hash:
- raise ValueError("Digests do not match")
-
- return dest
+ if cfg.sha256_hash and sha256(dest) != cfg.sha256_hash:
+ raise ValueError("Hashes do not match.")
diff --git a/src/mlia/utils/proc.py b/src/mlia/utils/proc.py
index d11bfc5..236854e 100644
--- a/src/mlia/utils/proc.py
+++ b/src/mlia/utils/proc.py
@@ -6,6 +6,7 @@ from __future__ import annotations
import logging
import subprocess # nosec
from dataclasses import dataclass
+from functools import partial
from pathlib import Path
from typing import Callable
from typing import Generator
@@ -45,6 +46,18 @@ def command_output(command: Command) -> Generator[str, None, None]:
OutputConsumer = Callable[[str], None]
+class OutputLogger:
+ """Log process output to the given logger with the given level."""
+
+ def __init__(self, logger_: logging.Logger, level: int = logging.DEBUG) -> None:
+ """Create log function with the appropriate log level set."""
+ self.log_fn = partial(logger_.log, level)
+
+ def __call__(self, line: str) -> None:
+ """Redirect output to the logger."""
+ self.log_fn(line)
+
+
def process_command_output(
command: Command,
consumers: list[OutputConsumer],