aboutsummaryrefslogtreecommitdiff
path: root/tests/test_utils_download.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_utils_download.py')
-rw-r--r--tests/test_utils_download.py35
1 files changed, 20 insertions, 15 deletions
diff --git a/tests/test_utils_download.py b/tests/test_utils_download.py
index 28af74f..7188c62 100644
--- a/tests/test_utils_download.py
+++ b/tests/test_utils_download.py
@@ -1,8 +1,9 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for download functionality."""
from __future__ import annotations
+import hashlib
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
@@ -14,7 +15,7 @@ import pytest
import requests
from mlia.utils.download import download
-from mlia.utils.download import DownloadArtifact
+from mlia.utils.download import DownloadConfig
def response_mock(
@@ -69,9 +70,18 @@ def test_download(
"mlia.utils.download.requests.get",
MagicMock(return_value=response_mock(content_length, content_chunks)),
)
+ hash_obj = hashlib.sha256()
+ for chunk in content_chunks:
+ hash_obj.update(chunk)
+ sha256_hash = hash_obj.hexdigest()
dest = tmp_path / "sample.bin"
- download("some_url", dest, show_progress=show_progress, label=label)
+ download(
+ dest,
+ DownloadConfig("some_url", sha256_hash=sha256_hash),
+ show_progress=show_progress,
+ label=label,
+ )
assert dest.is_file()
assert dest.read_bytes() == bytes(
@@ -92,7 +102,7 @@ def test_download(
"10",
[bytes(range(10))],
"bad_hash",
- pytest.raises(ValueError, match="Digests do not match"),
+ pytest.raises(ValueError, match="Hashes do not match."),
],
],
)
@@ -111,15 +121,13 @@ def test_download_artifact_download_to(
)
with expected_error:
- artifact = DownloadArtifact(
- "test_artifact",
+ cfg = DownloadConfig(
"some_url",
- "artifact_filename",
- "1.0",
sha256_hash,
)
- dest = artifact.download_to(tmp_path)
+ dest = tmp_path / "artifact_filename"
+ download(dest, cfg)
assert isinstance(dest, Path)
assert dest.name == "artifact_filename"
@@ -133,16 +141,13 @@ def test_download_artifact_unable_to_overwrite(
MagicMock(return_value=response_mock("10", [bytes(range(10))])),
)
- artifact = DownloadArtifact(
- "test_artifact",
+ cfg = DownloadConfig(
"some_url",
- "artifact_filename",
- "1.0",
"sha256_hash",
)
existing_file = tmp_path / "artifact_filename"
existing_file.touch()
- with pytest.raises(ValueError, match=f"{existing_file} already exists"):
- artifact.download_to(tmp_path)
+ with pytest.raises(FileExistsError, match=f"{existing_file} already exists."):
+ download(existing_file, cfg)