aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/test/FullyConnected.cpp
blob: 2853fe96abdd8f87c7684959d267f92b760a3ee0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include <boost/test/unit_test.hpp>
#include "ParserFlatbuffersFixture.hpp"
#include "../TfLiteParser.hpp"

#include <string>
#include <iostream>

BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)

struct FullyConnectedFixture : public ParserFlatbuffersFixture
{
    explicit FullyConnectedFixture(const std::string& inputShape,
                                           const std::string& outputShape,
                                           const std::string& filterShape,
                                           const std::string& filterData,
                                           const std::string biasShape = "",
                                           const std::string biasData = "")
    {
        std::string inputTensors = "[ 0, 2 ]";
        std::string biasTensor = "";
        std::string biasBuffer = "";
        if (biasShape.size() > 0 && biasData.size() > 0)
        {
            inputTensors = "[ 0, 2, 3 ]";
            biasTensor = R"(
                        {
                            "shape": )" + biasShape + R"( ,
                            "type": "INT32",
                            "buffer": 3,
                            "name": "biasTensor",
                            "quantization": {
                                "min": [ 0.0 ],
                                "max": [ 255.0 ],
                                "scale": [ 1.0 ],
                                "zero_point": [ 0 ],
                            }
                        } )";
            biasBuffer = R"(
                    { "data": )" + biasData + R"(, }, )";
        }
        m_JsonString = R"(
            {
                "version": 3,
                "operator_codes": [ { "builtin_code": "FULLY_CONNECTED" } ],
                "subgraphs": [ {
                    "tensors": [
                        {
                            "shape": )" + inputShape + R"(,
                            "type": "UINT8",
                            "buffer": 0,
                            "name": "inputTensor",
                            "quantization": {
                                "min": [ 0.0 ],
                                "max": [ 255.0 ],
                                "scale": [ 1.0 ],
                                "zero_point": [ 0 ],
                            }
                        },
                        {
                            "shape": )" + outputShape + R"(,
                            "type": "UINT8",
                            "buffer": 1,
                            "name": "outputTensor",
                            "quantization": {
                                "min": [ 0.0 ],
                                "max": [ 511.0 ],
                                "scale": [ 2.0 ],
                                "zero_point": [ 0 ],
                            }
                        },
                        {
                            "shape": )" + filterShape + R"(,
                            "type": "UINT8",
                            "buffer": 2,
                            "name": "filterTensor",
                            "quantization": {
                                "min": [ 0.0 ],
                                "max": [ 255.0 ],
                                "scale": [ 1.0 ],
                                "zero_point": [ 0 ],
                            }
                        }, )" + biasTensor + R"(
                    ],
                    "inputs": [ 0 ],
                    "outputs": [ 1 ],
                    "operators": [
                        {
                            "opcode_index": 0,
                            "inputs": )" + inputTensors + R"(,
                            "outputs": [ 1 ],
                            "builtin_options_type": "FullyConnectedOptions",
                            "builtin_options": {
                                "fused_activation_function": "NONE"
                            },
                            "custom_options_format": "FLEXBUFFERS"
                        }
                    ],
                } ],
                "buffers" : [
                    { },
                    { },
                    { "data": )" + filterData + R"(, }, )"
                       + biasBuffer + R"(
                ]
            }
        )";
        SetupSingleInputSingleOutput("inputTensor", "outputTensor");
    }
};

struct FullyConnectedWithNoBiasFixture : FullyConnectedFixture
{
    FullyConnectedWithNoBiasFixture()
        : FullyConnectedFixture("[ 1, 4, 1, 1 ]",     // inputShape
                                "[ 1, 1 ]",           // outputShape
                                "[ 4, 1 ]",           // filterShape
                                "[ 2, 3, 4, 5 ]")     // filterData
    {}
};

BOOST_FIXTURE_TEST_CASE(FullyConnectedWithNoBias, FullyConnectedWithNoBiasFixture)
{
    RunTest<2, uint8_t>(
        0,
        { 10, 20, 30, 40 },
        { 400/2 });
}

struct FullyConnectedWithBiasFixture : FullyConnectedFixture
{
    FullyConnectedWithBiasFixture()
        : FullyConnectedFixture("[ 1, 4, 1, 1 ]",     // inputShape
                                "[ 1, 1 ]",           // outputShape
                                "[ 4, 1 ]",           // filterShape
                                "[ 2, 3, 4, 5 ]",     // filterData
                                "[ 1 ]",              // biasShape
                                "[ 10, 0, 0, 0 ]" )   // biasData
    {}
};

BOOST_FIXTURE_TEST_CASE(ParseFullyConnectedWithBias, FullyConnectedWithBiasFixture)
{
    RunTest<2, uint8_t>(
        0,
        { 10, 20, 30, 40 },
        { (400+10)/2 });
}

BOOST_AUTO_TEST_SUITE_END()