aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/test/test_tensor_conversion.py
blob: bfff200e49431de963eb1f26b8e50ac434318e52 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# Copyright © 2019 Arm Ltd. All rights reserved.
# SPDX-License-Identifier: MIT
import os

import pytest
import pyarmnn as ann
import numpy as np


@pytest.fixture(scope="function")
def get_tensor_info_input(shared_data_folder):
    """
    Sample input tensor information.
    """
    parser = ann.ITfLiteParser()
    parser.CreateNetworkFromBinaryFile(os.path.join(shared_data_folder, 'ssd_mobilenetv1.tflite'))
    graph_id = 0

    input_binding_info = [parser.GetNetworkInputBindingInfo(graph_id, 'normalized_input_image_tensor')]

    yield input_binding_info


@pytest.fixture(scope="function")
def get_tensor_info_output(shared_data_folder):
    """
    Sample output tensor information.
    """
    parser = ann.ITfLiteParser()
    parser.CreateNetworkFromBinaryFile(os.path.join(shared_data_folder, 'ssd_mobilenetv1.tflite'))
    graph_id = 0

    output_names = parser.GetSubgraphOutputTensorNames(graph_id)
    outputs_binding_info = []

    for output_name in output_names:
        outputs_binding_info.append(parser.GetNetworkOutputBindingInfo(graph_id, output_name))

    yield outputs_binding_info


def test_make_input_tensors(get_tensor_info_input):
    input_tensor_info = get_tensor_info_input
    input_data = []

    for tensor_id, tensor_info in input_tensor_info:
        input_data.append(np.random.randint(0, 255, size=(1, tensor_info.GetNumElements())).astype(np.uint8))

    input_tensors = ann.make_input_tensors(input_tensor_info, input_data)
    assert len(input_tensors) == 1

    for tensor, tensor_info in zip(input_tensors, input_tensor_info):
        # Because we created ConstTensor function, we cannot check type directly.
        assert type(tensor[1]).__name__ == 'ConstTensor'
        assert str(tensor[1].GetInfo()) == str(tensor_info[1])


def test_make_output_tensors(get_tensor_info_output):
    output_binding_info = get_tensor_info_output

    output_tensors = ann.make_output_tensors(output_binding_info)
    assert len(output_tensors) == 4

    for tensor, tensor_info in zip(output_tensors, output_binding_info):
        assert type(tensor[1]) == ann.Tensor
        assert str(tensor[1].GetInfo()) == str(tensor_info[1])


def test_workload_tensors_to_ndarray(get_tensor_info_output):
    output_binding_info = get_tensor_info_output
    output_tensors = ann.make_output_tensors(output_binding_info)

    data = ann.workload_tensors_to_ndarray(output_tensors)

    for i in range(0, len(output_tensors)):
        assert len(data[i]) == output_tensors[i][1].GetNumElements()


def test_make_input_tensors_fp16(get_tensor_info_input):
    # Check ConstTensor with float16
    input_tensor_info = get_tensor_info_input
    input_data = []

    for tensor_id, tensor_info in input_tensor_info:
        input_data.append(np.random.randint(0, 255, size=(1, tensor_info.GetNumElements())).astype(np.float16))
        tensor_info.SetDataType(ann.DataType_Float16) # set datatype to float16

    input_tensors = ann.make_input_tensors(input_tensor_info, input_data)
    assert len(input_tensors) == 1

    for tensor, tensor_info in zip(input_tensors, input_tensor_info):
        # Because we created ConstTensor function, we cannot check type directly.
        assert type(tensor[1]).__name__ == 'ConstTensor'
        assert str(tensor[1].GetInfo()) == str(tensor_info[1])
        assert tensor[1].GetDataType() == ann.DataType_Float16
        assert tensor[1].GetNumElements() == 270000
        assert tensor[1].GetNumBytes() == 540000  # check each element is two byte