From 302ce432829ae7c25e100a5cca718f0aadbe4fd4 Mon Sep 17 00:00:00 2001 From: Dmitrii Agibov Date: Tue, 15 Nov 2022 13:19:53 +0000 Subject: MLIA-649 Support tosa-checker as a backend - Add new type of the backend based on python packages - Add installation class for TOSA checker - Update documentation - Extend support of the parameter "force" in the "install" command Change-Id: I95567b75e1cfe85daa1f1c3d359975bb67b2504e --- README.md | 21 ++-- src/mlia/cli/commands.py | 23 ++--- src/mlia/cli/config.py | 3 +- src/mlia/cli/main.py | 6 ++ src/mlia/cli/options.py | 2 +- src/mlia/core/errors.py | 4 + src/mlia/devices/tosa/operators.py | 3 +- src/mlia/tools/metadata/common.py | 173 ++++++++++++++++---------------- src/mlia/tools/metadata/corstone.py | 1 - src/mlia/tools/metadata/py_package.py | 84 ++++++++++++++++ src/mlia/utils/py_manager.py | 62 ++++++++++++ tests/test_cli_commands.py | 10 +- tests/test_mlia_utils_py_manager.py | 73 ++++++++++++++ tests/test_tools_metadata_common.py | 59 ++++++++--- tests/test_tools_metadata_py_package.py | 62 ++++++++++++ 15 files changed, 452 insertions(+), 134 deletions(-) create mode 100644 src/mlia/tools/metadata/py_package.py create mode 100644 src/mlia/utils/py_manager.py create mode 100644 tests/test_mlia_utils_py_manager.py create mode 100644 tests/test_tools_metadata_py_package.py diff --git a/README.md b/README.md index 1b9f494..a0a82ad 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ The `mlia-backend` command is used to manage the installation of new backends. The usage is: ```bash -mlia backend install --help +mlia-backend install --help ``` and the result looks like: @@ -71,22 +71,18 @@ optional arguments: * -h/--help: Show this help message and exit * --path PATH: Path to the installed backend -* --download: Download and install a backend +* --force: Force reinstalling backend in the specified path * --noninteractive: Non interactive mode with automatic confirmation of every action Example: ```bash -# Use this command to see what backends can be downloaded. -mlia backend install --download +mlia-backend install Corstone-300 ``` After a successful installation of the backend(s), start using mlia in your virtual environment. -*Please note*: Backends cannot be removed once installed. -Consider creating a new environment and reinstall backends when needed. - ### Backend compatibility table Not all backends work on any platform. Please refer to the compatibility table @@ -112,14 +108,13 @@ below: ### Using Corstone™-300 -To install Corstone™-300 as a backend for Ethos™-U both options (`--download` -and `--path`) can be used: +To install Corstone™-300 as a backend for Ethos™-U next commands can be used: ```bash # To download and install Corstone-300 automatically -mlia backend install --download Corstone-300 +mlia-backend install Corstone-300 # To point MLIA to an already locally installed version of Corstone-300 -mlia backend install --path YOUR_LOCAL_PATH_TO_CORSTONE_300 +mlia-backend install Corstone-300 --path YOUR_LOCAL_PATH_TO_CORSTONE_300 ``` Please note: Corstone™-300 used in the example above is available only @@ -136,11 +131,11 @@ Corstone™-310 is available as Arm® Virtual Hardware (AVH). ### Using TOSA checker -TOSA compatibility checker is available in MLIA as an external dependency. +TOSA compatibility checker is available in MLIA as a separate backend. Please, install it into the same environment as MLIA using next command: ```bash -pip install mlia[tosa] +mlia-backend install tosa-checker ``` TOSA checker resources: diff --git a/src/mlia/cli/commands.py b/src/mlia/cli/commands.py index 4be7f3e..09fe9de 100644 --- a/src/mlia/cli/commands.py +++ b/src/mlia/cli/commands.py @@ -20,7 +20,6 @@ from __future__ import annotations import logging from pathlib import Path -from typing import cast from mlia.api import ExecutionContext from mlia.api import generate_supported_operators_report @@ -249,29 +248,29 @@ def backend_install( noninteractive: bool = False, force: bool = False, ) -> None: - """Install configuration.""" + """Install backend.""" logger.info(CONFIG) manager = get_installation_manager(noninteractive) - install_from_path = path is not None - - if install_from_path: - manager.install_from(cast(Path, path), name, force) + if path is not None: + manager.install_from(path, name, force) else: eula_agreement = not i_agree_to_the_contained_eula - manager.download_and_install(name, eula_agreement) + manager.download_and_install(name, eula_agreement, force) -def backend_uninstall( - name: str, -) -> None: - """Uninstall backend(s).""" +def backend_uninstall(name: str) -> None: + """Uninstall backend.""" + logger.info(CONFIG) + manager = get_installation_manager(noninteractive=True) manager.uninstall(name) def backend_list() -> None: - """List backend status.""" + """List backends status.""" + logger.info(CONFIG) + manager = get_installation_manager(noninteractive=True) manager.show_env_details() diff --git a/src/mlia/cli/config.py b/src/mlia/cli/config.py index 30373e4..6ea9bb4 100644 --- a/src/mlia/cli/config.py +++ b/src/mlia/cli/config.py @@ -10,13 +10,14 @@ import mlia.backend.manager as backend_manager from mlia.tools.metadata.common import DefaultInstallationManager from mlia.tools.metadata.common import InstallationManager from mlia.tools.metadata.corstone import get_corstone_installations +from mlia.tools.metadata.py_package import get_pypackage_backend_installations logger = logging.getLogger(__name__) def get_installation_manager(noninteractive: bool = False) -> InstallationManager: """Return installation manager.""" - backends = get_corstone_installations() + backends = get_corstone_installations() + get_pypackage_backend_installations() return DefaultInstallationManager(backends, noninteractive=noninteractive) diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py index 61b8f05..6c74a11 100644 --- a/src/mlia/cli/main.py +++ b/src/mlia/cli/main.py @@ -33,6 +33,8 @@ from mlia.cli.options import add_output_options from mlia.cli.options import add_target_options from mlia.cli.options import add_tflite_model_options from mlia.core.context import ExecutionContext +from mlia.core.errors import ConfigurationError +from mlia.core.errors import InternalError logger = logging.getLogger(__name__) @@ -219,6 +221,10 @@ def run_command(args: argparse.Namespace) -> int: return 0 except KeyboardInterrupt: logger.error("Execution has been interrupted") + except InternalError as err: + logger.error("Internal error: %s", err) + except ConfigurationError as err: + logger.error(err) except Exception as err: # pylint: disable=broad-except logger.error( "\nExecution finished with error: %s", diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py index bf2f09b..5eab9aa 100644 --- a/src/mlia/cli/options.py +++ b/src/mlia/cli/options.py @@ -154,7 +154,7 @@ def add_backend_install_options(parser: argparse.ArgumentParser) -> None: "--force", default=False, action="store_true", - help="Force reinstall backend in the specified path", + help="Force reinstalling backend in the specified path", ) parser.add_argument( "--noninteractive", diff --git a/src/mlia/core/errors.py b/src/mlia/core/errors.py index 7d6beb1..d2356c2 100644 --- a/src/mlia/core/errors.py +++ b/src/mlia/core/errors.py @@ -7,6 +7,10 @@ class ConfigurationError(Exception): """Configuration error.""" +class InternalError(Exception): + """Internal error.""" + + class FunctionalityNotSupportedError(Exception): """Functionality is not supported error.""" diff --git a/src/mlia/devices/tosa/operators.py b/src/mlia/devices/tosa/operators.py index 03f6fb8..1e4581a 100644 --- a/src/mlia/devices/tosa/operators.py +++ b/src/mlia/devices/tosa/operators.py @@ -47,8 +47,7 @@ def get_tosa_compatibility_info( if checker is None: raise Exception( "TOSA checker is not available. " - "Please make sure that 'tosa_checker' package is installed: " - "pip install mlia[tosa]" + "Please make sure that 'tosa-checker' backend is installed." ) ops = [ diff --git a/src/mlia/tools/metadata/common.py b/src/mlia/tools/metadata/common.py index 927be74..5019da9 100644 --- a/src/mlia/tools/metadata/common.py +++ b/src/mlia/tools/metadata/common.py @@ -11,9 +11,10 @@ from pathlib import Path from typing import Callable from typing import Union +from mlia.core.errors import ConfigurationError +from mlia.core.errors import InternalError from mlia.utils.misc import yes - logger = logging.getLogger(__name__) @@ -124,7 +125,9 @@ class InstallationManager(ABC): """Install backend from the local directory.""" @abstractmethod - def download_and_install(self, backend_name: str, eula_agreement: bool) -> None: + def download_and_install( + self, backend_name: str, eula_agreement: bool, force: bool + ) -> None: """Download and install backends.""" @abstractmethod @@ -153,29 +156,15 @@ class InstallationFiltersMixin: if all(filter_(installation) for filter_ in filters) ] - def could_be_installed_from( - self, backend_path: Path, backend_name: str - ) -> list[Installation]: - """Return installations that could be installed from provided directory.""" - return self.filter_by( - SupportsInstallTypeFilter(InstallFromPath(backend_path)), - SearchByNameFilter(backend_name), - ) - - def could_be_downloaded_and_installed( - self, backend_name: str - ) -> list[Installation]: - """Return installations that could be downloaded and installed.""" - return self.filter_by( - SupportsInstallTypeFilter(DownloadAndInstall()), - SearchByNameFilter(backend_name), - ReadyForInstallationFilter(), - ) + def find_by_name(self, backend_name: str) -> list[Installation]: + """Return list of the backends filtered by name.""" + return self.filter_by(SearchByNameFilter(backend_name)) def already_installed(self, backend_name: str = None) -> list[Installation]: """Return list of backends that are already installed.""" return self.filter_by( - AlreadyInstalledFilter(), SearchByNameFilter(backend_name) + AlreadyInstalledFilter(), + SearchByNameFilter(backend_name), ) def ready_for_installation(self) -> list[Installation]: @@ -193,83 +182,96 @@ class DefaultInstallationManager(InstallationManager, InstallationFiltersMixin): self.installations = installations self.noninteractive = noninteractive - def choose_installation_for_path( - self, backend_path: Path, backend_name: str, force: bool - ) -> Installation | None: - """Check available installation and select one if possible.""" - installs = self.could_be_installed_from(backend_path, backend_name) + def _install( + self, + backend_name: str, + install_type: InstallationType, + prompt: Callable[[Installation], str], + force: bool, + ) -> None: + """Check metadata and install backend.""" + installs = self.find_by_name(backend_name) if not installs: + logger.info("Unknown backend '%s'.", backend_name) logger.info( - "Unfortunatelly, it was not possible to automatically " - "detect type of the installed FVP. " - "Please, check provided path to the installed FVP." + "Please run command 'mlia-backend list' to get list of " + "supported backend names." ) - return None - if len(installs) != 1: - names = ",".join(install.name for install in installs) - logger.info( - "Unable to correctly detect type of the installed FVP." - "The following FVPs are detected %s. Installation skipped.", - names, - ) - return None + return + + if len(installs) > 1: + raise InternalError(f"More than one backend with name {backend_name} found") installation = installs[0] - if installation.already_installed: + if not installation.supports(install_type): + if isinstance(install_type, InstallFromPath): + logger.info( + "Backend '%s' could not be installed using path '%s'.", + installation.name, + install_type.backend_path, + ) + logger.info( + "Please check that '%s' is a valid path to the installed backend.", + install_type.backend_path, + ) + else: + logger.info( + "Backend '%s' could not be downloaded and installed", + installation.name, + ) + logger.info( + "Please refer to the project's documentation for more details." + ) + + return + + if installation.already_installed and not force: + logger.info("Backend '%s' is already installed.", installation.name) + logger.info("Please, consider using --force option.") + return + + proceed = self.noninteractive or yes(prompt(installation)) + if not proceed: + logger.info("%s installation canceled.", installation.name) + return + + if installation.already_installed and force: logger.info( - "%s was found in %s, but it has been already installed " - "in the ML Inference Advisor.", + "Force installing %s, so delete the existing " + "installed backend first.", installation.name, - backend_path, ) - return installation if force else None + installation.uninstall() - return installation + installation.install(install_type) + logger.info("%s successfully installed.", installation.name) def install_from( self, backend_path: Path, backend_name: str, force: bool = False ) -> None: """Install from the provided directory.""" - installation = self.choose_installation_for_path( - backend_path, backend_name, force - ) - - if not installation: - return - if force: - self.uninstall(backend_name) - logger.info( - "Force installing %s, so delete the existing installed backend first.", - installation.name, + def prompt(install: Installation) -> str: + return ( + f"{install.name} was found in {backend_path}. " + "Would you like to install it?" ) - prompt = ( - f"{installation.name} was found in {backend_path}. " - "Would you like to install it?" - ) - self._install(installation, InstallFromPath(backend_path), prompt) + install_type = InstallFromPath(backend_path) + self._install(backend_name, install_type, prompt, force) def download_and_install( - self, backend_name: str, eula_agreement: bool = True + self, backend_name: str, eula_agreement: bool = True, force: bool = False ) -> None: """Download and install available backends.""" - installations = self.could_be_downloaded_and_installed(backend_name) - if not installations: - logger.info("No backends available for the installation.") - return + def prompt(install: Installation) -> str: + return f"Would you like to download and install {install.name}?" - names = ",".join(installation.name for installation in installations) - logger.info("Following backends are available for downloading: %s", names) - - for installation in installations: - prompt = f"Would you like to download and install {installation.name}?" - self._install( - installation, DownloadAndInstall(eula_agreement=eula_agreement), prompt - ) + install_type = DownloadAndInstall(eula_agreement=eula_agreement) + self._install(backend_name, install_type, prompt, force) def show_env_details(self) -> None: """Print current state of the execution environment.""" @@ -299,24 +301,19 @@ class DefaultInstallationManager(InstallationManager, InstallationFiltersMixin): def uninstall(self, backend_name: str) -> None: """Uninstall the backend with name backend_name.""" installations = self.already_installed(backend_name) + if not installations: - raise Exception("No backend available for uninstall") - for installation in installations: - installation.uninstall() + raise ConfigurationError(f"Backend '{backend_name}' is not installed") - def _install( - self, - installation: Installation, - installation_type: InstallationType, - prompt: str, - ) -> None: - proceed = self.noninteractive or yes(prompt) + if len(installations) != 1: + raise InternalError( + f"More than one installed backend with name {backend_name} found" + ) - if proceed: - installation.install(installation_type) - logger.info("%s successfully installed.", installation.name) - else: - logger.info("%s installation canceled.", installation.name) + installation = installations[0] + installation.uninstall() + + logger.info("%s successfully uninstalled.", installation.name) def backend_installed(self, backend_name: str) -> bool: """Return true if requested backend installed.""" diff --git a/src/mlia/tools/metadata/corstone.py b/src/mlia/tools/metadata/corstone.py index 04b13b5..df2dcdb 100644 --- a/src/mlia/tools/metadata/corstone.py +++ b/src/mlia/tools/metadata/corstone.py @@ -209,7 +209,6 @@ class BackendInstallation(Installation): def uninstall(self) -> None: """Uninstall the backend.""" remove_system(self.metadata.fvp_dir_name) - logger.info("%s successfully uninstalled.", self.name) class PackagePathChecker: diff --git a/src/mlia/tools/metadata/py_package.py b/src/mlia/tools/metadata/py_package.py new file mode 100644 index 0000000..716b62a --- /dev/null +++ b/src/mlia/tools/metadata/py_package.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for python package based installations.""" +from __future__ import annotations + +from mlia.tools.metadata.common import DownloadAndInstall +from mlia.tools.metadata.common import Installation +from mlia.tools.metadata.common import InstallationType +from mlia.utils.py_manager import get_package_manager + + +class PyPackageBackendInstallation(Installation): + """Backend based on the python package.""" + + def __init__( + self, + name: str, + description: str, + packages_to_install: list[str], + packages_to_uninstall: list[str], + expected_packages: list[str], + ) -> None: + """Init the backend installation.""" + self._name = name + self._description = description + self._packages_to_install = packages_to_install + self._packages_to_uninstall = packages_to_uninstall + self._expected_packages = expected_packages + + self.package_manager = get_package_manager() + + @property + def name(self) -> str: + """Return name of the backend.""" + return self._name + + @property + def description(self) -> str: + """Return description of the backend.""" + return self._description + + @property + def could_be_installed(self) -> bool: + """Check if backend could be installed.""" + return True + + @property + def already_installed(self) -> bool: + """Check if backend already installed.""" + return self.package_manager.packages_installed(self._expected_packages) + + def supports(self, install_type: InstallationType) -> bool: + """Return true if installation supports requested installation type.""" + return isinstance(install_type, DownloadAndInstall) + + def install(self, install_type: InstallationType) -> None: + """Install the backend.""" + if not self.supports(install_type): + raise Exception(f"Unsupported installation type {install_type}") + + self.package_manager.install(self._packages_to_install) + + def uninstall(self) -> None: + """Uninstall the backend.""" + self.package_manager.uninstall(self._packages_to_uninstall) + + +def get_tosa_backend_installation() -> Installation: + """Get TOSA backend installation.""" + return PyPackageBackendInstallation( + name="tosa-checker", + description="Tool to check if a ML model is compatible " + "with the TOSA specification", + packages_to_install=["mlia[tosa]"], + packages_to_uninstall=["tosa-checker"], + expected_packages=["tosa-checker"], + ) + + +def get_pypackage_backend_installations() -> list[Installation]: + """Return list of the backend installations based on python packages.""" + return [ + get_tosa_backend_installation(), + ] diff --git a/src/mlia/utils/py_manager.py b/src/mlia/utils/py_manager.py new file mode 100644 index 0000000..5f98fcc --- /dev/null +++ b/src/mlia/utils/py_manager.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Util functions for managing python packages.""" +from __future__ import annotations + +import sys +from importlib.metadata import distribution +from importlib.metadata import PackageNotFoundError +from subprocess import check_call # nosec + + +class PyPackageManager: + """Python package manager.""" + + @staticmethod + def package_installed(pkg_name: str) -> bool: + """Return true if package installed.""" + try: + distribution(pkg_name) + except PackageNotFoundError: + return False + + return True + + def packages_installed(self, pkg_names: list[str]) -> bool: + """Return true if all provided packages installed.""" + return all(self.package_installed(pkg) for pkg in pkg_names) + + def install(self, pkg_names: list[str]) -> None: + """Install provided packages.""" + if not pkg_names: + raise ValueError("No package names provided") + + self._execute_pip_cmd("install", pkg_names) + + def uninstall(self, pkg_names: list[str]) -> None: + """Uninstall provided packages.""" + if not pkg_names: + raise ValueError("No package names provided") + + self._execute_pip_cmd("uninstall", ["--yes", *pkg_names]) + + @staticmethod + def _execute_pip_cmd(subcommand: str, params: list[str]) -> None: + """Execute pip command.""" + assert sys.executable, "Unable to launch pip command" + + check_call( + [ + sys.executable, + "-m", + "pip", + "--disable-pip-version-check", + subcommand, + *params, + ] + ) + + +def get_package_manager() -> PyPackageManager: + """Get python packages manager.""" + return PyPackageManager() diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py index f6e0843..3a01f78 100644 --- a/tests/test_cli_commands.py +++ b/tests/test_cli_commands.py @@ -175,16 +175,17 @@ def test_backend_command_action_uninstall( @pytest.mark.parametrize( - "i_agree_to_the_contained_eula, backend_name, expected_calls", + "i_agree_to_the_contained_eula, force, backend_name, expected_calls", [ - [False, "backend_name", [call("backend_name", True)]], - [True, "backend_name", [call("backend_name", False)]], - [True, "BACKEND_NAME", [call("BACKEND_NAME", False)]], + [False, False, "backend_name", [call("backend_name", True, False)]], + [True, False, "backend_name", [call("backend_name", False, False)]], + [True, True, "BACKEND_NAME", [call("BACKEND_NAME", False, True)]], ], ) def test_backend_command_action_add_download( installation_manager_mock: MagicMock, i_agree_to_the_contained_eula: bool, + force: bool, backend_name: str, expected_calls: Any, ) -> None: @@ -192,6 +193,7 @@ def test_backend_command_action_add_download( backend_install( name=backend_name, i_agree_to_the_contained_eula=i_agree_to_the_contained_eula, + force=force, ) assert installation_manager_mock.download_and_install.mock_calls == expected_calls diff --git a/tests/test_mlia_utils_py_manager.py b/tests/test_mlia_utils_py_manager.py new file mode 100644 index 0000000..e41680d --- /dev/null +++ b/tests/test_mlia_utils_py_manager.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for python package manager.""" +import sys +from unittest.mock import MagicMock + +import pytest + +from mlia.utils.py_manager import get_package_manager +from mlia.utils.py_manager import PyPackageManager + + +def test_get_package_manager() -> None: + """Test function get_package_manager.""" + manager = get_package_manager() + assert isinstance(manager, PyPackageManager) + + +@pytest.fixture(name="mock_check_call") +def mock_check_call_fixture(monkeypatch: pytest.MonkeyPatch) -> MagicMock: + """Mock check_call function.""" + mock_check_call = MagicMock() + monkeypatch.setattr("mlia.utils.py_manager.check_call", mock_check_call) + + return mock_check_call + + +def test_py_package_manager_metadata() -> None: + """Test getting package status.""" + manager = PyPackageManager() + assert manager.package_installed("pytest") + assert manager.packages_installed(["pytest", "mlia"]) + + +def test_py_package_manager_install(mock_check_call: MagicMock) -> None: + """Test package installation.""" + manager = PyPackageManager() + with pytest.raises(ValueError, match="No package names provided"): + manager.install([]) + + manager.install(["mlia", "pytest"]) + mock_check_call.assert_called_once_with( + [ + sys.executable, + "-m", + "pip", + "--disable-pip-version-check", + "install", + "mlia", + "pytest", + ] + ) + + +def test_py_package_manager_uninstall(mock_check_call: MagicMock) -> None: + """Test package removal.""" + manager = PyPackageManager() + with pytest.raises(ValueError, match="No package names provided"): + manager.uninstall([]) + + manager.uninstall(["mlia", "pytest"]) + mock_check_call.assert_called_once_with( + [ + sys.executable, + "-m", + "pip", + "--disable-pip-version-check", + "uninstall", + "--yes", + "mlia", + "pytest", + ] + ) diff --git a/tests/test_tools_metadata_common.py b/tests/test_tools_metadata_common.py index fefb024..9811852 100644 --- a/tests/test_tools_metadata_common.py +++ b/tests/test_tools_metadata_common.py @@ -46,7 +46,7 @@ def get_installation_mock( name: str, already_installed: bool = False, could_be_installed: bool = False, - supported_install_type: type | None = None, + supported_install_type: type | tuple | None = None, ) -> MagicMock: """Get mock instance for the installation.""" mock = MagicMock(spec=Installation) @@ -74,6 +74,7 @@ def _already_installed_mock() -> MagicMock: return get_installation_mock( name="already_installed", already_installed=True, + supported_install_type=(DownloadAndInstall, InstallFromPath), ) @@ -136,32 +137,38 @@ def test_installation_manager_filtering() -> None: ready_for_installation, could_be_downloaded_and_installed, ] - assert manager.could_be_downloaded_and_installed( - "could_be_downloaded_and_installed" - ) == [could_be_downloaded_and_installed] - assert manager.could_be_downloaded_and_installed("some_installation") == [] @pytest.mark.parametrize("noninteractive", [True, False]) @pytest.mark.parametrize( - "install_mock, eula_agreement, backend_name, expected_call", + "install_mock, eula_agreement, backend_name, force, expected_call", [ [ _could_be_downloaded_and_installed_mock(), True, - None, + "could_be_downloaded_and_installed", + False, [call(DownloadAndInstall(eula_agreement=True))], ], [ _could_be_downloaded_and_installed_mock(), False, - None, + "could_be_downloaded_and_installed", + True, + [call(DownloadAndInstall(eula_agreement=False))], + ], + [ + _already_installed_mock(), + False, + "already_installed", + True, [call(DownloadAndInstall(eula_agreement=False))], ], [ _could_be_downloaded_and_installed_mock(), False, "unknown", + True, [], ], ], @@ -171,6 +178,7 @@ def test_installation_manager_download_and_install( noninteractive: bool, eula_agreement: bool, backend_name: str, + force: bool, expected_call: Any, monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -179,35 +187,58 @@ def test_installation_manager_download_and_install( manager = get_installation_manager(noninteractive, [install_mock], monkeypatch) - manager.download_and_install(backend_name, eula_agreement=eula_agreement) + manager.download_and_install( + backend_name, eula_agreement=eula_agreement, force=force + ) + assert install_mock.install.mock_calls == expected_call + if force and install_mock.already_installed: + install_mock.uninstall.assert_called_once() + else: + install_mock.uninstall.assert_not_called() @pytest.mark.parametrize("noninteractive", [True, False]) @pytest.mark.parametrize( - "install_mock, backend_name, expected_call", + "install_mock, backend_name, force, expected_call", [ [ _could_be_installed_from_mock(), - None, + "could_be_installed_from", + False, [call(InstallFromPath(Path("some_path")))], ], [ _could_be_installed_from_mock(), "unknown", + False, + [], + ], + [ + _could_be_installed_from_mock(), + "unknown", + True, [], ], [ _already_installed_mock(), "already_installed", + False, [], ], + [ + _already_installed_mock(), + "already_installed", + True, + [call(InstallFromPath(Path("some_path")))], + ], ], ) def test_installation_manager_install_from( install_mock: MagicMock, noninteractive: bool, backend_name: str, + force: bool, expected_call: Any, monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -215,9 +246,13 @@ def test_installation_manager_install_from( install_mock.reset_mock() manager = get_installation_manager(noninteractive, [install_mock], monkeypatch) - manager.install_from(Path("some_path"), backend_name) + manager.install_from(Path("some_path"), backend_name, force=force) assert install_mock.install.mock_calls == expected_call + if force and install_mock.already_installed: + install_mock.uninstall.assert_called_once() + else: + install_mock.uninstall.assert_not_called() @pytest.mark.parametrize("noninteractive", [True, False]) diff --git a/tests/test_tools_metadata_py_package.py b/tests/test_tools_metadata_py_package.py new file mode 100644 index 0000000..8b93e33 --- /dev/null +++ b/tests/test_tools_metadata_py_package.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for python package based installations.""" +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from mlia.tools.metadata.common import DownloadAndInstall +from mlia.tools.metadata.common import InstallFromPath +from mlia.tools.metadata.py_package import get_pypackage_backend_installations +from mlia.tools.metadata.py_package import get_tosa_backend_installation +from mlia.tools.metadata.py_package import PyPackageBackendInstallation + + +def test_get_pypackage_backends() -> None: + """Test function get_pypackage_backends.""" + backend_installs = get_pypackage_backend_installations() + + assert isinstance(backend_installs, list) + assert len(backend_installs) == 1 + + tosa_installation = backend_installs[0] + assert isinstance(tosa_installation, PyPackageBackendInstallation) + + +def test_get_tosa_backend_installation( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """Test function get_tosa_backend_installation.""" + mock_package_manager = MagicMock() + monkeypatch.setattr( + "mlia.tools.metadata.py_package.get_package_manager", + lambda: mock_package_manager, + ) + + tosa_installation = get_tosa_backend_installation() + + assert isinstance(tosa_installation, PyPackageBackendInstallation) + assert tosa_installation.name == "tosa-checker" + assert ( + tosa_installation.description + == "Tool to check if a ML model is compatible with the TOSA specification" + ) + assert tosa_installation.could_be_installed + assert tosa_installation.supports(DownloadAndInstall()) + assert not tosa_installation.supports(InstallFromPath(tmp_path)) + + mock_package_manager.packages_installed.return_value = True + assert tosa_installation.already_installed + mock_package_manager.packages_installed.assert_called_once_with(["tosa-checker"]) + + with pytest.raises(Exception, match=r"Unsupported installation type.*"): + tosa_installation.install(InstallFromPath(tmp_path)) + + mock_package_manager.install.assert_not_called() + + tosa_installation.install(DownloadAndInstall()) + mock_package_manager.install.assert_called_once_with(["mlia[tosa]"]) + + tosa_installation.uninstall() + mock_package_manager.uninstall.assert_called_once_with(["tosa-checker"]) -- cgit v1.2.1