aboutsummaryrefslogtreecommitdiff
path: root/src/aiet/utils/fs.py
blob: ea99a698200f3056d4e32de23d1929b988adb09f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module to host all file system related functions."""
import importlib.resources as pkg_resources
import re
import shutil
from pathlib import Path
from typing import Any
from typing import Literal
from typing import Optional

ResourceType = Literal["applications", "systems", "tools"]


def get_aiet_resources() -> Path:
    """Get resources folder path."""
    with pkg_resources.path("aiet", "__init__.py") as init_path:
        project_root = init_path.parent
        return project_root / "resources"


def get_resources(name: ResourceType) -> Path:
    """Return the absolute path of the specified resource.

    It uses importlib to return resources packaged with MANIFEST.in.
    """
    if not name:
        raise ResourceWarning("Resource name is not provided")

    resource_path = get_aiet_resources() / name
    if resource_path.is_dir():
        return resource_path

    raise ResourceWarning("Resource '{}' not found.".format(name))


def copy_directory_content(source: Path, destination: Path) -> None:
    """Copy content of the source directory into destination directory."""
    for item in source.iterdir():
        src = source / item.name
        dest = destination / item.name

        if src.is_dir():
            shutil.copytree(src, dest)
        else:
            shutil.copy2(src, dest)


def remove_resource(resource_directory: str, resource_type: ResourceType) -> None:
    """Remove resource data."""
    resources = get_resources(resource_type)

    resource_location = resources / resource_directory
    if not resource_location.exists():
        raise Exception("Resource {} does not exist".format(resource_directory))

    if not resource_location.is_dir():
        raise Exception("Wrong resource {}".format(resource_directory))

    shutil.rmtree(resource_location)


def remove_directory(directory_path: Optional[Path]) -> None:
    """Remove directory."""
    if not directory_path or not directory_path.is_dir():
        raise Exception("No directory path provided")

    shutil.rmtree(directory_path)


def recreate_directory(directory_path: Optional[Path]) -> None:
    """Recreate directory."""
    if not directory_path:
        raise Exception("No directory path provided")

    if directory_path.exists() and not directory_path.is_dir():
        raise Exception(
            "Path {} does exist and it is not a directory".format(str(directory_path))
        )

    if directory_path.is_dir():
        remove_directory(directory_path)

    directory_path.mkdir()


def read_file(file_path: Path, mode: Optional[str] = None) -> Any:
    """Read file as string or bytearray."""
    if file_path.is_file():
        if mode is not None:
            # Ignore pylint warning because mode can be 'binary' as well which
            # is not compatible with specifying encodings.
            with open(file_path, mode) as file:  # pylint: disable=unspecified-encoding
                return file.read()
        else:
            with open(file_path, encoding="utf-8") as file:
                return file.read()

    if mode == "rb":
        return b""
    return ""


def read_file_as_string(file_path: Path) -> str:
    """Read file as string."""
    return str(read_file(file_path))


def read_file_as_bytearray(file_path: Path) -> bytearray:
    """Read a file as bytearray."""
    return bytearray(read_file(file_path, mode="rb"))


def valid_for_filename(value: str, replacement: str = "") -> str:
    """Replace non alpha numeric characters."""
    return re.sub(r"[^\w.]", replacement, value, flags=re.ASCII)