From 0efca3cadbad5517a59884576ddb90cfe7ac30f8 Mon Sep 17 00:00:00 2001 From: Diego Russo Date: Mon, 30 May 2022 13:34:14 +0100 Subject: Add MLIA codebase Add MLIA codebase including sources and tests. Change-Id: Id41707559bd721edd114793618d12ccd188d8dbd --- src/aiet/backend/source.py | 209 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 209 insertions(+) create mode 100644 src/aiet/backend/source.py (limited to 'src/aiet/backend/source.py') diff --git a/src/aiet/backend/source.py b/src/aiet/backend/source.py new file mode 100644 index 0000000..dec175a --- /dev/null +++ b/src/aiet/backend/source.py @@ -0,0 +1,209 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Contain source related classes and functions.""" +import os +import shutil +import tarfile +from abc import ABC +from abc import abstractmethod +from pathlib import Path +from tarfile import TarFile +from typing import Optional +from typing import Union + +from aiet.backend.common import AIET_CONFIG_FILE +from aiet.backend.common import ConfigurationException +from aiet.backend.common import get_backend_config +from aiet.backend.common import is_backend_directory +from aiet.backend.common import load_config +from aiet.backend.config import BackendConfig +from aiet.utils.fs import copy_directory_content + + +class Source(ABC): + """Source class.""" + + @abstractmethod + def name(self) -> Optional[str]: + """Get source name.""" + + @abstractmethod + def config(self) -> Optional[BackendConfig]: + """Get configuration file content.""" + + @abstractmethod + def install_into(self, destination: Path) -> None: + """Install source into destination directory.""" + + @abstractmethod + def create_destination(self) -> bool: + """Return True if destination folder should be created before installation.""" + + +class DirectorySource(Source): + """DirectorySource class.""" + + def __init__(self, directory_path: Path) -> None: + """Create the DirectorySource instance.""" + assert isinstance(directory_path, Path) + self.directory_path = directory_path + + def name(self) -> str: + """Return name of source.""" + return self.directory_path.name + + def config(self) -> Optional[BackendConfig]: + """Return configuration file content.""" + if not is_backend_directory(self.directory_path): + raise ConfigurationException("No configuration file found") + + config_file = get_backend_config(self.directory_path) + return load_config(config_file) + + def install_into(self, destination: Path) -> None: + """Install source into destination directory.""" + if not destination.is_dir(): + raise ConfigurationException("Wrong destination {}".format(destination)) + + if not self.directory_path.is_dir(): + raise ConfigurationException( + "Directory {} does not exist".format(self.directory_path) + ) + + copy_directory_content(self.directory_path, destination) + + def create_destination(self) -> bool: + """Return True if destination folder should be created before installation.""" + return True + + +class TarArchiveSource(Source): + """TarArchiveSource class.""" + + def __init__(self, archive_path: Path) -> None: + """Create the TarArchiveSource class.""" + assert isinstance(archive_path, Path) + self.archive_path = archive_path + self._config: Optional[BackendConfig] = None + self._has_top_level_folder: Optional[bool] = None + self._name: Optional[str] = None + + def _read_archive_content(self) -> None: + """Read various information about archive.""" + # get source name from archive name (everything without extensions) + extensions = "".join(self.archive_path.suffixes) + self._name = self.archive_path.name.rstrip(extensions) + + if not self.archive_path.exists(): + return + + with self._open(self.archive_path) as archive: + try: + config_entry = archive.getmember(AIET_CONFIG_FILE) + self._has_top_level_folder = False + except KeyError as error_no_config: + try: + archive_entries = archive.getnames() + entries_common_prefix = os.path.commonprefix(archive_entries) + top_level_dir = entries_common_prefix.rstrip("/") + + if not top_level_dir: + raise RuntimeError( + "Archive has no top level directory" + ) from error_no_config + + config_path = "{}/{}".format(top_level_dir, AIET_CONFIG_FILE) + + config_entry = archive.getmember(config_path) + self._has_top_level_folder = True + self._name = top_level_dir + except (KeyError, RuntimeError) as error_no_root_dir_or_config: + raise ConfigurationException( + "No configuration file found" + ) from error_no_root_dir_or_config + + content = archive.extractfile(config_entry) + self._config = load_config(content) + + def config(self) -> Optional[BackendConfig]: + """Return configuration file content.""" + if self._config is None: + self._read_archive_content() + + return self._config + + def name(self) -> Optional[str]: + """Return name of the source.""" + if self._name is None: + self._read_archive_content() + + return self._name + + def create_destination(self) -> bool: + """Return True if destination folder must be created before installation.""" + if self._has_top_level_folder is None: + self._read_archive_content() + + return not self._has_top_level_folder + + def install_into(self, destination: Path) -> None: + """Install source into destination directory.""" + if not destination.is_dir(): + raise ConfigurationException("Wrong destination {}".format(destination)) + + with self._open(self.archive_path) as archive: + archive.extractall(destination) + + def _open(self, archive_path: Path) -> TarFile: + """Open archive file.""" + if not archive_path.is_file(): + raise ConfigurationException("File {} does not exist".format(archive_path)) + + if archive_path.name.endswith("tar.gz") or archive_path.name.endswith("tgz"): + mode = "r:gz" + else: + raise ConfigurationException( + "Unsupported archive type {}".format(archive_path) + ) + + # The returned TarFile object can be used as a context manager (using + # 'with') by the calling instance. + return tarfile.open( # pylint: disable=consider-using-with + self.archive_path, mode=mode + ) + + +def get_source(source_path: Path) -> Union[TarArchiveSource, DirectorySource]: + """Return appropriate source instance based on provided source path.""" + if source_path.is_file(): + return TarArchiveSource(source_path) + + if source_path.is_dir(): + return DirectorySource(source_path) + + raise ConfigurationException("Unable to read {}".format(source_path)) + + +def create_destination_and_install(source: Source, resource_path: Path) -> None: + """Create destination directory and install source. + + This function is used for actual installation of system/backend New + directory will be created inside :resource_path: if needed If for example + archive contains top level folder then no need to create new directory + """ + destination = resource_path + create_destination = source.create_destination() + + if create_destination: + name = source.name() + if not name: + raise ConfigurationException("Unable to get source name") + + destination = resource_path / name + destination.mkdir() + try: + source.install_into(destination) + except Exception as error: + if create_destination: + shutil.rmtree(destination) + raise error -- cgit v1.2.1