ArmNN
 21.08
RefLstmWorkload.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "RefLstmWorkload.hpp"
7 #include "Activation.hpp"
8 #include "Encoders.hpp"
9 #include "Decoders.hpp"
10 #include "Lstm.hpp"
11 #include "LstmUtils.hpp"
12 #include "RefWorkloadUtils.hpp"
13 
14 namespace armnn
15 {
16 
18  : BaseWorkload<LstmQueueDescriptor>(descriptor, info)
19  , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights))
20  , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights))
21  , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights))
22  , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights))
23  , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights))
24  , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights))
25  , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights))
26  , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights))
27  , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights))
28  , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights))
29  , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights))
30  , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias))
31  , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias))
32  , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias))
33  , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias))
34  , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights))
35  , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias))
36  , m_InputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights))
37  , m_ForgetLayerNormWeights (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights))
38  , m_CellLayerNormWeights (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights))
39  , m_OutputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights))
40 {}
41 
43 {
45 }
46 
48 {
49  Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs);
50 }
51 
52 void RefLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
53 {
54  // This is a porting of the LSTM::Eval() method in the Android code base
55  // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp
56 
57  const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
58  const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
59 
60  const TensorShape& inputShape = inputInfo.GetShape();
61 
62  std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo, outputs[1]->Map());
63  std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(outputInfo, outputs[2]->Map());
64  std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(outputInfo, outputs[3]->Map());
65 
66  std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(outputInfo, outputs[2]->Map());
67  std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(outputInfo, outputs[3]->Map());
68 
69  std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(inputInfo, inputs[0]->Map());
70  std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(inputInfo, inputs[1]->Map());
71  std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(inputInfo, inputs[2]->Map());
72 
73  const uint32_t nBatch = inputShape[0];
74  const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0];
75 
76  const bool useCifg = m_Data.m_Parameters.m_CifgEnabled;
77  const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled;
78  const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled;
79 
80  // Index the scratch buffers pointers to the global scratch buffer.
81  std::unique_ptr<Encoder<float>> inputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
82  std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
83  std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
84  std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
85 
86  std::unique_ptr<Decoder<float>> inputGateScratchDecoder =
87  MakeDecoder<float>(outputInfo, outputs[0]->Map());
88  std::unique_ptr<Decoder<float>> cellScratchDecoder =
89  MakeDecoder<float>(outputInfo, outputs[0]->Map());
90  std::unique_ptr<Decoder<float>> forgetGateScratchDecoder =
91  MakeDecoder<float>(outputInfo, outputs[0]->Map());
92  std::unique_ptr<Decoder<float>> outputGateScratchDecoder =
93  MakeDecoder<float>(outputInfo, outputs[0]->Map());
94 
95  if (useCifg)
96  {
97  *cellScratch += (0 * nCell * nBatch);
98  *forgetGateScratch += (1 * nCell * nBatch);
99  *outputGateScratch += (2 * nCell * nBatch);
100 
101  *cellScratchDecoder += (0 * nCell * nBatch);
102  *forgetGateScratchDecoder += (1 * nCell * nBatch);
103  *outputGateScratchDecoder += (2 * nCell * nBatch);
104  }
105  else
106  {
107  *inputGateScratch += (0 * nCell * nBatch);
108  *cellScratch += (1 * nCell * nBatch);
109  *forgetGateScratch += (2 * nCell * nBatch);
110  *outputGateScratch += (3 * nCell * nBatch);
111 
112  *inputGateScratchDecoder += (0 * nCell * nBatch);
113  *cellScratchDecoder += (1 * nCell * nBatch);
114  *forgetGateScratchDecoder += (2 * nCell * nBatch);
115  *outputGateScratchDecoder += (3 * nCell * nBatch);
116  }
117 
118  std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
119  std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
120  m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>());
121  std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
122  m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>());
123  std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
124  m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>());
125 
126  std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
127  std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
128  m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<void>());
129  std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
130  m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>());
131  std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
132  m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<void>());
133 
134  std::unique_ptr<Decoder<float>> inputGateBiasTensor;
135  std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
136  m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<void>());
137  std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
138  m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<void>());
139  std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
140  m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<void>());
141 
142  std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
143  std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
144  std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
145 
146  std::unique_ptr<Decoder<float>> projectionWeightsTensor;
147  std::unique_ptr<Decoder<float>> projectionBiasTensor;
148 
149  std::unique_ptr<Decoder<float>> inputLayerNormWeights;
150  std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
151  std::unique_ptr<Decoder<float>> cellLayerNormWeights;
152  std::unique_ptr<Decoder<float>> outputLayerNormWeights;
153 
154  const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
155  const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();
156 
157  if (useLayerNorm)
158  {
159  if (!useCifg)
160  {
161  inputLayerNormWeights = MakeDecoder<float>(
162  m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<void>());
163  }
164  forgetLayerNormWeights = MakeDecoder<float>(
165  m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<void>());
166  cellLayerNormWeights = MakeDecoder<float>(
167  m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<void>());
168  outputLayerNormWeights = MakeDecoder<float>(
169  m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<void>());
170  }
171 
172  if (!useCifg)
173  {
174  inputToInputWeightsTensor = MakeDecoder<float>(
175  m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>());
176  inputGateBiasTensor = MakeDecoder<float>(
177  m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<void>());
178  recurrentToInputWeightsTensor = MakeDecoder<float>(
179  m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<void>());
180  }
181 
182  if (usePeephole)
183  {
184  cellToForgetWeightsTensor = MakeDecoder<float>(
185  m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>());
186  cellToOutputWeightsTensor = MakeDecoder<float>(
187  m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>());
188  }
189 
190  if (!useCifg && usePeephole)
191  {
192  cellToInputWeightsTensor = MakeDecoder<float>(
193  m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>());
194  }
195 
197  {
198  projectionWeightsTensor = MakeDecoder<float>(
199  m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>());
200  if (m_ProjectionBiasTensor)
201  {
202  projectionBiasTensor = MakeDecoder<float>(
203  m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>());
204  }
205  }
206 
208  inputInfo,
209  outputInfo,
210  inputToOutputWeightsShape,
211  recurrentToOutputWeightsShape,
212  inputData,
213  outputStateIn,
214  cellStateIn,
215  outputStateOut,
216  cellStateOut,
217  output,
218  cellStateOutDecoder,
219  outputDecoder,
220  inputToInputWeightsTensor,
221  inputToForgetWeightsTensor,
222  inputToCellWeightsTensor,
223  inputToOutputWeightsTensor,
224  recurrentToInputWeightsTensor,
225  recurrentToForgetWeightsTensor,
226  recurrentToCellWeightsTensor,
227  recurrentToOutputWeightsTensor,
228  cellToInputWeightsTensor,
229  cellToForgetWeightsTensor,
230  cellToOutputWeightsTensor,
231  inputGateBiasTensor,
232  forgetGateBiasTensor,
233  cellBiasTensor,
234  outputGateBiasTensor,
235  projectionWeightsTensor,
236  projectionBiasTensor,
237  inputLayerNormWeights,
238  forgetLayerNormWeights,
239  cellLayerNormWeights,
240  outputLayerNormWeights,
241  inputGateScratch,
242  cellScratch,
243  forgetGateScratch,
244  outputGateScratch,
245  inputGateScratchDecoder,
246  cellScratchDecoder,
247  forgetGateScratchDecoder,
248  outputGateScratchDecoder,
249  m_LayerNormEpsilon);
250 }
251 
252 } //namespace armnn
bool m_ProjectionEnabled
Enable/disable the projection layer.
const TensorShape & GetShape() const
Definition: Tensor.hpp:191
void LstmImpl(const LstmDescriptor &descriptor, const TensorInfo &inputInfo, const TensorInfo &outputInfo, const TensorShape &inputToOutputWeightsShape, const TensorShape &recurrentToOutputWeightsShape, std::unique_ptr< Decoder< float >> &inputData, std::unique_ptr< Decoder< float >> &outputStateIn, std::unique_ptr< Decoder< float >> &cellStateIn, std::unique_ptr< Encoder< float >> &outputStateOut, std::unique_ptr< Encoder< float >> &cellStateOut, std::unique_ptr< Encoder< float >> &output, std::unique_ptr< Decoder< float >> &cellStateOutDecoder, std::unique_ptr< Decoder< float >> &outputDecoder, std::unique_ptr< Decoder< float >> &inputToInputWeightsTensor, std::unique_ptr< Decoder< float >> &inputToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &inputToCellWeightsTensor, std::unique_ptr< Decoder< float >> &inputToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToInputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToCellWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToInputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &cellToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &inputGateBiasTensor, std::unique_ptr< Decoder< float >> &forgetGateBiasTensor, std::unique_ptr< Decoder< float >> &cellBiasTensor, std::unique_ptr< Decoder< float >> &outputGateBiasTensor, std::unique_ptr< Decoder< float >> &projectionWeightsTensor, std::unique_ptr< Decoder< float >> &projectionBiasTensor, std::unique_ptr< Decoder< float >> &inputLayerNormWeights, std::unique_ptr< Decoder< float >> &forgetLayerNormWeights, std::unique_ptr< Decoder< float >> &cellLayerNormWeights, std::unique_ptr< Decoder< float >> &outputLayerNormWeights, std::unique_ptr< Encoder< float >> &inputGateScratch, std::unique_ptr< Encoder< float >> &cellScratch, std::unique_ptr< Encoder< float >> &forgetGateScratch, std::unique_ptr< Encoder< float >> &outputGateScratch, std::unique_ptr< Decoder< float >> &inputGateScratchDecoder, std::unique_ptr< Decoder< float >> &cellScratchDecoder, std::unique_ptr< Decoder< float >> &forgetGateScratchDecoder, std::unique_ptr< Decoder< float >> &outputGateScratchDecoder, float layerNormEpsilon)
Definition: Lstm.cpp:13
RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
std::unique_ptr< armnn::ScopedTensorHandle > AssignScopedTensorHandle(const armnn::ConstTensorHandle *ptr)
Definition: LstmUtils.cpp:299
Copyright (c) 2021 ARM Limited and Contributors.
void ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor) override
bool m_PeepholeEnabled
Enable/disable peephole.
void Execute() const override
bool m_CifgEnabled
Enable/disable cifg (coupled input & forget gate).
std::vector< ITensorHandle * > m_Outputs
bool m_LayerNormEnabled
Enable/disable layer normalization.
Contains information about TensorInfos of a layer.
std::vector< ITensorHandle * > m_Inputs
const TensorInfo & GetTensorInfo(const ITensorHandle *tensorHandle)
float32 helpers