aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2024-03-21 17:33:17 +0000
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2024-03-22 10:06:28 +0000
commitc7ee5b783f044d7ff641773aa385840f5ff944cc (patch)
tree297f308978b00282d8ebd3a1f71e1ae5e678a767
parent508281df31dc3c18f2e007f4dd505160342a681a (diff)
downloadmlia-c7ee5b783f044d7ff641773aa385840f5ff944cc.tar.gz
refactor: Backend dependencies and more
- Add backend dependencies: One backend can now depend on another backend. - Re-factor 'DownloadArtifact': - Rename 'DownloadArtifact' to 'DownloadConfig' - Remove attributes 'name' and 'version' not relevant for downloads - Add helper properties: - 'filename' parses the URL to extract the file name from the end - 'headers' calls the function to generate a HTML header for the download - Add OutputLogger helper class - Re-factor handling of backend configurations in the target profiles. Change-Id: Ifda6cf12c375d0c1747d7e4130a0370d22c3d33a Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
-rw-r--r--src/mlia/backend/corstone/install.py9
-rw-r--r--src/mlia/backend/corstone/performance.py8
-rw-r--r--src/mlia/backend/install.py40
-rw-r--r--src/mlia/backend/manager.py56
-rw-r--r--src/mlia/target/config.py4
-rw-r--r--src/mlia/target/cortex_a/config.py10
-rw-r--r--src/mlia/utils/download.py62
-rw-r--r--src/mlia/utils/proc.py13
-rw-r--r--tests/test_backend_install.py13
-rw-r--r--tests/test_backend_manager.py149
-rw-r--r--tests/test_utils_download.py35
11 files changed, 286 insertions, 113 deletions
diff --git a/src/mlia/backend/corstone/install.py b/src/mlia/backend/corstone/install.py
index 5f11d5b..5c18334 100644
--- a/src/mlia/backend/corstone/install.py
+++ b/src/mlia/backend/corstone/install.py
@@ -19,7 +19,7 @@ from mlia.backend.install import CompoundPathChecker
from mlia.backend.install import Installation
from mlia.backend.install import PackagePathChecker
from mlia.backend.install import StaticPathChecker
-from mlia.utils.download import DownloadArtifact
+from mlia.utils.download import DownloadConfig
from mlia.utils.filesystem import working_directory
@@ -159,8 +159,6 @@ def get_corstone_installation(corstone_name: str) -> Installation:
archive = corstone_fvp.archive
sha256_hash = corstone_fvp.sha256_hash
url = ARM_ECOSYSTEM_FVP_URL + archive
- filename = archive.split("/")[-1]
- version = corstone_fvp.get_fvp_version()
expected_files_fvp = corstone_fvp.fvp_expected_files
expected_files_vht = corstone_fvp.get_vht_expected_files()
backend_subfolder = expected_files_fvp[0].split("FVP")[0]
@@ -169,11 +167,8 @@ def get_corstone_installation(corstone_name: str) -> Installation:
name=corstone_name,
description=corstone_name.capitalize() + " FVP",
fvp_dir_name=corstone_name.replace("-", "_"),
- download_artifact=DownloadArtifact(
- name=corstone_name.capitalize() + " FVP",
+ download_config=DownloadConfig(
url=url,
- filename=filename,
- version=version,
sha256_hash=sha256_hash,
),
supported_platforms=["Linux"],
diff --git a/src/mlia/backend/corstone/performance.py b/src/mlia/backend/corstone/performance.py
index fc50109..fe4e271 100644
--- a/src/mlia/backend/corstone/performance.py
+++ b/src/mlia/backend/corstone/performance.py
@@ -15,6 +15,7 @@ from mlia.backend.errors import BackendExecutionFailed
from mlia.backend.repo import get_backend_repository
from mlia.utils.filesystem import get_mlia_resources
from mlia.utils.proc import Command
+from mlia.utils.proc import OutputLogger
from mlia.utils.proc import process_command_output
@@ -187,15 +188,12 @@ def get_metrics(
) from err
output_parser = GenericInferenceOutputParser()
-
- def redirect_to_log(line: str) -> None:
- """Redirect FVP output to the logger."""
- logger.debug(line.strip())
+ output_logger = OutputLogger(logger)
try:
process_command_output(
command,
- [output_parser, redirect_to_log],
+ [output_parser, output_logger],
)
except subprocess.CalledProcessError as err:
raise BackendExecutionFailed("Backend execution failed.") from err
diff --git a/src/mlia/backend/install.py b/src/mlia/backend/install.py
index 0ced9f6..1a7d58b 100644
--- a/src/mlia/backend/install.py
+++ b/src/mlia/backend/install.py
@@ -16,7 +16,8 @@ from typing import Optional
from typing import Union
from mlia.backend.repo import get_backend_repository
-from mlia.utils.download import DownloadArtifact
+from mlia.utils.download import download
+from mlia.utils.download import DownloadConfig
from mlia.utils.filesystem import all_files_exist
from mlia.utils.filesystem import temp_directory
from mlia.utils.filesystem import working_directory
@@ -45,10 +46,18 @@ InstallationType = Union[InstallFromPath, DownloadAndInstall]
class Installation(ABC):
"""Base class for the installation process of the backends."""
- def __init__(self, name: str, description: str) -> None:
+ def __init__(
+ self, name: str, description: str, dependencies: list[str] | None = None
+ ) -> None:
"""Init the installation."""
+ assert not dependencies or name not in dependencies, (
+ f"Invalid backend configuration: Backend '{name}' has itself as a "
+ "dependency! The backend source code needs to be fixed."
+ )
+
self.name = name
self.description = description
+ self.dependencies = dependencies if dependencies else []
@property
@abstractmethod
@@ -89,21 +98,22 @@ BackendInstaller = Callable[[bool, Path], Path]
class BackendInstallation(Installation):
"""Backend installation."""
- def __init__(
+ def __init__( # pylint: disable=too-many-arguments
self,
name: str,
description: str,
fvp_dir_name: str,
- download_artifact: DownloadArtifact | None,
+ download_config: DownloadConfig | None,
supported_platforms: list[str] | None,
path_checker: PathChecker,
backend_installer: BackendInstaller | None,
+ dependencies: list[str] | None = None,
) -> None:
"""Init the backend installation."""
- super().__init__(name, description)
+ super().__init__(name, description, dependencies)
self.fvp_dir_name = fvp_dir_name
- self.download_artifact = download_artifact
+ self.download_config = download_config
self.supported_platforms = supported_platforms
self.path_checker = path_checker
self.backend_installer = backend_installer
@@ -125,7 +135,7 @@ class BackendInstallation(Installation):
def supports(self, install_type: InstallationType) -> bool:
"""Return true if backends supported type of the installation."""
if isinstance(install_type, DownloadAndInstall):
- return self.download_artifact is not None
+ return self.download_config is not None
if isinstance(install_type, InstallFromPath):
return self.path_checker(install_type.backend_path) is not None
@@ -135,10 +145,10 @@ class BackendInstallation(Installation):
def install(self, install_type: InstallationType) -> None:
"""Install the backend."""
if isinstance(install_type, DownloadAndInstall):
- assert self.download_artifact is not None, "No artifact provided"
+ assert self.download_config is not None, "No artifact provided"
self._download_and_install(
- self.download_artifact, install_type.eula_agreement
+ self.download_config, install_type.eula_agreement
)
elif isinstance(install_type, InstallFromPath):
backend_info = self.path_checker(install_type.backend_path)
@@ -207,13 +217,17 @@ class BackendInstallation(Installation):
ex,
)
- def _download_and_install(
- self, download_artifact: DownloadArtifact, eula_agrement: bool
- ) -> None:
+ def _download_and_install(self, cfg: DownloadConfig, eula_agrement: bool) -> None:
"""Download and install the backend."""
with temp_directory() as tmpdir:
try:
- dest = download_artifact.download_to(tmpdir)
+ dest = tmpdir / cfg.filename
+ download(
+ dest=dest,
+ cfg=cfg,
+ show_progress=True,
+ )
+
except Exception as err:
raise RuntimeError("Unable to download backend artifact.") from err
diff --git a/src/mlia/backend/manager.py b/src/mlia/backend/manager.py
index a4bc8c0..a752791 100644
--- a/src/mlia/backend/manager.py
+++ b/src/mlia/backend/manager.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for installation process."""
from __future__ import annotations
@@ -188,6 +188,33 @@ class DefaultInstallationManager(InstallationManager, InstallationFiltersMixin):
logger.info("%s installation canceled.", installation.name)
return
+ for dependency in installation.dependencies:
+ deps = self.find_by_name(dependency)
+ if not deps:
+ raise ValueError(
+ f"Backend {installation.name} depends on "
+ f"unknown backend '{dependency}'."
+ )
+ missing_deps = [dep for dep in deps if not dep.already_installed]
+ if missing_deps:
+ proceed = self.noninteractive or yes(
+ "The following dependencies are not installed: "
+ f"{[dep.name for dep in missing_deps]}. "
+ "Continue installation anyway?"
+ )
+ logger.warning(
+ "Installing backend %s with the following dependencies "
+ "missing: %s",
+ installation.name,
+ missing_deps,
+ )
+ if not proceed:
+ logger.info(
+ "%s installation canceled due to missing dependencies.",
+ installation.name,
+ )
+ return
+
if installation.already_installed and force:
logger.info(
"Force installing %s, so delete the existing "
@@ -254,16 +281,37 @@ class DefaultInstallationManager(InstallationManager, InstallationFiltersMixin):
installations = self.already_installed(backend_name)
if not installations:
- raise ConfigurationError(f"Backend '{backend_name}' is not installed")
+ raise ConfigurationError(f"Backend '{backend_name}' is not installed.")
if len(installations) != 1:
raise InternalError(
- f"More than one installed backend with name {backend_name} found"
+ f"More than one installed backend with name {backend_name} found."
)
installation = installations[0]
- installation.uninstall()
+ dependent_backends = [
+ dep.name
+ for dep in self.installations
+ if dep.name in installation.dependencies
+ ]
+ if dependent_backends:
+ msg = (
+ f"The following backends depend on '{installation.name}' which "
+ f"you are about to uninstall: {dependent_backends}",
+ )
+ proceed = self.noninteractive or yes(
+ f"{msg}. Continue uninstalling anyway?"
+ )
+ logger.warning(msg)
+ if not proceed:
+ logger.info(
+ "Uninstalling %s canceled due to dependencies.",
+ installation.name,
+ )
+ return
+
+ installation.uninstall()
logger.info("%s successfully uninstalled.", installation.name)
def backend_installed(self, backend_name: str) -> bool:
diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py
index 8ccdad8..8492086 100644
--- a/src/mlia/target/config.py
+++ b/src/mlia/target/config.py
@@ -85,9 +85,11 @@ T = TypeVar("T", bound="TargetProfile")
class TargetProfile(ABC):
"""Base class for target profiles."""
- def __init__(self, target: str) -> None:
+ def __init__(self, target: str, backend_config: dict | None = None) -> None:
"""Init TargetProfile instance with the target name."""
self.target = target
+ # Load backend config(s) to be handled by the backend(s) later.
+ self.backend_config = {} if backend_config is None else backend_config
@classmethod
def load(cls: type[T], path: str | Path) -> T:
diff --git a/src/mlia/target/cortex_a/config.py b/src/mlia/target/cortex_a/config.py
index f91031e..4f33f3d 100644
--- a/src/mlia/target/cortex_a/config.py
+++ b/src/mlia/target/cortex_a/config.py
@@ -15,12 +15,12 @@ class CortexAConfiguration(TargetProfile):
def __init__(self, **kwargs: Any) -> None:
"""Init Cortex-A target configuration."""
target = kwargs["target"]
- super().__init__(target)
+ backend_config = kwargs.get("backend")
+ super().__init__(target, backend_config)
- self.backend_config = kwargs.get("backend")
- self.armnn_tflite_delegate_version = kwargs["backend"]["armnn-tflite-delegate"][
- "version"
- ]
+ self.armnn_tflite_delegate_version = self.backend_config[
+ "armnn-tflite-delegate"
+ ]["version"]
def verify(self) -> None:
"""Check the parameters."""
diff --git a/src/mlia/utils/download.py b/src/mlia/utils/download.py
index e00be28..2b06fed 100644
--- a/src/mlia/utils/download.py
+++ b/src/mlia/utils/download.py
@@ -1,10 +1,11 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Utils for files downloading."""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
+from typing import Callable
from typing import Iterable
import requests
@@ -40,20 +41,46 @@ def download_progress(
yield chunk
+@dataclass
+class DownloadConfig:
+ """Parameters to download an artifact."""
+
+ url: str
+ sha256_hash: str
+ header_gen_fn: Callable[[], dict[str, str]] | None = None
+
+ @property
+ def filename(self) -> str:
+ """Get the filename from the URL."""
+ return self.url.rsplit("/", 1)[-1]
+
+ @property
+ def headers(self) -> dict[str, str]:
+ """Get the headers using the header_gen_fn."""
+ return self.header_gen_fn() if self.header_gen_fn else {}
+
+
def download(
- url: str,
dest: Path,
+ cfg: DownloadConfig,
show_progress: bool = False,
label: str | None = None,
chunk_size: int = 8192,
timeout: int = 30,
) -> None:
"""Download the file."""
- with requests.get(url, stream=True, timeout=timeout) as resp:
+ if dest.exists():
+ raise FileExistsError(f"{dest} already exists.")
+
+ with requests.get(
+ cfg.url, stream=True, timeout=timeout, headers=cfg.headers
+ ) as resp:
resp.raise_for_status()
content_chunks = resp.iter_content(chunk_size=chunk_size)
if show_progress:
+ if not label:
+ label = f"Downloading to {dest}."
content_length = parse_int(resp.headers.get("Content-Length"))
content_chunks = download_progress(content_chunks, content_length, label)
@@ -61,30 +88,5 @@ def download(
for chunk in content_chunks:
file.write(chunk)
-
-@dataclass
-class DownloadArtifact:
- """Download artifact attributes."""
-
- name: str
- url: str
- filename: str
- version: str
- sha256_hash: str
-
- def download_to(self, dest_dir: Path, show_progress: bool = True) -> Path:
- """Download artifact into destination directory."""
- if (dest := dest_dir / self.filename).exists():
- raise ValueError(f"{dest} already exists")
-
- download(
- self.url,
- dest,
- show_progress=show_progress,
- label=f"Downloading {self.name} ver. {self.version}",
- )
-
- if sha256(dest) != self.sha256_hash:
- raise ValueError("Digests do not match")
-
- return dest
+ if cfg.sha256_hash and sha256(dest) != cfg.sha256_hash:
+ raise ValueError("Hashes do not match.")
diff --git a/src/mlia/utils/proc.py b/src/mlia/utils/proc.py
index d11bfc5..236854e 100644
--- a/src/mlia/utils/proc.py
+++ b/src/mlia/utils/proc.py
@@ -6,6 +6,7 @@ from __future__ import annotations
import logging
import subprocess # nosec
from dataclasses import dataclass
+from functools import partial
from pathlib import Path
from typing import Callable
from typing import Generator
@@ -45,6 +46,18 @@ def command_output(command: Command) -> Generator[str, None, None]:
OutputConsumer = Callable[[str], None]
+class OutputLogger:
+ """Log process output to the given logger with the given level."""
+
+ def __init__(self, logger_: logging.Logger, level: int = logging.DEBUG) -> None:
+ """Create log function with the appropriate log level set."""
+ self.log_fn = partial(logger_.log, level)
+
+ def __call__(self, line: str) -> None:
+ """Redirect output to the logger."""
+ self.log_fn(line)
+
+
def process_command_output(
command: Command,
consumers: list[OutputConsumer],
diff --git a/tests/test_backend_install.py b/tests/test_backend_install.py
index 963766e..3636fb4 100644
--- a/tests/test_backend_install.py
+++ b/tests/test_backend_install.py
@@ -20,6 +20,7 @@ from mlia.backend.install import InstallFromPath
from mlia.backend.install import PackagePathChecker
from mlia.backend.install import StaticPathChecker
from mlia.backend.repo import BackendRepository
+from mlia.utils.download import DownloadConfig
@pytest.fixture(name="backend_repo")
@@ -104,11 +105,9 @@ def test_backend_installation_from_path(
def test_backend_installation_download_and_install(
- tmp_path: Path, backend_repo: MagicMock
+ tmp_path: Path, backend_repo: MagicMock, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Test methods of backend installation."""
- download_artifact_mock = MagicMock()
-
tmp_archive = tmp_path.joinpath("sample.tgz")
sample_file = tmp_path.joinpath("sample.txt")
sample_file.touch()
@@ -116,13 +115,17 @@ def test_backend_installation_download_and_install(
with tarfile.open(tmp_archive, "w:gz") as archive:
archive.add(sample_file)
- download_artifact_mock.download_to.return_value = tmp_archive
+ monkeypatch.setattr("mlia.backend.install.download", MagicMock())
+ monkeypatch.setattr(
+ "mlia.utils.download.DownloadConfig.filename",
+ tmp_archive,
+ )
installation = BackendInstallation(
"sample_backend",
"Sample backend",
"sample_backend",
- download_artifact_mock,
+ DownloadConfig(url="NOT_USED", sha256_hash="NOT_USED"),
None,
lambda path: BackendInfo(path, copy_source=False),
lambda eula_agreement, path: path,
diff --git a/tests/test_backend_manager.py b/tests/test_backend_manager.py
index 879353e..63c11ee 100644
--- a/tests/test_backend_manager.py
+++ b/tests/test_backend_manager.py
@@ -3,6 +3,7 @@
"""Tests for installation manager."""
from __future__ import annotations
+from functools import partial
from pathlib import Path
from typing import Any
from unittest.mock import call
@@ -23,6 +24,7 @@ from mlia.core.errors import InternalError
def get_default_installation_manager_mock(
name: str,
already_installed: bool = False,
+ dependencies: list[str] | None = None,
) -> MagicMock:
"""Get mock instance for DefaultInstallationManager."""
mock = MagicMock(spec=DefaultInstallationManager)
@@ -30,6 +32,7 @@ def get_default_installation_manager_mock(
props = {
"name": name,
"already_installed": already_installed,
+ "dependencies": dependencies if dependencies else [],
}
for prop, value in props.items():
setattr(type(mock), prop, PropertyMock(return_value=value))
@@ -49,6 +52,7 @@ def get_installation_mock(
already_installed: bool = False,
could_be_installed: bool = False,
supported_install_type: type | tuple | None = None,
+ dependencies: list[str] | None = None,
) -> MagicMock:
"""Get mock instance for the installation."""
mock = MagicMock(spec=Installation)
@@ -65,6 +69,7 @@ def get_installation_mock(
"name": name,
"already_installed": already_installed,
"could_be_installed": could_be_installed,
+ "dependencies": dependencies if dependencies else [],
}
for prop, value in props.items():
setattr(type(mock), prop, PropertyMock(return_value=value))
@@ -72,38 +77,45 @@ def get_installation_mock(
return mock
-def _already_installed_mock() -> MagicMock:
- return get_installation_mock(
- name="already_installed",
- already_installed=True,
- supported_install_type=(DownloadAndInstall, InstallFromPath),
- )
+_already_installed_mock = partial(
+ get_installation_mock,
+ name="already_installed",
+ already_installed=True,
+ supported_install_type=(DownloadAndInstall, InstallFromPath),
+)
-def _ready_for_installation_mock() -> MagicMock:
- return get_installation_mock(
- name="ready_for_installation",
- already_installed=False,
- could_be_installed=True,
- )
+_ready_for_installation_mock = partial(
+ get_installation_mock,
+ name="ready_for_installation",
+ already_installed=False,
+ could_be_installed=True,
+)
-def _could_be_downloaded_and_installed_mock() -> MagicMock:
- return get_installation_mock(
- name="could_be_downloaded_and_installed",
- already_installed=False,
- could_be_installed=True,
- supported_install_type=DownloadAndInstall,
- )
+_could_be_downloaded_and_installed_mock = partial(
+ get_installation_mock,
+ name="could_be_downloaded_and_installed",
+ already_installed=False,
+ could_be_installed=True,
+ supported_install_type=DownloadAndInstall,
+)
-def _could_be_installed_from_mock() -> MagicMock:
- return get_installation_mock(
- name="could_be_installed_from",
- already_installed=False,
- could_be_installed=True,
- supported_install_type=InstallFromPath,
- )
+_could_be_installed_from_mock = partial(
+ get_installation_mock,
+ name="could_be_installed_from",
+ already_installed=False,
+ could_be_installed=True,
+ supported_install_type=InstallFromPath,
+)
+
+_already_installed_dep_mock = partial(
+ get_installation_mock,
+ name="already_installed_dep",
+ already_installed=True,
+ supported_install_type=(DownloadAndInstall, InstallFromPath),
+)
def get_installation_manager(
@@ -114,13 +126,23 @@ def get_installation_manager(
) -> DefaultInstallationManager:
"""Get installation manager instance."""
if not noninteractive:
- monkeypatch.setattr(
- "mlia.backend.manager.yes", MagicMock(return_value=yes_response)
+ return get_interactive_installation_manager(
+ installations, monkeypatch, MagicMock(return_value=yes_response)
)
return DefaultInstallationManager(installations, noninteractive=noninteractive)
+def get_interactive_installation_manager(
+ installations: list[Any],
+ monkeypatch: pytest.MonkeyPatch,
+ mock_interaction: MagicMock,
+) -> DefaultInstallationManager:
+ """Get and interactive installation manager instance using the given mock."""
+ monkeypatch.setattr("mlia.backend.manager.yes", mock_interaction)
+ return DefaultInstallationManager(installations, noninteractive=False)
+
+
def test_installation_manager_filtering() -> None:
"""Test default installation manager."""
already_installed = _already_installed_mock()
@@ -337,3 +359,74 @@ def test_show_env_details(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch,
)
manager.show_env_details()
+
+
+@pytest.mark.parametrize(
+ "dependency",
+ (
+ _ready_for_installation_mock(),
+ _already_installed_mock(),
+ ),
+)
+@pytest.mark.parametrize("yes_response", (True, False))
+def test_could_be_installed_with_dep(
+ dependency: MagicMock,
+ yes_response: bool,
+ tmp_path: Path,
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """Test installation with a dependency."""
+ install_mock = _could_be_installed_from_mock(dependencies=[dependency.name])
+
+ yes_mock = MagicMock(return_value=yes_response)
+ manager = get_interactive_installation_manager(
+ [install_mock, dependency], monkeypatch, yes_mock
+ )
+ manager.install_from(tmp_path, install_mock.name)
+
+ if yes_response:
+ install_mock.install.assert_called_once()
+ else:
+ install_mock.install.assert_not_called()
+ install_mock.uninstall.assert_not_called()
+
+ dependency.install.assert_not_called()
+ dependency.uninstall.assert_not_called()
+
+
+def test_install_with_unknown_dep(
+ tmp_path: Path,
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """Test installation with an unknown dependency."""
+ install_mock = _could_be_installed_from_mock(dependencies=["UNKNOWN_BACKEND"])
+
+ manager = get_installation_manager(False, [install_mock], monkeypatch)
+ with pytest.raises(ValueError):
+ manager.install_from(tmp_path, install_mock.name)
+
+ install_mock.install.assert_not_called()
+ install_mock.uninstall.assert_not_called()
+
+
+@pytest.mark.parametrize("yes_response", (True, False))
+def test_uninstall_with_dep(
+ yes_response: bool, monkeypatch: pytest.MonkeyPatch
+) -> None:
+ """Test uninstalling a backend with a dependency."""
+ dependency = _already_installed_dep_mock()
+ install_mock = _already_installed_mock(dependencies=[dependency.name])
+ yes_mock = MagicMock(return_value=yes_response)
+ manager = get_interactive_installation_manager(
+ [install_mock, dependency], monkeypatch, yes_mock
+ )
+ manager.uninstall(install_mock.name)
+
+ install_mock.install.assert_not_called()
+ if yes_response:
+ install_mock.uninstall.assert_called_once()
+ else:
+ install_mock.uninstall.assert_not_called()
+
+ dependency.install.assert_not_called()
+ dependency.uninstall.assert_not_called()
diff --git a/tests/test_utils_download.py b/tests/test_utils_download.py
index 28af74f..7188c62 100644
--- a/tests/test_utils_download.py
+++ b/tests/test_utils_download.py
@@ -1,8 +1,9 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for download functionality."""
from __future__ import annotations
+import hashlib
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
@@ -14,7 +15,7 @@ import pytest
import requests
from mlia.utils.download import download
-from mlia.utils.download import DownloadArtifact
+from mlia.utils.download import DownloadConfig
def response_mock(
@@ -69,9 +70,18 @@ def test_download(
"mlia.utils.download.requests.get",
MagicMock(return_value=response_mock(content_length, content_chunks)),
)
+ hash_obj = hashlib.sha256()
+ for chunk in content_chunks:
+ hash_obj.update(chunk)
+ sha256_hash = hash_obj.hexdigest()
dest = tmp_path / "sample.bin"
- download("some_url", dest, show_progress=show_progress, label=label)
+ download(
+ dest,
+ DownloadConfig("some_url", sha256_hash=sha256_hash),
+ show_progress=show_progress,
+ label=label,
+ )
assert dest.is_file()
assert dest.read_bytes() == bytes(
@@ -92,7 +102,7 @@ def test_download(
"10",
[bytes(range(10))],
"bad_hash",
- pytest.raises(ValueError, match="Digests do not match"),
+ pytest.raises(ValueError, match="Hashes do not match."),
],
],
)
@@ -111,15 +121,13 @@ def test_download_artifact_download_to(
)
with expected_error:
- artifact = DownloadArtifact(
- "test_artifact",
+ cfg = DownloadConfig(
"some_url",
- "artifact_filename",
- "1.0",
sha256_hash,
)
- dest = artifact.download_to(tmp_path)
+ dest = tmp_path / "artifact_filename"
+ download(dest, cfg)
assert isinstance(dest, Path)
assert dest.name == "artifact_filename"
@@ -133,16 +141,13 @@ def test_download_artifact_unable_to_overwrite(
MagicMock(return_value=response_mock("10", [bytes(range(10))])),
)
- artifact = DownloadArtifact(
- "test_artifact",
+ cfg = DownloadConfig(
"some_url",
- "artifact_filename",
- "1.0",
"sha256_hash",
)
existing_file = tmp_path / "artifact_filename"
existing_file.touch()
- with pytest.raises(ValueError, match=f"{existing_file} already exists"):
- artifact.download_to(tmp_path)
+ with pytest.raises(FileExistsError, match=f"{existing_file} already exists."):
+ download(existing_file, cfg)