aboutsummaryrefslogtreecommitdiff
path: root/driver_library/python/test/testdata/download.py
blob: 18aa9afd8d395105e09f9eea7b35e9455ccbaab8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#
# SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
import os
from pathlib import Path
from typing import List
from urllib.request import urlopen
"""
Downloads resources for tests from Arm public model zoo.
Run this script before executing tests.
"""


PMZ_URL = 'https://github.com/ARM-software/ML-zoo/raw/9f506fe52b39df545f0e6c5ff9223f671bc5ae00/models'
test_resources = [
    {'model': '{}/visual_wake_words/micronet_vww2/tflite_int8/vww2_50_50_INT8.tflite'.format(PMZ_URL),
     'ifm': '{}/visual_wake_words/micronet_vww2/tflite_int8/testing_input/input/0.npy'.format(PMZ_URL),
     'ofm': '{}/visual_wake_words/micronet_vww2/tflite_int8/testing_output/Identity/0.npy'.format(PMZ_URL)}
]


def download(path: str, url: str):
    with urlopen(url) as response, open(path, 'wb') as file:
        print("Downloading {} ...".format(url))
        file.write(response.read())
        file.seek(0)
        print("Finished downloading {}.".format(url))


def download_test_resources(test_res_entries: List[dict], where_to: str):
    os.makedirs(where_to, exist_ok=True)

    for resources in test_res_entries:
        download(os.path.join(where_to, 'model.tflite'), resources['model'])
        download(os.path.join(where_to, 'model_ifm.npy'), resources['ifm'])
        download(os.path.join(where_to, 'model_ofm.npy'), resources['ofm'])


def main():
    current_dir = str(Path(__file__).parent.absolute())
    download_test_resources(test_resources, os.path.join(current_dir, 'shared'))


if __name__ == '__main__':
    main()