aboutsummaryrefslogtreecommitdiff
path: root/src/aiet/backend/source.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/aiet/backend/source.py')
-rw-r--r--src/aiet/backend/source.py209
1 files changed, 209 insertions, 0 deletions
diff --git a/src/aiet/backend/source.py b/src/aiet/backend/source.py
new file mode 100644
index 0000000..dec175a
--- /dev/null
+++ b/src/aiet/backend/source.py
@@ -0,0 +1,209 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Contain source related classes and functions."""
+import os
+import shutil
+import tarfile
+from abc import ABC
+from abc import abstractmethod
+from pathlib import Path
+from tarfile import TarFile
+from typing import Optional
+from typing import Union
+
+from aiet.backend.common import AIET_CONFIG_FILE
+from aiet.backend.common import ConfigurationException
+from aiet.backend.common import get_backend_config
+from aiet.backend.common import is_backend_directory
+from aiet.backend.common import load_config
+from aiet.backend.config import BackendConfig
+from aiet.utils.fs import copy_directory_content
+
+
+class Source(ABC):
+ """Source class."""
+
+ @abstractmethod
+ def name(self) -> Optional[str]:
+ """Get source name."""
+
+ @abstractmethod
+ def config(self) -> Optional[BackendConfig]:
+ """Get configuration file content."""
+
+ @abstractmethod
+ def install_into(self, destination: Path) -> None:
+ """Install source into destination directory."""
+
+ @abstractmethod
+ def create_destination(self) -> bool:
+ """Return True if destination folder should be created before installation."""
+
+
+class DirectorySource(Source):
+ """DirectorySource class."""
+
+ def __init__(self, directory_path: Path) -> None:
+ """Create the DirectorySource instance."""
+ assert isinstance(directory_path, Path)
+ self.directory_path = directory_path
+
+ def name(self) -> str:
+ """Return name of source."""
+ return self.directory_path.name
+
+ def config(self) -> Optional[BackendConfig]:
+ """Return configuration file content."""
+ if not is_backend_directory(self.directory_path):
+ raise ConfigurationException("No configuration file found")
+
+ config_file = get_backend_config(self.directory_path)
+ return load_config(config_file)
+
+ def install_into(self, destination: Path) -> None:
+ """Install source into destination directory."""
+ if not destination.is_dir():
+ raise ConfigurationException("Wrong destination {}".format(destination))
+
+ if not self.directory_path.is_dir():
+ raise ConfigurationException(
+ "Directory {} does not exist".format(self.directory_path)
+ )
+
+ copy_directory_content(self.directory_path, destination)
+
+ def create_destination(self) -> bool:
+ """Return True if destination folder should be created before installation."""
+ return True
+
+
+class TarArchiveSource(Source):
+ """TarArchiveSource class."""
+
+ def __init__(self, archive_path: Path) -> None:
+ """Create the TarArchiveSource class."""
+ assert isinstance(archive_path, Path)
+ self.archive_path = archive_path
+ self._config: Optional[BackendConfig] = None
+ self._has_top_level_folder: Optional[bool] = None
+ self._name: Optional[str] = None
+
+ def _read_archive_content(self) -> None:
+ """Read various information about archive."""
+ # get source name from archive name (everything without extensions)
+ extensions = "".join(self.archive_path.suffixes)
+ self._name = self.archive_path.name.rstrip(extensions)
+
+ if not self.archive_path.exists():
+ return
+
+ with self._open(self.archive_path) as archive:
+ try:
+ config_entry = archive.getmember(AIET_CONFIG_FILE)
+ self._has_top_level_folder = False
+ except KeyError as error_no_config:
+ try:
+ archive_entries = archive.getnames()
+ entries_common_prefix = os.path.commonprefix(archive_entries)
+ top_level_dir = entries_common_prefix.rstrip("/")
+
+ if not top_level_dir:
+ raise RuntimeError(
+ "Archive has no top level directory"
+ ) from error_no_config
+
+ config_path = "{}/{}".format(top_level_dir, AIET_CONFIG_FILE)
+
+ config_entry = archive.getmember(config_path)
+ self._has_top_level_folder = True
+ self._name = top_level_dir
+ except (KeyError, RuntimeError) as error_no_root_dir_or_config:
+ raise ConfigurationException(
+ "No configuration file found"
+ ) from error_no_root_dir_or_config
+
+ content = archive.extractfile(config_entry)
+ self._config = load_config(content)
+
+ def config(self) -> Optional[BackendConfig]:
+ """Return configuration file content."""
+ if self._config is None:
+ self._read_archive_content()
+
+ return self._config
+
+ def name(self) -> Optional[str]:
+ """Return name of the source."""
+ if self._name is None:
+ self._read_archive_content()
+
+ return self._name
+
+ def create_destination(self) -> bool:
+ """Return True if destination folder must be created before installation."""
+ if self._has_top_level_folder is None:
+ self._read_archive_content()
+
+ return not self._has_top_level_folder
+
+ def install_into(self, destination: Path) -> None:
+ """Install source into destination directory."""
+ if not destination.is_dir():
+ raise ConfigurationException("Wrong destination {}".format(destination))
+
+ with self._open(self.archive_path) as archive:
+ archive.extractall(destination)
+
+ def _open(self, archive_path: Path) -> TarFile:
+ """Open archive file."""
+ if not archive_path.is_file():
+ raise ConfigurationException("File {} does not exist".format(archive_path))
+
+ if archive_path.name.endswith("tar.gz") or archive_path.name.endswith("tgz"):
+ mode = "r:gz"
+ else:
+ raise ConfigurationException(
+ "Unsupported archive type {}".format(archive_path)
+ )
+
+ # The returned TarFile object can be used as a context manager (using
+ # 'with') by the calling instance.
+ return tarfile.open( # pylint: disable=consider-using-with
+ self.archive_path, mode=mode
+ )
+
+
+def get_source(source_path: Path) -> Union[TarArchiveSource, DirectorySource]:
+ """Return appropriate source instance based on provided source path."""
+ if source_path.is_file():
+ return TarArchiveSource(source_path)
+
+ if source_path.is_dir():
+ return DirectorySource(source_path)
+
+ raise ConfigurationException("Unable to read {}".format(source_path))
+
+
+def create_destination_and_install(source: Source, resource_path: Path) -> None:
+ """Create destination directory and install source.
+
+ This function is used for actual installation of system/backend New
+ directory will be created inside :resource_path: if needed If for example
+ archive contains top level folder then no need to create new directory
+ """
+ destination = resource_path
+ create_destination = source.create_destination()
+
+ if create_destination:
+ name = source.name()
+ if not name:
+ raise ConfigurationException("Unable to get source name")
+
+ destination = resource_path / name
+ destination.mkdir()
+ try:
+ source.install_into(destination)
+ except Exception as error:
+ if create_destination:
+ shutil.rmtree(destination)
+ raise error