aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/examples/common/network_executor_tflite.py
blob: 10f5e6e6fbe6c4e69f024a27be53537106eb813e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT

import os
from typing import List, Tuple

import numpy as np
from tflite_runtime import interpreter as tflite

class TFLiteNetworkExecutor:

    def __init__(self, model_file: str, backends: list, delegate_path: str):
        """
        Creates an inference executor for a given network and a list of backends.

        Args:
            model_file: User-specified model file.
            backends: List of backends to optimize network.
            delegate_path: tflite delegate file path (.so).
        """
        self.model_file = model_file
        self.backends = backends
        self.delegate_path = delegate_path
        self.interpreter, self.input_details, self.output_details = self.create_network()

    def run(self, input_data_list: list) -> List[np.ndarray]:
        """
        Executes inference for the loaded network.

        Args:
            input_data_list: List of input frames.

        Returns:
            list: Inference results as a list of ndarrays.
        """
        output = []
        for index, input_data in enumerate(input_data_list):
            self.interpreter.set_tensor(self.input_details[index]['index'], input_data)
        self.interpreter.invoke()
        for curr_output in self.output_details:
            output.append(self.interpreter.get_tensor(curr_output['index']))

        return output

    def create_network(self):
        """
        Creates a network based on the model file and a list of backends.

        Returns:
            interpreter: A TensorFlow Lite object for executing inference.
            input_details: Contains essential information about the model input.
            output_details: Used to map output tensor and its memory.
        """

        # Controls whether optimizations are used or not.
        # Please note that optimizations can improve performance in some cases, but it can also
        # degrade the performance in other cases. Accuracy might also be affected.

        optimization_enable = "true"

        if not os.path.exists(self.model_file):
            raise FileNotFoundError(f'Model file not found for: {self.model_file}')

        _, ext = os.path.splitext(self.model_file)
        if ext == '.tflite':
            armnn_delegate = tflite.load_delegate(library=self.delegate_path,
                                                  options={"backends": ','.join(self.backends), "logging-severity": "info",
                                                           "enable-fast-math": optimization_enable,
                                                           "reduce-fp32-to-fp16": optimization_enable})
            interpreter = tflite.Interpreter(model_path=self.model_file,
                                             experimental_delegates=[armnn_delegate])
            interpreter.allocate_tensors()
        else:
            raise ValueError("Supplied model file type is not supported. Supported types are [ tflite ]")

        # Get input and output binding information
        input_details = interpreter.get_input_details()
        output_details = interpreter.get_output_details()

        return interpreter, input_details, output_details

    def get_data_type(self):
        """
        Get the input data type of the initiated network.

        Returns:
            numpy data type or None if doesn't exist in the if condition.
        """
        return self.input_details[0]['dtype']

    def get_shape(self):
        """
        Get the input shape of the initiated network.

        Returns:
            tuple: The Shape of the network input.
        """
        return tuple(self.input_details[0]['shape'])