ArmNN
 22.05
RefUnidirectionalSequenceLstmWorkload.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
7 #include "Activation.hpp"
8 #include "Encoders.hpp"
9 #include "Decoders.hpp"
10 #include "Lstm.hpp"
11 #include "LstmUtils.hpp"
12 #include "RefWorkloadUtils.hpp"
13 
14 #include <armnnUtils/Permute.hpp>
15 
16 namespace armnn
17 {
18 
21  const WorkloadInfo& info)
23  , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights))
24  , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights))
25  , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights))
26  , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights))
27  , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights))
28  , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights))
29  , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights))
30  , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights))
31  , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights))
32  , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights))
33  , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights))
34  , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias))
35  , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias))
36  , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias))
37  , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias))
38  , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights))
39  , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias))
40  , m_InputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights))
41  , m_ForgetLayerNormWeights (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights))
42  , m_CellLayerNormWeights (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights))
43  , m_OutputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights))
44 {}
45 
47 {
49 }
50 
52 {
53  Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs);
54 }
55 
56 void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector<ITensorHandle*> inputs,
57  std::vector<ITensorHandle*> outputs) const
58 {
59  TensorInfo inputInfo = GetTensorInfo(inputs[0]);
60  const TensorInfo& outputStateInfo = GetTensorInfo(inputs[1]);
61  const TensorInfo& cellStateInfo = GetTensorInfo(inputs[2]);
62  TensorInfo outputStateOutInfo = GetTensorInfo(outputs[0]);
63  TensorInfo cellStateOutInfo = GetTensorInfo(outputs[1]);
64  TensorInfo outputInfo = GetTensorInfo(outputs[2]);
65  TensorShape& inputShape = inputInfo.GetShape();
66  TensorShape& outputShape= outputInfo.GetShape();
67  auto inputTensor = reinterpret_cast<float*>(inputs[0]->Map());
68 
70  {
71  // Permute to time major
72  const PermutationVector& mappings = {1U, 0U, 2U};
73  std::vector<float> inputValue(inputTensor, inputTensor + inputInfo.GetNumElements());
74  inputShape = armnnUtils::Permuted(inputInfo.GetShape(), mappings);
75  inputInfo.SetShape(inputShape);
76  armnnUtils::Permute(inputShape, mappings, inputValue.data(), inputTensor, sizeof(float));
77 
78  outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings);
79  outputInfo.SetShape(outputShape);
80  }
81  unsigned int maxTime = inputShape[0];
82  unsigned int batchSize = inputShape[1];
83  unsigned int outputSize = outputShape[2];
84  unsigned int inputSize = inputShape[2];
85 
86  TensorInfo scratchInfo = outputInfo;
87  scratchInfo.SetShape({batchSize, cellStateInfo.GetShape()[1]});
88 
89  std::vector<float> inputGateScratchBuffer;
90  std::vector<float> cellScratchBuffer(scratchInfo.GetNumElements(), 0.);
91  std::vector<float> forgetGateScratchBuffer(scratchInfo.GetNumElements(), 0.);
92  std::vector<float> outputGateScratchBuffer(scratchInfo.GetNumElements(), 0.);
93 
94  std::vector<float> outputStateOutBuffer(outputStateInfo.GetNumElements(), 0.);
95  std::vector<float> cellStateOutBuffer(cellStateInfo.GetNumElements(), 0.);
96 
97  void* outputStateOutData = outputStateOutBuffer.data();
98  void* cellStateOutData = cellStateOutBuffer.data();
99 
100  std::unique_ptr<Encoder<float>> inputGateScratch;
101  std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(scratchInfo, cellScratchBuffer.data());
102  std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(scratchInfo, forgetGateScratchBuffer.data());
103  std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(scratchInfo, outputGateScratchBuffer.data());
104 
105  std::unique_ptr<Decoder<float>> inputGateScratchDecoder;
106  std::unique_ptr<Decoder<float>> cellScratchDecoder = MakeDecoder<float>(scratchInfo, cellScratchBuffer.data());
107  std::unique_ptr<Decoder<float>> forgetGateScratchDecoder = MakeDecoder<float>(scratchInfo,
108  forgetGateScratchBuffer.data());
109  std::unique_ptr<Decoder<float>> outputGateScratchDecoder = MakeDecoder<float>(scratchInfo,
110  outputGateScratchBuffer.data());
111 
112  const bool useCifg = m_Data.m_Parameters.m_CifgEnabled;
113  const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled;
114  const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled;
115 
116  if (!useCifg)
117  {
118  inputGateScratchBuffer.resize(scratchInfo.GetNumElements(), 0.);
119  inputGateScratch = MakeEncoder<float>(scratchInfo, inputGateScratchBuffer.data());
120  inputGateScratchDecoder = MakeDecoder<float>(scratchInfo, inputGateScratchBuffer.data());
121  }
122 
123  std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputStateInfo, outputStateOutData);
124  std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(cellStateInfo, cellStateOutData);
125  std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(cellStateInfo, cellStateOutData);
126 
127  TensorInfo lstmInputInfo = inputInfo;
128  TensorShape batchInputShape = TensorShape({batchSize, inputSize});
129  lstmInputInfo.SetShape(batchInputShape);
130 
131  TensorInfo lstmOutputInfo = outputInfo;
132  lstmOutputInfo.SetShape({batchSize, outputSize});
133 
134  const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
135  const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();
136  unsigned int nOutput = recurrentToOutputWeightsShape[1];
137  auto outputStateInData = inputs[1]->Map();
138  std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateInData);
139 
140  auto cellStateInData = inputs[2]->Map();
141  std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateInData);
142 
143  auto currentInputData = reinterpret_cast<float*>(inputs[0]->Map());
144  std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
145  auto currentOutputData = reinterpret_cast<float*>(outputs[2]->Map());
146  std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
147  std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
148 
149  std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
150  std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
151  m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>());
152  std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
153  m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>());
154  std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
155  m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>());
156 
157  std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
158  std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
159  m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<void>());
160  std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
161  m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>());
162  std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
163  m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<void>());
164 
165  std::unique_ptr<Decoder<float>> inputGateBiasTensor;
166  std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
167  m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<void>());
168  std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
169  m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<void>());
170  std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
171  m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<void>());
172 
173  std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
174  std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
175  std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
176 
177  std::unique_ptr<Decoder<float>> projectionWeightsTensor;
178  std::unique_ptr<Decoder<float>> projectionBiasTensor;
179 
180  std::unique_ptr<Decoder<float>> inputLayerNormWeights;
181  std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
182  std::unique_ptr<Decoder<float>> cellLayerNormWeights;
183  std::unique_ptr<Decoder<float>> outputLayerNormWeights;
184 
185  if (useLayerNorm)
186  {
187  if (!useCifg)
188  {
189  inputLayerNormWeights = MakeDecoder<float>(
190  m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<void>());
191  }
192  forgetLayerNormWeights = MakeDecoder<float>(
193  m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<void>());
194  cellLayerNormWeights = MakeDecoder<float>(
195  m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<void>());
196  outputLayerNormWeights = MakeDecoder<float>(
197  m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<void>());
198  }
199 
200  if (!useCifg)
201  {
202  inputToInputWeightsTensor = MakeDecoder<float>(
203  m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>());
204  inputGateBiasTensor = MakeDecoder<float>(
205  m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<void>());
206  recurrentToInputWeightsTensor = MakeDecoder<float>(
207  m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<void>());
208  }
209 
210  if (usePeephole)
211  {
212  cellToForgetWeightsTensor = MakeDecoder<float>(
213  m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>());
214  cellToOutputWeightsTensor = MakeDecoder<float>(
215  m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>());
216  }
217 
218  if (!useCifg && usePeephole)
219  {
220  cellToInputWeightsTensor = MakeDecoder<float>(
221  m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>());
222  }
223 
225  {
226  projectionWeightsTensor = MakeDecoder<float>(
227  m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>());
228  if (m_ProjectionBiasTensor)
229  {
230  projectionBiasTensor = MakeDecoder<float>(
231  m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>());
232  }
233  }
234 
235  unsigned int batchInputSize = batchSize * inputSize;
236  unsigned int batchOutputSize = batchSize * nOutput;
237 
238  for (unsigned int t = 0; t < maxTime; ++t)
239  {
241  lstmInputInfo,
242  lstmOutputInfo,
243  inputToOutputWeightsShape,
244  recurrentToOutputWeightsShape,
245  inputData,
246  outputStateIn,
247  cellStateIn,
248  outputStateOut,
249  cellStateOut,
250  output,
251  cellStateOutDecoder,
252  outputDecoder,
253  inputToInputWeightsTensor,
254  inputToForgetWeightsTensor,
255  inputToCellWeightsTensor,
256  inputToOutputWeightsTensor,
257  recurrentToInputWeightsTensor,
258  recurrentToForgetWeightsTensor,
259  recurrentToCellWeightsTensor,
260  recurrentToOutputWeightsTensor,
261  cellToInputWeightsTensor,
262  cellToForgetWeightsTensor,
263  cellToOutputWeightsTensor,
264  inputGateBiasTensor,
265  forgetGateBiasTensor,
266  cellBiasTensor,
267  outputGateBiasTensor,
268  projectionWeightsTensor,
269  projectionBiasTensor,
270  inputLayerNormWeights,
271  forgetLayerNormWeights,
272  cellLayerNormWeights,
273  outputLayerNormWeights,
274  inputGateScratch,
275  cellScratch,
276  forgetGateScratch,
277  outputGateScratch,
278  inputGateScratchDecoder,
279  cellScratchDecoder,
280  forgetGateScratchDecoder,
281  outputGateScratchDecoder,
282  m_LayerNormEpsilon);
283 
284  currentInputData += batchInputSize;
285  inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
286  currentOutputData += batchOutputSize;
287  output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
288  outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
289 
290  // Assign output state out to the next output state in
291  outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateOutData);
292 
293  // Assign cell state out to the next cell state in
294  cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateOutData);
295  }
296 
298  {
299  // Permute Output back to batch major
300  const PermutationVector& mappings = {1U, 0U, 2U};
301  auto outputData = reinterpret_cast<float*>(outputs[2]->Map());
302  std::vector<float> outputValue(outputData, outputData + outputInfo.GetNumElements());
303  outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings);
304  outputInfo.SetShape(outputShape);
305  armnnUtils::Permute(outputShape, mappings, outputValue.data(), outputData, sizeof(float));
306  }
307 }
308 
309 } //namespace armnn
bool m_ProjectionEnabled
Enable/disable the projection layer.
const TensorShape & GetShape() const
Definition: Tensor.hpp:191
void LstmImpl(const LstmDescriptor &descriptor, const TensorInfo &inputInfo, const TensorInfo &outputInfo, const TensorShape &inputToOutputWeightsShape, const TensorShape &recurrentToOutputWeightsShape, std::unique_ptr< Decoder< float >> &inputData, std::unique_ptr< Decoder< float >> &outputStateIn, std::unique_ptr< Decoder< float >> &cellStateIn, std::unique_ptr< Encoder< float >> &outputStateOut, std::unique_ptr< Encoder< float >> &cellStateOut, std::unique_ptr< Encoder< float >> &output, std::unique_ptr< Decoder< float >> &cellStateOutDecoder, std::unique_ptr< Decoder< float >> &outputDecoder, std::unique_ptr< Decoder< float >> &inputToInputWeightsTensor, std::unique_ptr< Decoder< float >> &inputToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &inputToCellWeightsTensor, std::unique_ptr< Decoder< float >> &inputToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToInputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToCellWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToInputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &cellToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &inputGateBiasTensor, std::unique_ptr< Decoder< float >> &forgetGateBiasTensor, std::unique_ptr< Decoder< float >> &cellBiasTensor, std::unique_ptr< Decoder< float >> &outputGateBiasTensor, std::unique_ptr< Decoder< float >> &projectionWeightsTensor, std::unique_ptr< Decoder< float >> &projectionBiasTensor, std::unique_ptr< Decoder< float >> &inputLayerNormWeights, std::unique_ptr< Decoder< float >> &forgetLayerNormWeights, std::unique_ptr< Decoder< float >> &cellLayerNormWeights, std::unique_ptr< Decoder< float >> &outputLayerNormWeights, std::unique_ptr< Encoder< float >> &inputGateScratch, std::unique_ptr< Encoder< float >> &cellScratch, std::unique_ptr< Encoder< float >> &forgetGateScratch, std::unique_ptr< Encoder< float >> &outputGateScratch, std::unique_ptr< Decoder< float >> &inputGateScratchDecoder, std::unique_ptr< Decoder< float >> &cellScratchDecoder, std::unique_ptr< Decoder< float >> &forgetGateScratchDecoder, std::unique_ptr< Decoder< float >> &outputGateScratchDecoder, float layerNormEpsilon)
Definition: Lstm.cpp:13
std::unique_ptr< armnn::ScopedTensorHandle > AssignScopedTensorHandle(const armnn::ConstTensorHandle *ptr)
Definition: LstmUtils.cpp:299
bool m_TimeMajor
Enable/disable time major.
Copyright (c) 2021 ARM Limited and Contributors.
void SetShape(const TensorShape &newShape)
Definition: Tensor.hpp:193
UnidirectionalSequenceLstmQueueDescriptor m_Data
Definition: Workload.hpp:81
bool m_PeepholeEnabled
Enable/disable peephole.
bool m_CifgEnabled
Enable/disable cifg (coupled input & forget gate).
std::vector< ITensorHandle * > m_Outputs
void ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor) override
bool m_LayerNormEnabled
Enable/disable layer normalization.
Contains information about TensorInfos of a layer.
std::vector< ITensorHandle * > m_Inputs
const TensorInfo & GetTensorInfo(const ITensorHandle *tensorHandle)
float32 helpers
armnn::TensorShape Permuted(const armnn::TensorShape &srcShape, const armnn::PermutationVector &mappings)
Definition: Permute.cpp:98
RefUnidirectionalSequenceLstmWorkload(const UnidirectionalSequenceLstmQueueDescriptor &descriptor, const WorkloadInfo &info)
unsigned int GetNumElements() const
Definition: Tensor.hpp:196