aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/examples/image_classification/example_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyarmnn/examples/image_classification/example_utils.py')
-rw-r--r--python/pyarmnn/examples/image_classification/example_utils.py20
1 files changed, 10 insertions, 10 deletions
diff --git a/python/pyarmnn/examples/image_classification/example_utils.py b/python/pyarmnn/examples/image_classification/example_utils.py
index 090ce2f5b3..f0ba91e981 100644
--- a/python/pyarmnn/examples/image_classification/example_utils.py
+++ b/python/pyarmnn/examples/image_classification/example_utils.py
@@ -38,7 +38,8 @@ def run_inference(runtime, net_id, images, labels, input_binding_info, output_bi
runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
# Process output
- out_tensor = ann.workload_tensors_to_ndarray(output_tensors)[0]
+ # output tensor has a shape (1, 1001)
+ out_tensor = ann.workload_tensors_to_ndarray(output_tensors)[0][0]
results = np.argsort(out_tensor)[::-1]
print_top_n(5, results, labels, out_tensor)
@@ -121,7 +122,7 @@ def __create_network(model_file: str, backends: list, parser=None):
return net_id, parser, runtime
-def create_tflite_network(model_file: str, backends: list = ['CpuAcc', 'CpuRef']):
+def create_tflite_network(model_file: str, backends: list = ('CpuAcc', 'CpuRef')):
"""Creates a network from a tflite model file.
Args:
@@ -140,7 +141,7 @@ def create_tflite_network(model_file: str, backends: list = ['CpuAcc', 'CpuRef']
return net_id, graph_id, parser, runtime
-def create_onnx_network(model_file: str, backends: list = ['CpuAcc', 'CpuRef']):
+def create_onnx_network(model_file: str, backends: list = ('CpuAcc', 'CpuRef')):
"""Creates a network from an onnx model file.
Args:
@@ -181,7 +182,7 @@ def preprocess_default(img: Image, width: int, height: int, data_type, scale: fl
def load_images(image_files: list, input_width: int, input_height: int, data_type=np.uint8,
- scale: float = 1., mean: list = [0., 0., 0.], stddev: list = [1., 1., 1.],
+ scale: float = 1., mean: list = (0., 0., 0.), stddev: list = (1., 1., 1.),
preprocess_fn=preprocess_default):
"""Loads images, resizes and performs any additional preprocessing to run inference.
@@ -218,7 +219,6 @@ def load_labels(label_file: str):
with open(label_file, 'r') as f:
labels = [l.rstrip() for l in f]
return labels
- return None
def print_top_n(N: int, results: list, labels: list, prob: list):
@@ -299,10 +299,10 @@ def get_model_and_labels(model_dir: str, model: str, labels: str, archive: str =
download_url = [download_url]
for dl in download_url:
archive = download_file(dl)
- if dl.lower().endswith(".zip"):
- unzip_file(archive)
+ if dl.lower().endswith(".zip"):
+ unzip_file(archive)
except RuntimeError:
- print("Unable to download file ({}).".format(archive_url))
+ print("Unable to download file ({}).".format(download_url))
if not os.path.exists(labels) or not os.path.exists(model):
raise RuntimeError("Unable to provide model and labels.")
@@ -310,7 +310,7 @@ def get_model_and_labels(model_dir: str, model: str, labels: str, archive: str =
return model, labels
-def list_images(folder: str = None, formats: list = ['.jpg', '.jpeg']):
+def list_images(folder: str = None, formats: list = ('.jpg', '.jpeg')):
"""Lists files of a certain format in a folder.
Args:
@@ -338,7 +338,7 @@ def get_images(image_dir: str, image_url: str = DEFAULT_IMAGE_URL):
"""Gets image.
Args:
- image (str): Image filename
+ image_dir (str): Image filename
image_url (str): Image url
Returns: