From a91ee307d920b2acc90360278c466433caacaecc Mon Sep 17 00:00:00 2001 From: Benjamin Klimczak Date: Fri, 8 Mar 2024 16:48:45 +0000 Subject: fix: Relax filtering during archive installation - Relax the filtering when unpacking an archive - Add unit tests for the filtering Resolves: MLIA-1042 Change-Id: I8acd6a1596bef1c624a8fc67cdfbac961e0b179d --- src/mlia/backend/install.py | 91 +++++++++++++++++++++++++------------------ tests/test_backend_install.py | 82 +++++++++++++++++++++++++++++++++++++- 2 files changed, 134 insertions(+), 39 deletions(-) diff --git a/src/mlia/backend/install.py b/src/mlia/backend/install.py index f405511..0ced9f6 100644 --- a/src/mlia/backend/install.py +++ b/src/mlia/backend/install.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for installation process.""" from __future__ import annotations @@ -166,58 +166,73 @@ class BackendInstallation(Installation): backend_info.settings, ) + @staticmethod + def _filter_tar_members( + members: Iterable[tarfile.TarInfo], dst_dir: Path + ) -> Iterable[tarfile.TarInfo]: + """ + Make sure we only handle safe files from the tar file. + + To avoid traversal attacks we only allow files that are + - relative paths, i.e. no absolute file paths + - not including directory traversal sequences '..' + """ + + def check_rel(path: Path) -> None: + if path.is_absolute(): + raise ValueError("Path is absolute, but must be relative.") + + def check_in_dir(path: Path) -> None: + abs_path = (dst_dir / path).resolve() + abs_path.relative_to(dst_dir) + + for member in members: + try: + path = Path(member.path) + check_rel(path) + check_in_dir(path) + + if member.islnk() or member.issym(): + # Make sure we are only linking within the + # archive. + lnk = Path(member.linkname) + check_rel(lnk) + check_in_dir(lnk) + + yield member + except ValueError as ex: + logger.warning( + "File '%s' ignored while extracting: %s", + member.path, + ex, + ) + def _download_and_install( self, download_artifact: DownloadArtifact, eula_agrement: bool ) -> None: """Download and install the backend.""" with temp_directory() as tmpdir: try: - downloaded_to = download_artifact.download_to(tmpdir) + dest = download_artifact.download_to(tmpdir) except Exception as err: raise RuntimeError("Unable to download backend artifact.") from err with working_directory(tmpdir / "dist", create_dir=True) as dist_dir: - with tarfile.open(downloaded_to) as archive: - - def get_filtered_members( - members: Iterable[tarfile.TarInfo], - ) -> Iterable[tarfile.TarInfo]: - """ - Make sure we only handle safe files from the tar file. - - To avoid traversal attacks we only allow files that are - - regular files, i.e. not a symlink etc. - - relative paths, i.e. no absolute file paths - - not including directory traversal sequences '..' - """ - for member in members: - try: - if not (member.isfile() or member.isdir()): - raise ValueError("Path is not a regular file.") - path = Path(member.path) - if path.is_absolute(): - raise ValueError( - "Path is absolute, but must be relative." - ) - abs_path = (dist_dir / path).resolve() - abs_path.relative_to(dist_dir) - yield member - except ValueError as ex: - logger.warning( - "File '%s' ignored while extracting from %s: %s", - member.path, - downloaded_to, - ex, - ) - + with tarfile.open(dest) as archive: # Filter files from the tarfile to avoid traversal attacks. # Note: bandit is still putting out a low severity / # low confidence warning despite the check - # From Python 3.8.17 on there is a built-in feature to fix + # From Python 3.9.17 on there is a built-in feature to fix # this using the new argument filter="data", see - # https://docs.python.org/3.8/library/tarfile.html#tarfile.TarFile.extractall + # https://docs.python.org/3.9/library/tarfile.html#tarfile.TarFile.extractall + logger.debug( + "Extracting downloaded artifact %s to %s.", dest, dist_dir + ) archive.extractall( # nosec - dist_dir, members=get_filtered_members(archive.getmembers()) + dist_dir, + members=self._filter_tar_members( + archive.getmembers(), dist_dir + ), ) backend_path = dist_dir diff --git a/tests/test_backend_install.py b/tests/test_backend_install.py index dacb1aa..963766e 100644 --- a/tests/test_backend_install.py +++ b/tests/test_backend_install.py @@ -1,10 +1,12 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for common management functionality.""" from __future__ import annotations import tarfile +import tempfile from pathlib import Path +from typing import Callable from unittest.mock import ANY from unittest.mock import MagicMock @@ -203,3 +205,81 @@ def test_backend_installation_uninstall(backend_repo: MagicMock) -> None: installation.uninstall() backend_repo.remove_backend.assert_called_with("sample_backend") + + +def _gen_rel_file(dir_path: Path) -> Path: + file_path = dir_path / "test.txt" + if not file_path.exists(): + file_path.touch() + return file_path + + +def _gen_abs_file(dir_path: Path) -> Path: + return _gen_rel_file(dir_path).resolve() + + +def _gen_rel_sym(dir_path: Path) -> Path: + file_path = _gen_rel_file(dir_path) + lnk_path = dir_path / "symlink-rel" + lnk_path.symlink_to(file_path.relative_to(dir_path)) + return lnk_path + + +def _gen_abs_sym(dir_path: Path) -> Path: + file_path = _gen_abs_file(dir_path) + lnk_path = dir_path / Path("symlink-abs") + lnk_path.symlink_to(file_path) + return lnk_path + + +def _gen_rel_lnk(dir_path: Path) -> Path: + file_path = _gen_rel_file(dir_path) + lnk_path = dir_path / "hardlink-rel" + lnk_path.symlink_to(file_path.relative_to(dir_path)) + return lnk_path + + +def _gen_abs_lnk(dir_path: Path) -> Path: + file_path = _gen_abs_file(dir_path) + lnk_path = dir_path / Path("hardlink-abs") + lnk_path.symlink_to(file_path) + return lnk_path + + +@pytest.mark.parametrize( + ("gen_file_func", "num_members_in", "num_members_out"), + ( + (_gen_rel_file, 1, 1), + (_gen_rel_sym, 1, 1), + (_gen_abs_sym, 1, 0), + (_gen_rel_lnk, 1, 1), + (_gen_abs_lnk, 1, 0), + ), +) +def test_filter_tar_members( + gen_file_func: Callable[[Path], Path], + num_members_in: int, + num_members_out: int, + tmp_path: Path, +) -> None: + """Test function BackendInstallation._filter_tar_members().""" + + def create_tar(file_to_add: Path, archive_path: Path) -> None: + with tarfile.open(archive_path, "w") as archive: + archive.add(file_to_add, arcname=file_to_add, recursive=False) + + with tempfile.TemporaryDirectory() as tmp_dir_2: + with pytest.raises(ValueError): + Path(tmp_dir_2).relative_to(tmp_path) + + archive_path = Path(tmp_dir_2) / "test.tar.gz" + file_path = gen_file_func(tmp_path) + create_tar(file_path, archive_path) + with tarfile.open(archive_path) as archive: + orig_members = list(archive.getmembers()) + assert len(orig_members) == num_members_in + filtered_members = list( + # pylint: disable=protected-access + BackendInstallation._filter_tar_members(orig_members, tmp_path) + ) + assert len(filtered_members) == num_members_out -- cgit v1.2.1