aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/scripts/download_test_resources.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyarmnn/scripts/download_test_resources.py')
-rw-r--r--python/pyarmnn/scripts/download_test_resources.py44
1 files changed, 44 insertions, 0 deletions
diff --git a/python/pyarmnn/scripts/download_test_resources.py b/python/pyarmnn/scripts/download_test_resources.py
new file mode 100644
index 0000000000..b166ed77be
--- /dev/null
+++ b/python/pyarmnn/scripts/download_test_resources.py
@@ -0,0 +1,44 @@
+"""Downloads and extracts resources for unit tests.
+
+It is mandatory to run this script prior to running unit tests. Resources are stored as a tar.gz or a tar.bz2 archive and
+extracted into the test/testdata/shared folder.
+"""
+
+import tarfile
+import requests
+import os
+import uuid
+
+SCRIPTS_DIR = os.path.dirname(os.path.realpath(__file__))
+EXTRACT_DIR = os.path.join(SCRIPTS_DIR, "..", "test")
+ARCHIVE_URL = "https://snapshots.linaro.org/components/pyarmnn-tests/pyarmnn_testdata_200500_20200415.tar.bz2"
+
+
+def download_resources(url, save_path):
+ # download archive - only support tar.gz or tar.bz2
+ print("Downloading '{}'".format(url))
+ temp_filename = str(uuid.uuid4())
+ if url.endswith(".tar.bz2"):
+ temp_filename += ".tar.bz2"
+ elif url.endswith(".tar.gz"):
+ temp_filename += ".tar.gz"
+ else:
+ raise RuntimeError("Unsupported file.")
+ try:
+ r = requests.get(url, stream=True)
+ except requests.exceptions.RequestException as e:
+ raise RuntimeError("Unable to download file: {}".format(e))
+ file_path = os.path.join(save_path, temp_filename)
+ with open(file_path, 'wb') as f:
+ f.write(r.content)
+
+ # extract and delete temp file
+ with tarfile.open(file_path, "r:bz2" if temp_filename.endswith(".tar.bz2") else "r:gz") as tar:
+ print("Extracting '{}'".format(file_path))
+ tar.extractall(save_path)
+ if os.path.exists(file_path):
+ print("Removing '{}'".format(file_path))
+ os.remove(file_path)
+
+
+download_resources(ARCHIVE_URL, EXTRACT_DIR)