diff options
Diffstat (limited to 'python/pyarmnn/examples/image_classification/example_utils.py')
-rw-r--r-- | python/pyarmnn/examples/image_classification/example_utils.py | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/python/pyarmnn/examples/image_classification/example_utils.py b/python/pyarmnn/examples/image_classification/example_utils.py index 090ce2f5b3..f0ba91e981 100644 --- a/python/pyarmnn/examples/image_classification/example_utils.py +++ b/python/pyarmnn/examples/image_classification/example_utils.py @@ -38,7 +38,8 @@ def run_inference(runtime, net_id, images, labels, input_binding_info, output_bi runtime.EnqueueWorkload(net_id, input_tensors, output_tensors) # Process output - out_tensor = ann.workload_tensors_to_ndarray(output_tensors)[0] + # output tensor has a shape (1, 1001) + out_tensor = ann.workload_tensors_to_ndarray(output_tensors)[0][0] results = np.argsort(out_tensor)[::-1] print_top_n(5, results, labels, out_tensor) @@ -121,7 +122,7 @@ def __create_network(model_file: str, backends: list, parser=None): return net_id, parser, runtime -def create_tflite_network(model_file: str, backends: list = ['CpuAcc', 'CpuRef']): +def create_tflite_network(model_file: str, backends: list = ('CpuAcc', 'CpuRef')): """Creates a network from a tflite model file. Args: @@ -140,7 +141,7 @@ def create_tflite_network(model_file: str, backends: list = ['CpuAcc', 'CpuRef'] return net_id, graph_id, parser, runtime -def create_onnx_network(model_file: str, backends: list = ['CpuAcc', 'CpuRef']): +def create_onnx_network(model_file: str, backends: list = ('CpuAcc', 'CpuRef')): """Creates a network from an onnx model file. Args: @@ -181,7 +182,7 @@ def preprocess_default(img: Image, width: int, height: int, data_type, scale: fl def load_images(image_files: list, input_width: int, input_height: int, data_type=np.uint8, - scale: float = 1., mean: list = [0., 0., 0.], stddev: list = [1., 1., 1.], + scale: float = 1., mean: list = (0., 0., 0.), stddev: list = (1., 1., 1.), preprocess_fn=preprocess_default): """Loads images, resizes and performs any additional preprocessing to run inference. @@ -218,7 +219,6 @@ def load_labels(label_file: str): with open(label_file, 'r') as f: labels = [l.rstrip() for l in f] return labels - return None def print_top_n(N: int, results: list, labels: list, prob: list): @@ -299,10 +299,10 @@ def get_model_and_labels(model_dir: str, model: str, labels: str, archive: str = download_url = [download_url] for dl in download_url: archive = download_file(dl) - if dl.lower().endswith(".zip"): - unzip_file(archive) + if dl.lower().endswith(".zip"): + unzip_file(archive) except RuntimeError: - print("Unable to download file ({}).".format(archive_url)) + print("Unable to download file ({}).".format(download_url)) if not os.path.exists(labels) or not os.path.exists(model): raise RuntimeError("Unable to provide model and labels.") @@ -310,7 +310,7 @@ def get_model_and_labels(model_dir: str, model: str, labels: str, archive: str = return model, labels -def list_images(folder: str = None, formats: list = ['.jpg', '.jpeg']): +def list_images(folder: str = None, formats: list = ('.jpg', '.jpeg')): """Lists files of a certain format in a folder. Args: @@ -338,7 +338,7 @@ def get_images(image_dir: str, image_url: str = DEFAULT_IMAGE_URL): """Gets image. Args: - image (str): Image filename + image_dir (str): Image filename image_url (str): Image url Returns: |