aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/utils/download.py
blob: 4658738da425b6d422e909b0790fd75a8bd01958 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Utils for files downloading."""
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable
from typing import List
from typing import Optional

import requests
from rich.progress import BarColumn
from rich.progress import DownloadColumn
from rich.progress import FileSizeColumn
from rich.progress import Progress
from rich.progress import ProgressColumn
from rich.progress import TextColumn

from mlia.utils.filesystem import sha256
from mlia.utils.types import parse_int


def download_progress(
    content_chunks: Iterable[bytes], content_length: Optional[int], label: Optional[str]
) -> Iterable[bytes]:
    """Show progress info while reading content."""
    columns: List[ProgressColumn] = [TextColumn("{task.description}")]

    if content_length is None:
        total = float("inf")
        columns.append(FileSizeColumn())
    else:
        total = content_length
        columns.extend([BarColumn(), DownloadColumn(binary_units=True)])

    with Progress(*columns) as progress:
        task = progress.add_task(label or "Downloading", total=total)

        for chunk in content_chunks:
            progress.update(task, advance=len(chunk))
            yield chunk


def download(
    url: str,
    dest: Path,
    show_progress: bool = False,
    label: Optional[str] = None,
    chunk_size: int = 8192,
) -> None:
    """Download the file."""
    with requests.get(url, stream=True) as resp:
        resp.raise_for_status()
        content_chunks = resp.iter_content(chunk_size=chunk_size)

        if show_progress:
            content_length = parse_int(resp.headers.get("Content-Length"))
            content_chunks = download_progress(content_chunks, content_length, label)

        with open(dest, "wb") as file:
            for chunk in content_chunks:
                file.write(chunk)


@dataclass
class DownloadArtifact:
    """Download artifact attributes."""

    name: str
    url: str
    filename: str
    version: str
    sha256_hash: str

    def download_to(self, dest_dir: Path, show_progress: bool = True) -> Path:
        """Download artifact into destination directory."""
        if (dest := dest_dir / self.filename).exists():
            raise ValueError(f"{dest} already exists")

        download(
            self.url,
            dest,
            show_progress=show_progress,
            label=f"Downloading {self.name} ver. {self.version}",
        )

        if sha256(dest) != self.sha256_hash:
            raise ValueError("Digests do not match")

        return dest