diff options
author | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2024-03-21 17:33:17 +0000 |
---|---|---|
committer | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2024-03-22 10:06:28 +0000 |
commit | c7ee5b783f044d7ff641773aa385840f5ff944cc (patch) | |
tree | 297f308978b00282d8ebd3a1f71e1ae5e678a767 /src/mlia/utils/download.py | |
parent | 508281df31dc3c18f2e007f4dd505160342a681a (diff) | |
download | mlia-c7ee5b783f044d7ff641773aa385840f5ff944cc.tar.gz |
refactor: Backend dependencies and more
- Add backend dependencies: One backend can now depend on another
backend.
- Re-factor 'DownloadArtifact':
- Rename 'DownloadArtifact' to 'DownloadConfig'
- Remove attributes 'name' and 'version' not relevant for downloads
- Add helper properties:
- 'filename' parses the URL to extract the file name from the end
- 'headers' calls the function to generate a HTML header for the
download
- Add OutputLogger helper class
- Re-factor handling of backend configurations in the target profiles.
Change-Id: Ifda6cf12c375d0c1747d7e4130a0370d22c3d33a
Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
Diffstat (limited to 'src/mlia/utils/download.py')
-rw-r--r-- | src/mlia/utils/download.py | 62 |
1 files changed, 32 insertions, 30 deletions
diff --git a/src/mlia/utils/download.py b/src/mlia/utils/download.py index e00be28..2b06fed 100644 --- a/src/mlia/utils/download.py +++ b/src/mlia/utils/download.py @@ -1,10 +1,11 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Utils for files downloading.""" from __future__ import annotations from dataclasses import dataclass from pathlib import Path +from typing import Callable from typing import Iterable import requests @@ -40,20 +41,46 @@ def download_progress( yield chunk +@dataclass +class DownloadConfig: + """Parameters to download an artifact.""" + + url: str + sha256_hash: str + header_gen_fn: Callable[[], dict[str, str]] | None = None + + @property + def filename(self) -> str: + """Get the filename from the URL.""" + return self.url.rsplit("/", 1)[-1] + + @property + def headers(self) -> dict[str, str]: + """Get the headers using the header_gen_fn.""" + return self.header_gen_fn() if self.header_gen_fn else {} + + def download( - url: str, dest: Path, + cfg: DownloadConfig, show_progress: bool = False, label: str | None = None, chunk_size: int = 8192, timeout: int = 30, ) -> None: """Download the file.""" - with requests.get(url, stream=True, timeout=timeout) as resp: + if dest.exists(): + raise FileExistsError(f"{dest} already exists.") + + with requests.get( + cfg.url, stream=True, timeout=timeout, headers=cfg.headers + ) as resp: resp.raise_for_status() content_chunks = resp.iter_content(chunk_size=chunk_size) if show_progress: + if not label: + label = f"Downloading to {dest}." content_length = parse_int(resp.headers.get("Content-Length")) content_chunks = download_progress(content_chunks, content_length, label) @@ -61,30 +88,5 @@ def download( for chunk in content_chunks: file.write(chunk) - -@dataclass -class DownloadArtifact: - """Download artifact attributes.""" - - name: str - url: str - filename: str - version: str - sha256_hash: str - - def download_to(self, dest_dir: Path, show_progress: bool = True) -> Path: - """Download artifact into destination directory.""" - if (dest := dest_dir / self.filename).exists(): - raise ValueError(f"{dest} already exists") - - download( - self.url, - dest, - show_progress=show_progress, - label=f"Downloading {self.name} ver. {self.version}", - ) - - if sha256(dest) != self.sha256_hash: - raise ValueError("Digests do not match") - - return dest + if cfg.sha256_hash and sha256(dest) != cfg.sha256_hash: + raise ValueError("Hashes do not match.") |