1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
|
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for download functionality."""
from __future__ import annotations
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
from typing import Iterable
from unittest.mock import MagicMock
from unittest.mock import PropertyMock
import pytest
import requests
from mlia.utils.download import download
from mlia.utils.download import DownloadArtifact
def response_mock(
content_length: str | None, content_chunks: Iterable[bytes]
) -> MagicMock:
"""Mock response object."""
mock = MagicMock(spec=requests.Response)
mock.__enter__.return_value = mock
type(mock).headers = PropertyMock(return_value={"Content-Length": content_length})
mock.iter_content.return_value = content_chunks
return mock
@pytest.mark.parametrize("show_progress", [True, False])
@pytest.mark.parametrize(
"content_length, content_chunks, label",
[
[
"5",
[bytes(range(5))],
"Downloading artifact",
],
[
"10",
[bytes(range(5)), bytes(range(5))],
None,
],
[
None,
[bytes(range(5))],
"Downlading no size",
],
[
"abc",
[bytes(range(5))],
"Downloading wrong size",
],
],
)
def test_download(
show_progress: bool,
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
content_length: str | None,
content_chunks: Iterable[bytes],
label: str | None,
) -> None:
"""Test function download."""
monkeypatch.setattr(
"mlia.utils.download.requests.get",
MagicMock(return_value=response_mock(content_length, content_chunks)),
)
dest = tmp_path / "sample.bin"
download("some_url", dest, show_progress=show_progress, label=label)
assert dest.is_file()
assert dest.read_bytes() == bytes(
byte for chunk in content_chunks for byte in chunk
)
@pytest.mark.parametrize(
"content_length, content_chunks, sha256_hash, expected_error",
[
[
"10",
[bytes(range(10))],
"1f825aa2f0020ef7cf91dfa30da4668d791c5d4824fc8e41354b89ec05795ab3",
does_not_raise(),
],
[
"10",
[bytes(range(10))],
"bad_hash",
pytest.raises(ValueError, match="Digests do not match"),
],
],
)
def test_download_artifact_download_to(
monkeypatch: pytest.MonkeyPatch,
content_length: str | None,
content_chunks: Iterable[bytes],
sha256_hash: str,
expected_error: Any,
tmp_path: Path,
) -> None:
"""Test artifact downloading."""
monkeypatch.setattr(
"mlia.utils.download.requests.get",
MagicMock(return_value=response_mock(content_length, content_chunks)),
)
with expected_error:
artifact = DownloadArtifact(
"test_artifact",
"some_url",
"artifact_filename",
"1.0",
sha256_hash,
)
dest = artifact.download_to(tmp_path)
assert isinstance(dest, Path)
assert dest.name == "artifact_filename"
def test_download_artifact_unable_to_overwrite(
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
) -> None:
"""Test that download process cannot overwrite file."""
monkeypatch.setattr(
"mlia.utils.download.requests.get",
MagicMock(return_value=response_mock("10", [bytes(range(10))])),
)
artifact = DownloadArtifact(
"test_artifact",
"some_url",
"artifact_filename",
"1.0",
"sha256_hash",
)
existing_file = tmp_path / "artifact_filename"
existing_file.touch()
with pytest.raises(ValueError, match=f"{existing_file} already exists"):
artifact.download_to(tmp_path)
|