From 0efca3cadbad5517a59884576ddb90cfe7ac30f8 Mon Sep 17 00:00:00 2001 From: Diego Russo Date: Mon, 30 May 2022 13:34:14 +0100 Subject: Add MLIA codebase Add MLIA codebase including sources and tests. Change-Id: Id41707559bd721edd114793618d12ccd188d8dbd --- src/mlia/utils/download.py | 89 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 src/mlia/utils/download.py (limited to 'src/mlia/utils/download.py') diff --git a/src/mlia/utils/download.py b/src/mlia/utils/download.py new file mode 100644 index 0000000..4658738 --- /dev/null +++ b/src/mlia/utils/download.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Utils for files downloading.""" +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable +from typing import List +from typing import Optional + +import requests +from rich.progress import BarColumn +from rich.progress import DownloadColumn +from rich.progress import FileSizeColumn +from rich.progress import Progress +from rich.progress import ProgressColumn +from rich.progress import TextColumn + +from mlia.utils.filesystem import sha256 +from mlia.utils.types import parse_int + + +def download_progress( + content_chunks: Iterable[bytes], content_length: Optional[int], label: Optional[str] +) -> Iterable[bytes]: + """Show progress info while reading content.""" + columns: List[ProgressColumn] = [TextColumn("{task.description}")] + + if content_length is None: + total = float("inf") + columns.append(FileSizeColumn()) + else: + total = content_length + columns.extend([BarColumn(), DownloadColumn(binary_units=True)]) + + with Progress(*columns) as progress: + task = progress.add_task(label or "Downloading", total=total) + + for chunk in content_chunks: + progress.update(task, advance=len(chunk)) + yield chunk + + +def download( + url: str, + dest: Path, + show_progress: bool = False, + label: Optional[str] = None, + chunk_size: int = 8192, +) -> None: + """Download the file.""" + with requests.get(url, stream=True) as resp: + resp.raise_for_status() + content_chunks = resp.iter_content(chunk_size=chunk_size) + + if show_progress: + content_length = parse_int(resp.headers.get("Content-Length")) + content_chunks = download_progress(content_chunks, content_length, label) + + with open(dest, "wb") as file: + for chunk in content_chunks: + file.write(chunk) + + +@dataclass +class DownloadArtifact: + """Download artifact attributes.""" + + name: str + url: str + filename: str + version: str + sha256_hash: str + + def download_to(self, dest_dir: Path, show_progress: bool = True) -> Path: + """Download artifact into destination directory.""" + if (dest := dest_dir / self.filename).exists(): + raise ValueError(f"{dest} already exists") + + download( + self.url, + dest, + show_progress=show_progress, + label=f"Downloading {self.name} ver. {self.version}", + ) + + if sha256(dest) != self.sha256_hash: + raise ValueError("Digests do not match") + + return dest -- cgit v1.2.1