aboutsummaryrefslogtreecommitdiff
path: root/delegate/python/test/utils.py
blob: f3761ec8a1149f8727f45f0c58257de535375e5b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT

import tflite_runtime.interpreter as tflite
import numpy as np
import os


def run_mock_model(delegate, test_data_folder):
    model_path = os.path.join(test_data_folder, 'mock_model.tflite')
    interpreter = tflite.Interpreter(model_path=model_path,
                                     experimental_delegates=[delegate])
    interpreter.allocate_tensors()

    # Get input and output tensors.
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    # Test model on random input data.
    input_shape = input_details[0]['shape']
    input_data = np.array(np.random.random_sample(input_shape), dtype=np.uint8)
    interpreter.set_tensor(input_details[0]['index'], input_data)

    interpreter.invoke()

def run_inference(test_data_folder, model_filename, inputs, delegates=None):
    model_path = os.path.join(test_data_folder, model_filename)
    interpreter = tflite.Interpreter(model_path=model_path,
                                     experimental_delegates=delegates)
    interpreter.allocate_tensors()

    # Get input and output tensors.
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    # Set inputs to tensors.
    for i in range(len(inputs)):
        interpreter.set_tensor(input_details[i]['index'], inputs[i])

    interpreter.invoke()

    results = []
    for output in output_details:
        results.append(interpreter.get_tensor(output['index']))

    return results

def compare_outputs(outputs, expected_outputs):
    assert len(outputs) == len(expected_outputs), 'Incorrect number of outputs'
    for i in range(len(expected_outputs)):
        assert outputs[i].shape == expected_outputs[i].shape, 'Incorrect output shape on output#{}'.format(i)
        assert outputs[i].dtype == expected_outputs[i].dtype, 'Incorrect output data type on output#{}'.format(i)
        assert outputs[i].all() == expected_outputs[i].all(), 'Incorrect output value on output#{}'.format(i)