aboutsummaryrefslogtreecommitdiff
path: root/tests/mlia/test_utils_download.py
blob: 4f8e2dcbddc660cf156165ef990c9c923b206d17 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for download functionality."""
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
from typing import Iterable
from typing import Optional
from unittest.mock import MagicMock
from unittest.mock import PropertyMock

import pytest
import requests

from mlia.utils.download import download
from mlia.utils.download import DownloadArtifact


def response_mock(
    content_length: Optional[str], content_chunks: Iterable[bytes]
) -> MagicMock:
    """Mock response object."""
    mock = MagicMock(spec=requests.Response)
    mock.__enter__.return_value = mock

    type(mock).headers = PropertyMock(return_value={"Content-Length": content_length})
    mock.iter_content.return_value = content_chunks

    return mock


@pytest.mark.parametrize("show_progress", [True, False])
@pytest.mark.parametrize(
    "content_length, content_chunks, label",
    [
        [
            "5",
            [bytes(range(5))],
            "Downloading artifact",
        ],
        [
            "10",
            [bytes(range(5)), bytes(range(5))],
            None,
        ],
        [
            None,
            [bytes(range(5))],
            "Downlading no size",
        ],
        [
            "abc",
            [bytes(range(5))],
            "Downloading wrong size",
        ],
    ],
)
def test_download(
    show_progress: bool,
    tmp_path: Path,
    monkeypatch: pytest.MonkeyPatch,
    content_length: Optional[str],
    content_chunks: Iterable[bytes],
    label: Optional[str],
) -> None:
    """Test function download."""
    monkeypatch.setattr(
        "mlia.utils.download.requests.get",
        MagicMock(return_value=response_mock(content_length, content_chunks)),
    )

    dest = tmp_path / "sample.bin"
    download("some_url", dest, show_progress=show_progress, label=label)

    assert dest.is_file()
    assert dest.read_bytes() == bytes(
        byte for chunk in content_chunks for byte in chunk
    )


@pytest.mark.parametrize(
    "content_length, content_chunks, sha256_hash, expected_error",
    [
        [
            "10",
            [bytes(range(10))],
            "1f825aa2f0020ef7cf91dfa30da4668d791c5d4824fc8e41354b89ec05795ab3",
            does_not_raise(),
        ],
        [
            "10",
            [bytes(range(10))],
            "bad_hash",
            pytest.raises(ValueError, match="Digests do not match"),
        ],
    ],
)
def test_download_artifact_download_to(
    monkeypatch: pytest.MonkeyPatch,
    content_length: Optional[str],
    content_chunks: Iterable[bytes],
    sha256_hash: str,
    expected_error: Any,
    tmp_path: Path,
) -> None:
    """Test artifact downloading."""
    monkeypatch.setattr(
        "mlia.utils.download.requests.get",
        MagicMock(return_value=response_mock(content_length, content_chunks)),
    )

    with expected_error:
        artifact = DownloadArtifact(
            "test_artifact",
            "some_url",
            "artifact_filename",
            "1.0",
            sha256_hash,
        )

        dest = artifact.download_to(tmp_path)
        assert isinstance(dest, Path)
        assert dest.name == "artifact_filename"


def test_download_artifact_unable_to_overwrite(
    monkeypatch: pytest.MonkeyPatch, tmp_path: Path
) -> None:
    """Test that download process cannot overwrite file."""
    monkeypatch.setattr(
        "mlia.utils.download.requests.get",
        MagicMock(return_value=response_mock("10", [bytes(range(10))])),
    )

    artifact = DownloadArtifact(
        "test_artifact",
        "some_url",
        "artifact_filename",
        "1.0",
        "sha256_hash",
    )

    existing_file = tmp_path / "artifact_filename"
    existing_file.touch()

    with pytest.raises(ValueError, match=f"{existing_file} already exists"):
        artifact.download_to(tmp_path)