aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/test/QLstmEndToEndTestImpl.cpp
blob: 7c87f358d6096cd26a5e3694562226be0f9021c3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
//
// Copyright © 2020 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include "QLstmEndToEndTestImpl.hpp"

#include "CommonTestUtils.hpp"
#include "EndToEndTestImpl.hpp"

#include <armnn/INetwork.hpp>
#include <armnn/LstmParams.hpp>

#include <doctest/doctest.h>

namespace
{

// Checks if two values of an arithmetic type are close enough to each other
// with regard to a given tolerance value.
template<typename T>
typename std::enable_if<std::is_arithmetic<T>::value, bool>::type
IsCloseEnough(T value1, T value2, T tolerance)
{
    if (tolerance < 0)
    {
        throw armnn::InvalidArgumentException("Tolerance cannot be < 0");
    }

    T diff = value1 >= value2 ? static_cast<T>(value1 - value2) : static_cast<T>(value2 - value1);
    return diff <= tolerance;
}

} // anonymous namespace

void QLstmEndToEnd(const std::vector<armnn::BackendId>& backends)
{
    const unsigned int numBatches = 2;
    const unsigned int inputSize  = 5;
    const unsigned int outputSize = 4;
    const unsigned int numUnits   = 4;

    bool cifgEnabled       = true;
    bool peepholeEnabled   = false;
    bool projectionEnabled = false;
    bool layerNormEnabled  = true;

    // Scale/Offset quantization info
    const float inputScale    = 0.0078125f;
    const int32_t inputOffset = 0;

    const int32_t hiddenStateZeroPoint = 0;
    const float hiddenStateScale       = 0.007f;

    // if (!projectionEnabled) outputScale == hiddenStateScale
    const float outputScale    = hiddenStateScale;
    const int32_t outputOffset = hiddenStateZeroPoint;

    const float cellStateScale    = 3.05176e-05f;
    const int32_t cellStateOffset = 0;

    const float weightsScale    = 0.00784314f;
    const int32_t weightsOffset = 0;

    const float layerNormScale    = 3.05182e-05f;
    const int32_t layerNormOffset = 0;

    const float biasScale    = layerNormScale / 1024;
    const int32_t biasOffset = 0;

    const float inputIntermediateScale  = 0.007059f;
    const float forgetIntermediateScale = 0.007812f;
    const float cellIntermediateScale   = inputIntermediateScale;
    const float outputIntermediateScale = forgetIntermediateScale;

    const float cellClip       = 0.0f;
    const float projectionClip = 0.0f;

    // Weights and bias tensor info
    const armnn::TensorInfo inputWeightsInfo({outputSize, inputSize},
                                             armnn::DataType::QSymmS8,
                                             weightsScale,
                                             weightsOffset,
                                             true);

    const armnn::TensorInfo recurrentWeightsInfo({outputSize, outputSize},
                                                 armnn::DataType::QSymmS8,
                                                 weightsScale,
                                                 weightsOffset,
                                                 true);

    const armnn::TensorInfo biasInfo({outputSize},
                                     armnn::DataType::Signed32,
                                     biasScale,
                                     biasOffset,
                                     true);

    const armnn::TensorInfo layerNormWeightsInfo({numUnits},
                                                 armnn::DataType::QSymmS16,
                                                 layerNormScale,
                                                 layerNormOffset,
                                                 true);

    // Mandatory params
    const std::vector<int8_t> inputToForgetWeightsVector =
            {-77, -13, 38, 25, 115, -64, -25, -51, 38, -102, -51, 38, -64, -51, -77, 38, -51, -77, -64, -64};
    const std::vector<int8_t> inputToCellWeightsTensorVector =
            {-51, -38, -25, -13, -64, 64, -25, -38, -25, -77, 77, -13, -51, -38, -89, 89, -115, -64, 102, 77};
    const std::vector<int8_t> inputToOutputWeightsTensorVector =
            {-102, -51, -25, -115, -13, -89, 38, -38, -102, -25, 77, -25, 51, -89, -38, -64, 13, 64, -77, -51};

    armnn::ConstTensor inputToForgetWeightsTensor(inputWeightsInfo, inputToForgetWeightsVector.data());
    armnn::ConstTensor inputToCellWeightsTensor(inputWeightsInfo, inputToCellWeightsTensorVector.data());
    armnn::ConstTensor inputToOutputWeightsTensor(inputWeightsInfo, inputToOutputWeightsTensorVector.data());

    const std::vector<int8_t> recurrentToForgetWeightsTensorVector =
            {-64, -38, -64, -25, 77, 51, 115, 38, -13, 25, 64, 25, 25, 38, -13, 51};
    const std::vector<int8_t> recurrentToCellWeightsTensorVector =
            {-38, 25, 13, -38, 102, -10, -25, 38, 102, -77, -13, 25, 38, -13, 25, 64};
    const std::vector<int8_t> recurrentToOutputWeightsTensorVector =
            {38, -13, 13, -25, -64, -89, -25, -77, -13, -51, -89, -25, 13, 64, 25, -38};

    armnn::ConstTensor recurrentToForgetWeightsTensor(recurrentWeightsInfo,
                                                      recurrentToForgetWeightsTensorVector.data());
    armnn::ConstTensor recurrentToCellWeightsTensor(recurrentWeightsInfo,
                                                    recurrentToCellWeightsTensorVector.data());
    armnn::ConstTensor recurrentToOutputWeightsTensor(recurrentWeightsInfo,
                                                      recurrentToOutputWeightsTensorVector.data());

    const std::vector<int32_t> forgetGateBiasTensorVector = {2147484, -6442451, -4294968, 2147484};
    const std::vector<int32_t> cellBiasTensorVector       = {-1073742, 15461883, 5368709, 1717987};
    const std::vector<int32_t> outputGateBiasTensorVector = {1073742, -214748, 4294968, 2147484};

    armnn::ConstTensor forgetGateBiasTensor(biasInfo, forgetGateBiasTensorVector.data());
    armnn::ConstTensor cellBiasTensor(biasInfo, cellBiasTensorVector.data());
    armnn::ConstTensor outputGateBiasTensor(biasInfo, outputGateBiasTensorVector.data());

    // Layer Norm
    const std::vector<int16_t> forgetLayerNormWeightsVector = {6553, 6553, 13107, 9830};
    const std::vector<int16_t> cellLayerNormWeightsVector   = {22937, 6553, 9830, 26214};
    const std::vector<int16_t> outputLayerNormWeightsVector = {19660, 6553, 6553, 16384};

    armnn::ConstTensor forgetLayerNormWeights(layerNormWeightsInfo, forgetLayerNormWeightsVector.data());
    armnn::ConstTensor cellLayerNormWeights(layerNormWeightsInfo, cellLayerNormWeightsVector.data());
    armnn::ConstTensor outputLayerNormWeights(layerNormWeightsInfo, outputLayerNormWeightsVector.data());

    // Set up params
    armnn::LstmInputParams params;
    params.m_InputToForgetWeights = &inputToForgetWeightsTensor;
    params.m_InputToCellWeights   = &inputToCellWeightsTensor;
    params.m_InputToOutputWeights = &inputToOutputWeightsTensor;

    params.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
    params.m_RecurrentToCellWeights   = &recurrentToCellWeightsTensor;
    params.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;

    params.m_ForgetGateBias = &forgetGateBiasTensor;
    params.m_CellBias       = &cellBiasTensor;
    params.m_OutputGateBias = &outputGateBiasTensor;

    params.m_ForgetLayerNormWeights = &forgetLayerNormWeights;
    params.m_CellLayerNormWeights   = &cellLayerNormWeights;
    params.m_OutputLayerNormWeights = &outputLayerNormWeights;

    QLstmDescriptor descriptor;
    descriptor.m_CifgEnabled       = cifgEnabled;
    descriptor.m_PeepholeEnabled   = peepholeEnabled;
    descriptor.m_ProjectionEnabled = projectionEnabled;
    descriptor.m_LayerNormEnabled  = layerNormEnabled;

    descriptor.m_CellClip       = cellClip;
    descriptor.m_ProjectionClip = projectionClip;

    descriptor.m_HiddenStateZeroPoint = hiddenStateZeroPoint;
    descriptor.m_HiddenStateScale     = hiddenStateScale;

    descriptor.m_InputIntermediateScale  = inputIntermediateScale;
    descriptor.m_ForgetIntermediateScale = forgetIntermediateScale;
    descriptor.m_CellIntermediateScale   = cellIntermediateScale;
    descriptor.m_OutputIntermediateScale = outputIntermediateScale;

    // Input/Output tensor info
    const armnn::TensorInfo inputInfo({numBatches , inputSize},
                                      armnn::DataType::QAsymmS8,
                                      inputScale,
                                      inputOffset,
                                      true);

    const armnn::TensorInfo cellStateInfo({numBatches , numUnits},
                                          armnn::DataType::QSymmS16,
                                          cellStateScale,
                                          cellStateOffset,
                                          true);

    const armnn::TensorInfo outputStateInfo({numBatches , outputSize},
                                            armnn::DataType::QAsymmS8,
                                            outputScale,
                                            outputOffset,
                                            true);

    // Input tensor data
    const std::vector<int8_t> inputVector         = {90, 102, 13, 26, 38, 102, 13, 26, 51, 64};
    const std::vector<int8_t> outputStateInVector = {0, 0, 0, 0, 0, 0, 0, 0};
    const std::vector<int16_t> cellStateInVector  = {0, 0, 0, 0, 0, 0, 0, 0};

    // Expected output tensor data
    const std::vector<int8_t> outputStateOutVector = {-15, 21, 14, 20, -15, 15, 5, 27};
    const std::vector<int16_t> cellStateOutVector  = {-11692, 9960, 5491, 8861, -9422, 7726, 2056, 13149};
    const std::vector<int8_t> outputVector         = {-15, 21, 14, 20, -15, 15, 5, 27};

    // Build network
    armnn::INetworkPtr net(armnn::INetwork::Create());

    armnn::IConnectableLayer* const input         = net->AddInputLayer(0);
    armnn::IConnectableLayer* const outputStateIn = net->AddInputLayer(1);
    armnn::IConnectableLayer* const cellStateIn   = net->AddInputLayer(2);

    armnn::IConnectableLayer* const qLstmLayer = net->AddQLstmLayer(descriptor, params, "qLstm");

    armnn::IConnectableLayer* const outputStateOut = net->AddOutputLayer(0);
    armnn::IConnectableLayer* const cellStateOut   = net->AddOutputLayer(1);
    armnn::IConnectableLayer* const output         = net->AddOutputLayer(2);

    // Connect input/output slots
    Connect(input, qLstmLayer, inputInfo, 0, 0);
    Connect(outputStateIn, qLstmLayer, outputStateInfo, 0, 1);
    Connect(cellStateIn, qLstmLayer, cellStateInfo, 0, 2);

    Connect(qLstmLayer, outputStateOut, outputStateInfo, 0, 0);
    Connect(qLstmLayer, cellStateOut, cellStateInfo, 1, 0);
    Connect(qLstmLayer, output, outputStateInfo, 2, 0);

    // Create runtime
    IRuntime::CreationOptions options;
    IRuntimePtr runtime(IRuntime::Create(options));

    // Optimize the network
    IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec());

    // Loads network into runtime
    NetworkId netId;
    runtime->LoadNetwork(netId, std::move(optNet));

    // Push back input tensors
    InputTensors inputTensors;
    inputTensors.reserve(3);

    inputTensors.push_back({0, ConstTensor(runtime->GetInputTensorInfo(netId, 0), inputVector.data())});
    inputTensors.push_back({1, ConstTensor(runtime->GetInputTensorInfo(netId, 1), outputStateInVector.data())});
    inputTensors.push_back({2, ConstTensor(runtime->GetInputTensorInfo(netId, 2), cellStateInVector.data())});

    // Push back output tensors
    OutputTensors outputTensors;
    outputTensors.reserve(3);

    std::vector<int8_t> outputStateOutResult(outputStateOutVector.size());
    std::vector<int16_t> cellStateOutResult(cellStateOutVector.size());
    std::vector<int8_t> outputResult(outputStateOutVector.size());

    outputTensors.push_back({0, Tensor(runtime->GetOutputTensorInfo(netId, 0), outputStateOutResult.data())});
    outputTensors.push_back({1, Tensor(runtime->GetOutputTensorInfo(netId, 1), cellStateOutResult.data())});
    outputTensors.push_back({2, Tensor(runtime->GetOutputTensorInfo(netId, 2), outputResult.data())});

    // Execute inference
    runtime->EnqueueWorkload(netId, inputTensors, outputTensors);

    constexpr int8_t toleranceInt8 = 1;
    for (unsigned int i = 0u; i < outputStateOutResult.size(); ++i)
    {
        CHECK(IsCloseEnough(outputStateOutVector[i], outputStateOutResult[i], toleranceInt8));
    }

    for (unsigned int i = 0u; i < outputResult.size(); ++i)
    {
        CHECK(IsCloseEnough(outputVector[i], outputResult[i], toleranceInt8));
    }
}