diff options
Diffstat (limited to 'src/mlia/utils/download.py')
-rw-r--r-- | src/mlia/utils/download.py | 62 |
1 files changed, 32 insertions, 30 deletions
diff --git a/src/mlia/utils/download.py b/src/mlia/utils/download.py index e00be28..2b06fed 100644 --- a/src/mlia/utils/download.py +++ b/src/mlia/utils/download.py @@ -1,10 +1,11 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Utils for files downloading.""" from __future__ import annotations from dataclasses import dataclass from pathlib import Path +from typing import Callable from typing import Iterable import requests @@ -40,20 +41,46 @@ def download_progress( yield chunk +@dataclass +class DownloadConfig: + """Parameters to download an artifact.""" + + url: str + sha256_hash: str + header_gen_fn: Callable[[], dict[str, str]] | None = None + + @property + def filename(self) -> str: + """Get the filename from the URL.""" + return self.url.rsplit("/", 1)[-1] + + @property + def headers(self) -> dict[str, str]: + """Get the headers using the header_gen_fn.""" + return self.header_gen_fn() if self.header_gen_fn else {} + + def download( - url: str, dest: Path, + cfg: DownloadConfig, show_progress: bool = False, label: str | None = None, chunk_size: int = 8192, timeout: int = 30, ) -> None: """Download the file.""" - with requests.get(url, stream=True, timeout=timeout) as resp: + if dest.exists(): + raise FileExistsError(f"{dest} already exists.") + + with requests.get( + cfg.url, stream=True, timeout=timeout, headers=cfg.headers + ) as resp: resp.raise_for_status() content_chunks = resp.iter_content(chunk_size=chunk_size) if show_progress: + if not label: + label = f"Downloading to {dest}." content_length = parse_int(resp.headers.get("Content-Length")) content_chunks = download_progress(content_chunks, content_length, label) @@ -61,30 +88,5 @@ def download( for chunk in content_chunks: file.write(chunk) - -@dataclass -class DownloadArtifact: - """Download artifact attributes.""" - - name: str - url: str - filename: str - version: str - sha256_hash: str - - def download_to(self, dest_dir: Path, show_progress: bool = True) -> Path: - """Download artifact into destination directory.""" - if (dest := dest_dir / self.filename).exists(): - raise ValueError(f"{dest} already exists") - - download( - self.url, - dest, - show_progress=show_progress, - label=f"Downloading {self.name} ver. {self.version}", - ) - - if sha256(dest) != self.sha256_hash: - raise ValueError("Digests do not match") - - return dest + if cfg.sha256_hash and sha256(dest) != cfg.sha256_hash: + raise ValueError("Hashes do not match.") |