diff options
Diffstat (limited to 'src/mlia/utils/py_manager.py')
-rw-r--r-- | src/mlia/utils/py_manager.py | 62 |
1 files changed, 62 insertions, 0 deletions
diff --git a/src/mlia/utils/py_manager.py b/src/mlia/utils/py_manager.py new file mode 100644 index 0000000..5f98fcc --- /dev/null +++ b/src/mlia/utils/py_manager.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Util functions for managing python packages.""" +from __future__ import annotations + +import sys +from importlib.metadata import distribution +from importlib.metadata import PackageNotFoundError +from subprocess import check_call # nosec + + +class PyPackageManager: + """Python package manager.""" + + @staticmethod + def package_installed(pkg_name: str) -> bool: + """Return true if package installed.""" + try: + distribution(pkg_name) + except PackageNotFoundError: + return False + + return True + + def packages_installed(self, pkg_names: list[str]) -> bool: + """Return true if all provided packages installed.""" + return all(self.package_installed(pkg) for pkg in pkg_names) + + def install(self, pkg_names: list[str]) -> None: + """Install provided packages.""" + if not pkg_names: + raise ValueError("No package names provided") + + self._execute_pip_cmd("install", pkg_names) + + def uninstall(self, pkg_names: list[str]) -> None: + """Uninstall provided packages.""" + if not pkg_names: + raise ValueError("No package names provided") + + self._execute_pip_cmd("uninstall", ["--yes", *pkg_names]) + + @staticmethod + def _execute_pip_cmd(subcommand: str, params: list[str]) -> None: + """Execute pip command.""" + assert sys.executable, "Unable to launch pip command" + + check_call( + [ + sys.executable, + "-m", + "pip", + "--disable-pip-version-check", + subcommand, + *params, + ] + ) + + +def get_package_manager() -> PyPackageManager: + """Get python packages manager.""" + return PyPackageManager() |