aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/RefLstmWorkload.cpp
blob: 1ff6f50ed53bdedabe13fe6f043d45d41d4a2672 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include "RefLstmWorkload.hpp"
#include "Activation.hpp"
#include "Encoders.hpp"
#include "Decoders.hpp"
#include "Lstm.hpp"
#include "LstmUtils.hpp"
#include "RefWorkloadUtils.hpp"

namespace armnn
{

RefLstmWorkload::RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
    : BaseWorkload<LstmQueueDescriptor>(descriptor, info)
    , m_InputToInputWeightsTensor     (AssignScopedTensorHandle(descriptor.m_InputToInputWeights))
    , m_InputToForgetWeightsTensor    (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights))
    , m_InputToCellWeightsTensor      (AssignScopedTensorHandle(descriptor.m_InputToCellWeights))
    , m_InputToOutputWeightsTensor    (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights))
    , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights))
    , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights))
    , m_RecurrentToCellWeightsTensor  (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights))
    , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights))
    , m_CellToInputWeightsTensor      (AssignScopedTensorHandle(descriptor.m_CellToInputWeights))
    , m_CellToForgetWeightsTensor     (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights))
    , m_CellToOutputWeightsTensor     (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights))
    , m_InputGateBiasTensor           (AssignScopedTensorHandle(descriptor.m_InputGateBias))
    , m_ForgetGateBiasTensor          (AssignScopedTensorHandle(descriptor.m_ForgetGateBias))
    , m_CellBiasTensor                (AssignScopedTensorHandle(descriptor.m_CellBias))
    , m_OutputGateBiasTensor          (AssignScopedTensorHandle(descriptor.m_OutputGateBias))
    , m_ProjectionWeightsTensor       (AssignScopedTensorHandle(descriptor.m_ProjectionWeights))
    , m_ProjectionBiasTensor          (AssignScopedTensorHandle(descriptor.m_ProjectionBias))
    , m_InputLayerNormWeights         (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights))
    , m_ForgetLayerNormWeights        (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights))
    , m_CellLayerNormWeights          (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights))
    , m_OutputLayerNormWeights        (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights))
{}

void RefLstmWorkload::Execute() const
{
    Execute(m_Data.m_Inputs, m_Data.m_Outputs);
}

void RefLstmWorkload::ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor)
{
    Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs);
}

void RefLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
{
    // This is a porting of the LSTM::Eval() method in the Android code base
    // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp

    const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
    const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);

    const TensorShape& inputShape = inputInfo.GetShape();

    std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo, outputs[1]->Map());
    std::unique_ptr<Encoder<float>> cellStateOut   = MakeEncoder<float>(outputInfo, outputs[2]->Map());
    std::unique_ptr<Encoder<float>> output         = MakeEncoder<float>(outputInfo, outputs[3]->Map());

    std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(outputInfo, outputs[2]->Map());
    std::unique_ptr<Decoder<float>> outputDecoder       = MakeDecoder<float>(outputInfo, outputs[3]->Map());

    std::unique_ptr<Decoder<float>> inputData     = MakeDecoder<float>(inputInfo, inputs[0]->Map());
    std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(inputInfo, inputs[1]->Map());
    std::unique_ptr<Decoder<float>> cellStateIn   = MakeDecoder<float>(inputInfo, inputs[2]->Map());

    const uint32_t nBatch = inputShape[0];
    const uint32_t nCell   = m_InputToOutputWeightsTensor->GetShape()[0];

    const bool useCifg      = m_Data.m_Parameters.m_CifgEnabled;
    const bool usePeephole  = m_Data.m_Parameters.m_PeepholeEnabled;
    const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled;

    // Index the scratch buffers pointers to the global scratch buffer.
    std::unique_ptr<Encoder<float>> inputGateScratch  = MakeEncoder<float>(outputInfo, outputs[0]->Map());
    std::unique_ptr<Encoder<float>> cellScratch       = MakeEncoder<float>(outputInfo, outputs[0]->Map());
    std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
    std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());

    std::unique_ptr<Decoder<float>> inputGateScratchDecoder =
        MakeDecoder<float>(outputInfo, outputs[0]->Map());
    std::unique_ptr<Decoder<float>> cellScratchDecoder =
        MakeDecoder<float>(outputInfo, outputs[0]->Map());
    std::unique_ptr<Decoder<float>> forgetGateScratchDecoder =
        MakeDecoder<float>(outputInfo, outputs[0]->Map());
    std::unique_ptr<Decoder<float>> outputGateScratchDecoder =
        MakeDecoder<float>(outputInfo, outputs[0]->Map());

    if (useCifg)
    {
        *cellScratch       += (0 * nCell * nBatch);
        *forgetGateScratch += (1 * nCell * nBatch);
        *outputGateScratch += (2 * nCell * nBatch);

        *cellScratchDecoder       += (0 * nCell * nBatch);
        *forgetGateScratchDecoder += (1 * nCell * nBatch);
        *outputGateScratchDecoder += (2 * nCell * nBatch);
    }
    else
    {
        *inputGateScratch  += (0 * nCell * nBatch);
        *cellScratch       += (1 * nCell * nBatch);
        *forgetGateScratch += (2 * nCell * nBatch);
        *outputGateScratch += (3 * nCell * nBatch);

        *inputGateScratchDecoder  += (0 * nCell * nBatch);
        *cellScratchDecoder       += (1 * nCell * nBatch);
        *forgetGateScratchDecoder += (2 * nCell * nBatch);
        *outputGateScratchDecoder += (3 * nCell * nBatch);
    }

    std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
    std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
        m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>());
    std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
        m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>());
    std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
        m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>());

    std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
    std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
        m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<void>());
    std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
        m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>());
    std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
        m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<void>());

    std::unique_ptr<Decoder<float>> inputGateBiasTensor;
    std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
        m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<void>());
    std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
        m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<void>());
    std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
        m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<void>());

    std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
    std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
    std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;

    std::unique_ptr<Decoder<float>> projectionWeightsTensor;
    std::unique_ptr<Decoder<float>> projectionBiasTensor;

    std::unique_ptr<Decoder<float>> inputLayerNormWeights;
    std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
    std::unique_ptr<Decoder<float>> cellLayerNormWeights;
    std::unique_ptr<Decoder<float>> outputLayerNormWeights;

    const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
    const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();

    if (useLayerNorm)
    {
        if (!useCifg)
        {
            inputLayerNormWeights = MakeDecoder<float>(
                    m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<void>());
        }
        forgetLayerNormWeights = MakeDecoder<float>(
                m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<void>());
        cellLayerNormWeights = MakeDecoder<float>(
                m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<void>());
        outputLayerNormWeights = MakeDecoder<float>(
                m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<void>());
    }

    if (!useCifg)
    {
        inputToInputWeightsTensor = MakeDecoder<float>(
            m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>());
        inputGateBiasTensor = MakeDecoder<float>(
            m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<void>());
        recurrentToInputWeightsTensor = MakeDecoder<float>(
            m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<void>());
    }

    if (usePeephole)
    {
        cellToForgetWeightsTensor = MakeDecoder<float>(
            m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>());
        cellToOutputWeightsTensor = MakeDecoder<float>(
            m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>());
    }

    if (!useCifg && usePeephole)
    {
        cellToInputWeightsTensor = MakeDecoder<float>(
            m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>());
    }

    if (m_Data.m_Parameters.m_ProjectionEnabled)
    {
        projectionWeightsTensor = MakeDecoder<float>(
            m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>());
        if (m_ProjectionBiasTensor)
        {
            projectionBiasTensor = MakeDecoder<float>(
                m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>());
        }
    }

    LstmImpl(m_Data.m_Parameters,
                 inputInfo,
                 outputInfo,
                 inputToOutputWeightsShape,
                 recurrentToOutputWeightsShape,
                 inputData,
                 outputStateIn,
                 cellStateIn,
                 outputStateOut,
                 cellStateOut,
                 output,
                 cellStateOutDecoder,
                 outputDecoder,
                 inputToInputWeightsTensor,
                 inputToForgetWeightsTensor,
                 inputToCellWeightsTensor,
                 inputToOutputWeightsTensor,
                 recurrentToInputWeightsTensor,
                 recurrentToForgetWeightsTensor,
                 recurrentToCellWeightsTensor,
                 recurrentToOutputWeightsTensor,
                 cellToInputWeightsTensor,
                 cellToForgetWeightsTensor,
                 cellToOutputWeightsTensor,
                 inputGateBiasTensor,
                 forgetGateBiasTensor,
                 cellBiasTensor,
                 outputGateBiasTensor,
                 projectionWeightsTensor,
                 projectionBiasTensor,
                 inputLayerNormWeights,
                 forgetLayerNormWeights,
                 cellLayerNormWeights,
                 outputLayerNormWeights,
                 inputGateScratch,
                 cellScratch,
                 forgetGateScratch,
                 outputGateScratch,
                 inputGateScratchDecoder,
                 cellScratchDecoder,
                 forgetGateScratchDecoder,
                 outputGateScratchDecoder,
                 m_LayerNormEpsilon);
}

} //namespace armnn