aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2024-03-08 16:48:45 +0000
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2024-03-13 11:58:33 +0000
commita91ee307d920b2acc90360278c466433caacaecc (patch)
tree505ac13f565fb9289b9d56686cf3c63d46b64274
parent2ba39623502551ec073fbc67b59e0458af084c7e (diff)
downloadmlia-a91ee307d920b2acc90360278c466433caacaecc.tar.gz
fix: Relax filtering during archive installation
- Relax the filtering when unpacking an archive - Add unit tests for the filtering Resolves: MLIA-1042 Change-Id: I8acd6a1596bef1c624a8fc67cdfbac961e0b179d
-rw-r--r--src/mlia/backend/install.py91
-rw-r--r--tests/test_backend_install.py82
2 files changed, 134 insertions, 39 deletions
diff --git a/src/mlia/backend/install.py b/src/mlia/backend/install.py
index f405511..0ced9f6 100644
--- a/src/mlia/backend/install.py
+++ b/src/mlia/backend/install.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for installation process."""
from __future__ import annotations
@@ -166,58 +166,73 @@ class BackendInstallation(Installation):
backend_info.settings,
)
+ @staticmethod
+ def _filter_tar_members(
+ members: Iterable[tarfile.TarInfo], dst_dir: Path
+ ) -> Iterable[tarfile.TarInfo]:
+ """
+ Make sure we only handle safe files from the tar file.
+
+ To avoid traversal attacks we only allow files that are
+ - relative paths, i.e. no absolute file paths
+ - not including directory traversal sequences '..'
+ """
+
+ def check_rel(path: Path) -> None:
+ if path.is_absolute():
+ raise ValueError("Path is absolute, but must be relative.")
+
+ def check_in_dir(path: Path) -> None:
+ abs_path = (dst_dir / path).resolve()
+ abs_path.relative_to(dst_dir)
+
+ for member in members:
+ try:
+ path = Path(member.path)
+ check_rel(path)
+ check_in_dir(path)
+
+ if member.islnk() or member.issym():
+ # Make sure we are only linking within the
+ # archive.
+ lnk = Path(member.linkname)
+ check_rel(lnk)
+ check_in_dir(lnk)
+
+ yield member
+ except ValueError as ex:
+ logger.warning(
+ "File '%s' ignored while extracting: %s",
+ member.path,
+ ex,
+ )
+
def _download_and_install(
self, download_artifact: DownloadArtifact, eula_agrement: bool
) -> None:
"""Download and install the backend."""
with temp_directory() as tmpdir:
try:
- downloaded_to = download_artifact.download_to(tmpdir)
+ dest = download_artifact.download_to(tmpdir)
except Exception as err:
raise RuntimeError("Unable to download backend artifact.") from err
with working_directory(tmpdir / "dist", create_dir=True) as dist_dir:
- with tarfile.open(downloaded_to) as archive:
-
- def get_filtered_members(
- members: Iterable[tarfile.TarInfo],
- ) -> Iterable[tarfile.TarInfo]:
- """
- Make sure we only handle safe files from the tar file.
-
- To avoid traversal attacks we only allow files that are
- - regular files, i.e. not a symlink etc.
- - relative paths, i.e. no absolute file paths
- - not including directory traversal sequences '..'
- """
- for member in members:
- try:
- if not (member.isfile() or member.isdir()):
- raise ValueError("Path is not a regular file.")
- path = Path(member.path)
- if path.is_absolute():
- raise ValueError(
- "Path is absolute, but must be relative."
- )
- abs_path = (dist_dir / path).resolve()
- abs_path.relative_to(dist_dir)
- yield member
- except ValueError as ex:
- logger.warning(
- "File '%s' ignored while extracting from %s: %s",
- member.path,
- downloaded_to,
- ex,
- )
-
+ with tarfile.open(dest) as archive:
# Filter files from the tarfile to avoid traversal attacks.
# Note: bandit is still putting out a low severity /
# low confidence warning despite the check
- # From Python 3.8.17 on there is a built-in feature to fix
+ # From Python 3.9.17 on there is a built-in feature to fix
# this using the new argument filter="data", see
- # https://docs.python.org/3.8/library/tarfile.html#tarfile.TarFile.extractall
+ # https://docs.python.org/3.9/library/tarfile.html#tarfile.TarFile.extractall
+ logger.debug(
+ "Extracting downloaded artifact %s to %s.", dest, dist_dir
+ )
archive.extractall( # nosec
- dist_dir, members=get_filtered_members(archive.getmembers())
+ dist_dir,
+ members=self._filter_tar_members(
+ archive.getmembers(), dist_dir
+ ),
)
backend_path = dist_dir
diff --git a/tests/test_backend_install.py b/tests/test_backend_install.py
index dacb1aa..963766e 100644
--- a/tests/test_backend_install.py
+++ b/tests/test_backend_install.py
@@ -1,10 +1,12 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for common management functionality."""
from __future__ import annotations
import tarfile
+import tempfile
from pathlib import Path
+from typing import Callable
from unittest.mock import ANY
from unittest.mock import MagicMock
@@ -203,3 +205,81 @@ def test_backend_installation_uninstall(backend_repo: MagicMock) -> None:
installation.uninstall()
backend_repo.remove_backend.assert_called_with("sample_backend")
+
+
+def _gen_rel_file(dir_path: Path) -> Path:
+ file_path = dir_path / "test.txt"
+ if not file_path.exists():
+ file_path.touch()
+ return file_path
+
+
+def _gen_abs_file(dir_path: Path) -> Path:
+ return _gen_rel_file(dir_path).resolve()
+
+
+def _gen_rel_sym(dir_path: Path) -> Path:
+ file_path = _gen_rel_file(dir_path)
+ lnk_path = dir_path / "symlink-rel"
+ lnk_path.symlink_to(file_path.relative_to(dir_path))
+ return lnk_path
+
+
+def _gen_abs_sym(dir_path: Path) -> Path:
+ file_path = _gen_abs_file(dir_path)
+ lnk_path = dir_path / Path("symlink-abs")
+ lnk_path.symlink_to(file_path)
+ return lnk_path
+
+
+def _gen_rel_lnk(dir_path: Path) -> Path:
+ file_path = _gen_rel_file(dir_path)
+ lnk_path = dir_path / "hardlink-rel"
+ lnk_path.symlink_to(file_path.relative_to(dir_path))
+ return lnk_path
+
+
+def _gen_abs_lnk(dir_path: Path) -> Path:
+ file_path = _gen_abs_file(dir_path)
+ lnk_path = dir_path / Path("hardlink-abs")
+ lnk_path.symlink_to(file_path)
+ return lnk_path
+
+
+@pytest.mark.parametrize(
+ ("gen_file_func", "num_members_in", "num_members_out"),
+ (
+ (_gen_rel_file, 1, 1),
+ (_gen_rel_sym, 1, 1),
+ (_gen_abs_sym, 1, 0),
+ (_gen_rel_lnk, 1, 1),
+ (_gen_abs_lnk, 1, 0),
+ ),
+)
+def test_filter_tar_members(
+ gen_file_func: Callable[[Path], Path],
+ num_members_in: int,
+ num_members_out: int,
+ tmp_path: Path,
+) -> None:
+ """Test function BackendInstallation._filter_tar_members()."""
+
+ def create_tar(file_to_add: Path, archive_path: Path) -> None:
+ with tarfile.open(archive_path, "w") as archive:
+ archive.add(file_to_add, arcname=file_to_add, recursive=False)
+
+ with tempfile.TemporaryDirectory() as tmp_dir_2:
+ with pytest.raises(ValueError):
+ Path(tmp_dir_2).relative_to(tmp_path)
+
+ archive_path = Path(tmp_dir_2) / "test.tar.gz"
+ file_path = gen_file_func(tmp_path)
+ create_tar(file_path, archive_path)
+ with tarfile.open(archive_path) as archive:
+ orig_members = list(archive.getmembers())
+ assert len(orig_members) == num_members_in
+ filtered_members = list(
+ # pylint: disable=protected-access
+ BackendInstallation._filter_tar_members(orig_members, tmp_path)
+ )
+ assert len(filtered_members) == num_members_out