diff options
Diffstat (limited to 'tests/test_utils_download.py')
-rw-r--r-- | tests/test_utils_download.py | 35 |
1 files changed, 20 insertions, 15 deletions
diff --git a/tests/test_utils_download.py b/tests/test_utils_download.py index 28af74f..7188c62 100644 --- a/tests/test_utils_download.py +++ b/tests/test_utils_download.py @@ -1,8 +1,9 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for download functionality.""" from __future__ import annotations +import hashlib from contextlib import ExitStack as does_not_raise from pathlib import Path from typing import Any @@ -14,7 +15,7 @@ import pytest import requests from mlia.utils.download import download -from mlia.utils.download import DownloadArtifact +from mlia.utils.download import DownloadConfig def response_mock( @@ -69,9 +70,18 @@ def test_download( "mlia.utils.download.requests.get", MagicMock(return_value=response_mock(content_length, content_chunks)), ) + hash_obj = hashlib.sha256() + for chunk in content_chunks: + hash_obj.update(chunk) + sha256_hash = hash_obj.hexdigest() dest = tmp_path / "sample.bin" - download("some_url", dest, show_progress=show_progress, label=label) + download( + dest, + DownloadConfig("some_url", sha256_hash=sha256_hash), + show_progress=show_progress, + label=label, + ) assert dest.is_file() assert dest.read_bytes() == bytes( @@ -92,7 +102,7 @@ def test_download( "10", [bytes(range(10))], "bad_hash", - pytest.raises(ValueError, match="Digests do not match"), + pytest.raises(ValueError, match="Hashes do not match."), ], ], ) @@ -111,15 +121,13 @@ def test_download_artifact_download_to( ) with expected_error: - artifact = DownloadArtifact( - "test_artifact", + cfg = DownloadConfig( "some_url", - "artifact_filename", - "1.0", sha256_hash, ) - dest = artifact.download_to(tmp_path) + dest = tmp_path / "artifact_filename" + download(dest, cfg) assert isinstance(dest, Path) assert dest.name == "artifact_filename" @@ -133,16 +141,13 @@ def test_download_artifact_unable_to_overwrite( MagicMock(return_value=response_mock("10", [bytes(range(10))])), ) - artifact = DownloadArtifact( - "test_artifact", + cfg = DownloadConfig( "some_url", - "artifact_filename", - "1.0", "sha256_hash", ) existing_file = tmp_path / "artifact_filename" existing_file.touch() - with pytest.raises(ValueError, match=f"{existing_file} already exists"): - artifact.download_to(tmp_path) + with pytest.raises(FileExistsError, match=f"{existing_file} already exists."): + download(existing_file, cfg) |