aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/examples/common/network_executor.py
blob: 72262fc520e6f7d4b211ca4252989a2ba3de5ca8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT

import os
from typing import List, Tuple

import pyarmnn as ann
import numpy as np

class ArmnnNetworkExecutor:

    def __init__(self, model_file: str, backends: list):
        """
        Creates an inference executor for a given network and a list of backends.

        Args:
            model_file: User-specified model file.
            backends: List of backends to optimize network.
        """
        self.model_file = model_file
        self.backends = backends
        self.network_id, self.runtime, self.input_binding_info, self.output_binding_info = self.create_network()
        self.output_tensors = ann.make_output_tensors(self.output_binding_info)

    def run(self, input_data_list: list) -> List[np.ndarray]:
        """
        Creates input tensors from input data and executes inference with the loaded network.

        Args:
            input_data_list: List of input frames.

        Returns:
            list: Inference results as a list of ndarrays.
        """
        input_tensors = ann.make_input_tensors(self.input_binding_info, input_data_list)
        self.runtime.EnqueueWorkload(self.network_id, input_tensors, self.output_tensors)
        output = ann.workload_tensors_to_ndarray(self.output_tensors)

        return output

    def create_network(self):
        """
        Creates a network based on the model file and a list of backends.

        Returns:
            net_id: Unique ID of the network to run.
            runtime: Runtime context for executing inference.
            input_binding_info: Contains essential information about the model input.
            output_binding_info: Used to map output tensor and its memory.
        """
        if not os.path.exists(self.model_file):
            raise FileNotFoundError(f'Model file not found for: {self.model_file}')

        _, ext = os.path.splitext(self.model_file)
        if ext == '.tflite':
            parser = ann.ITfLiteParser()
        else:
            raise ValueError("Supplied model file type is not supported. Supported types are [ tflite ]")

        network = parser.CreateNetworkFromBinaryFile(self.model_file)

        # Specify backends to optimize network
        preferred_backends = []
        for b in self.backends:
            preferred_backends.append(ann.BackendId(b))

        # Select appropriate device context and optimize the network for that device
        options = ann.CreationOptions()
        runtime = ann.IRuntime(options)
        opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(),
                                             ann.OptimizerOptions())
        print(f'Preferred backends: {self.backends}\n{runtime.GetDeviceSpec()}\n'
              f'Optimization warnings: {messages}')

        # Load the optimized network onto the Runtime device
        net_id, _ = runtime.LoadNetwork(opt_network)

        # Get input and output binding information
        graph_id = parser.GetSubgraphCount() - 1
        input_names = parser.GetSubgraphInputTensorNames(graph_id)
        input_binding_info = []
        for input_name in input_names:
            in_bind_info = parser.GetNetworkInputBindingInfo(graph_id, input_name)
            input_binding_info.append(in_bind_info)
        output_names = parser.GetSubgraphOutputTensorNames(graph_id)
        output_binding_info = []
        for output_name in output_names:
            out_bind_info = parser.GetNetworkOutputBindingInfo(graph_id, output_name)
            output_binding_info.append(out_bind_info)
        return net_id, runtime, input_binding_info, output_binding_info

    def get_data_type(self):
        """
        Get the input data type of the initiated network.

        Returns:
            numpy data type or None if doesn't exist in the if condition.
        """
        if self.input_binding_info[0][1].GetDataType() == ann.DataType_Float32:
            return np.float32
        elif self.input_binding_info[0][1].GetDataType() == ann.DataType_QAsymmU8:
            return np.uint8
        elif self.input_binding_info[0][1].GetDataType() == ann.DataType_QAsymmS8:
            return np.int8
        else:
            return None

    def get_shape(self):
        """
        Get the input shape of the initiated network.

        Returns:
            tuple: The Shape of the network input.
        """
        return tuple(self.input_binding_info[0][1].GetShape())

    def get_input_quantization_scale(self, idx):
        """
        Get the input quantization scale of the initiated network.

        Returns:
            The quantization scale  of the network input.
        """
        return self.input_binding_info[idx][1].GetQuantizationScale()

    def get_input_quantization_offset(self, idx):
        """
        Get the input quantization offset of the initiated network.

        Returns:
            The quantization offset of the network input.
        """
        return self.input_binding_info[idx][1].GetQuantizationOffset()

    def is_output_quantized(self, idx):
        """
        Get True/False if output tensor is quantized or not respectively.

        Returns:
            True if output is quantized and False otherwise.
        """
        return self.output_binding_info[idx][1].IsQuantized()

    def get_output_quantization_scale(self, idx):
        """
        Get the output quantization offset of the initiated network.

        Returns:
            The quantization offset of the network output.
        """
        return self.output_binding_info[idx][1].GetQuantizationScale()

    def get_output_quantization_offset(self, idx):
        """
        Get the output quantization offset of the initiated network.

        Returns:
            The quantization offset of the network output.
        """
        return self.output_binding_info[idx][1].GetQuantizationOffset()