aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/utils/download.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/utils/download.py')
-rw-r--r--src/mlia/utils/download.py62
1 files changed, 32 insertions, 30 deletions
diff --git a/src/mlia/utils/download.py b/src/mlia/utils/download.py
index e00be28..2b06fed 100644
--- a/src/mlia/utils/download.py
+++ b/src/mlia/utils/download.py
@@ -1,10 +1,11 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Utils for files downloading."""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
+from typing import Callable
from typing import Iterable
import requests
@@ -40,20 +41,46 @@ def download_progress(
yield chunk
+@dataclass
+class DownloadConfig:
+ """Parameters to download an artifact."""
+
+ url: str
+ sha256_hash: str
+ header_gen_fn: Callable[[], dict[str, str]] | None = None
+
+ @property
+ def filename(self) -> str:
+ """Get the filename from the URL."""
+ return self.url.rsplit("/", 1)[-1]
+
+ @property
+ def headers(self) -> dict[str, str]:
+ """Get the headers using the header_gen_fn."""
+ return self.header_gen_fn() if self.header_gen_fn else {}
+
+
def download(
- url: str,
dest: Path,
+ cfg: DownloadConfig,
show_progress: bool = False,
label: str | None = None,
chunk_size: int = 8192,
timeout: int = 30,
) -> None:
"""Download the file."""
- with requests.get(url, stream=True, timeout=timeout) as resp:
+ if dest.exists():
+ raise FileExistsError(f"{dest} already exists.")
+
+ with requests.get(
+ cfg.url, stream=True, timeout=timeout, headers=cfg.headers
+ ) as resp:
resp.raise_for_status()
content_chunks = resp.iter_content(chunk_size=chunk_size)
if show_progress:
+ if not label:
+ label = f"Downloading to {dest}."
content_length = parse_int(resp.headers.get("Content-Length"))
content_chunks = download_progress(content_chunks, content_length, label)
@@ -61,30 +88,5 @@ def download(
for chunk in content_chunks:
file.write(chunk)
-
-@dataclass
-class DownloadArtifact:
- """Download artifact attributes."""
-
- name: str
- url: str
- filename: str
- version: str
- sha256_hash: str
-
- def download_to(self, dest_dir: Path, show_progress: bool = True) -> Path:
- """Download artifact into destination directory."""
- if (dest := dest_dir / self.filename).exists():
- raise ValueError(f"{dest} already exists")
-
- download(
- self.url,
- dest,
- show_progress=show_progress,
- label=f"Downloading {self.name} ver. {self.version}",
- )
-
- if sha256(dest) != self.sha256_hash:
- raise ValueError("Digests do not match")
-
- return dest
+ if cfg.sha256_hash and sha256(dest) != cfg.sha256_hash:
+ raise ValueError("Hashes do not match.")