ArmNN
 22.02
RefUnidirectionalSequenceLstmWorkload.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
7 #include "Activation.hpp"
8 #include "Encoders.hpp"
9 #include "Decoders.hpp"
10 #include "Lstm.hpp"
11 #include "LstmUtils.hpp"
12 #include "RefWorkloadUtils.hpp"
13 
14 #include <armnnUtils/Permute.hpp>
15 
16 namespace armnn
17 {
18 
21  const WorkloadInfo& info)
23  , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights))
24  , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights))
25  , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights))
26  , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights))
27  , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights))
28  , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights))
29  , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights))
30  , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights))
31  , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights))
32  , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights))
33  , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights))
34  , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias))
35  , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias))
36  , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias))
37  , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias))
38  , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights))
39  , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias))
40  , m_InputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights))
41  , m_ForgetLayerNormWeights (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights))
42  , m_CellLayerNormWeights (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights))
43  , m_OutputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights))
44 {}
45 
47 {
49 }
50 
52 {
53  Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs);
54 }
55 
56 void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector<ITensorHandle*> inputs,
57  std::vector<ITensorHandle*> outputs) const
58 {
59  TensorInfo inputInfo = GetTensorInfo(inputs[0]);
60  const TensorInfo& outputStateInfo = GetTensorInfo(inputs[1]);
61  const TensorInfo& cellStateInfo = GetTensorInfo(inputs[2]);
62  TensorInfo outputInfo = GetTensorInfo(outputs[0]);
63  TensorShape& inputShape = inputInfo.GetShape();
64  TensorShape& outputShape= outputInfo.GetShape();
65  auto inputTensor = reinterpret_cast<float*>(inputs[0]->Map());
66 
68  {
69  // Permute to time major
70  const PermutationVector& mappings = {1U, 0U, 2U};
71  std::vector<float> inputValue(inputTensor, inputTensor + inputInfo.GetNumElements());
72  inputShape = armnnUtils::Permuted(inputInfo.GetShape(), mappings);
73  inputInfo.SetShape(inputShape);
74  armnnUtils::Permute(inputShape, mappings, inputValue.data(), inputTensor, sizeof(float));
75 
76  outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings);
77  outputInfo.SetShape(outputShape);
78  }
79  unsigned int maxTime = inputShape[0];
80  unsigned int batchSize = inputShape[1];
81  unsigned int outputSize = outputShape[2];
82  unsigned int inputSize = inputShape[2];
83 
84  TensorInfo scratchInfo = outputInfo;
85  scratchInfo.SetShape({batchSize, cellStateInfo.GetShape()[1]});
86 
87  std::vector<float> inputGateScratchBuffer;
88  std::vector<float> cellScratchBuffer(scratchInfo.GetNumElements(), 0.);
89  std::vector<float> forgetGateScratchBuffer(scratchInfo.GetNumElements(), 0.);
90  std::vector<float> outputGateScratchBuffer(scratchInfo.GetNumElements(), 0.);
91 
92  std::vector<float> outputStateOutBuffer(outputStateInfo.GetNumElements(), 0.);
93  std::vector<float> cellStateOutBuffer(cellStateInfo.GetNumElements(), 0.);
94 
95  void* outputStateOutData = outputStateOutBuffer.data();
96  void* cellStateOutData = cellStateOutBuffer.data();
97 
98  std::unique_ptr<Encoder<float>> inputGateScratch;
99  std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(scratchInfo, cellScratchBuffer.data());
100  std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(scratchInfo, forgetGateScratchBuffer.data());
101  std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(scratchInfo, outputGateScratchBuffer.data());
102 
103  std::unique_ptr<Decoder<float>> inputGateScratchDecoder;
104  std::unique_ptr<Decoder<float>> cellScratchDecoder = MakeDecoder<float>(scratchInfo, cellScratchBuffer.data());
105  std::unique_ptr<Decoder<float>> forgetGateScratchDecoder = MakeDecoder<float>(scratchInfo,
106  forgetGateScratchBuffer.data());
107  std::unique_ptr<Decoder<float>> outputGateScratchDecoder = MakeDecoder<float>(scratchInfo,
108  outputGateScratchBuffer.data());
109 
110  const bool useCifg = m_Data.m_Parameters.m_CifgEnabled;
111  const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled;
112  const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled;
113 
114  if (!useCifg)
115  {
116  inputGateScratchBuffer.resize(scratchInfo.GetNumElements(), 0.);
117  inputGateScratch = MakeEncoder<float>(scratchInfo, inputGateScratchBuffer.data());
118  inputGateScratchDecoder = MakeDecoder<float>(scratchInfo, inputGateScratchBuffer.data());
119  }
120 
121  std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputStateInfo, outputStateOutData);
122  std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(cellStateInfo, cellStateOutData);
123  std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(cellStateInfo, cellStateOutData);
124 
125  TensorInfo lstmInputInfo = inputInfo;
126  TensorShape batchInputShape = TensorShape({batchSize, inputSize});
127  lstmInputInfo.SetShape(batchInputShape);
128 
129  TensorInfo lstmOutputInfo = outputInfo;
130  lstmOutputInfo.SetShape({batchSize, outputSize});
131 
132  const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
133  const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();
134  unsigned int nOutput = recurrentToOutputWeightsShape[1];
135  auto outputStateInData = inputs[1]->Map();
136  std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateInData);
137 
138  auto cellStateInData = inputs[2]->Map();
139  std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateInData);
140 
141  auto currentInputData = reinterpret_cast<float*>(inputs[0]->Map());
142  std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
143  auto currentOutputData = reinterpret_cast<float*>(outputs[0]->Map());
144  std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
145  std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
146 
147  std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
148  std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
149  m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>());
150  std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
151  m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>());
152  std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
153  m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>());
154 
155  std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
156  std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
157  m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<void>());
158  std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
159  m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>());
160  std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
161  m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<void>());
162 
163  std::unique_ptr<Decoder<float>> inputGateBiasTensor;
164  std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
165  m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<void>());
166  std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
167  m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<void>());
168  std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
169  m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<void>());
170 
171  std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
172  std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
173  std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
174 
175  std::unique_ptr<Decoder<float>> projectionWeightsTensor;
176  std::unique_ptr<Decoder<float>> projectionBiasTensor;
177 
178  std::unique_ptr<Decoder<float>> inputLayerNormWeights;
179  std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
180  std::unique_ptr<Decoder<float>> cellLayerNormWeights;
181  std::unique_ptr<Decoder<float>> outputLayerNormWeights;
182 
183  if (useLayerNorm)
184  {
185  if (!useCifg)
186  {
187  inputLayerNormWeights = MakeDecoder<float>(
188  m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<void>());
189  }
190  forgetLayerNormWeights = MakeDecoder<float>(
191  m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<void>());
192  cellLayerNormWeights = MakeDecoder<float>(
193  m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<void>());
194  outputLayerNormWeights = MakeDecoder<float>(
195  m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<void>());
196  }
197 
198  if (!useCifg)
199  {
200  inputToInputWeightsTensor = MakeDecoder<float>(
201  m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>());
202  inputGateBiasTensor = MakeDecoder<float>(
203  m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<void>());
204  recurrentToInputWeightsTensor = MakeDecoder<float>(
205  m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<void>());
206  }
207 
208  if (usePeephole)
209  {
210  cellToForgetWeightsTensor = MakeDecoder<float>(
211  m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>());
212  cellToOutputWeightsTensor = MakeDecoder<float>(
213  m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>());
214  }
215 
216  if (!useCifg && usePeephole)
217  {
218  cellToInputWeightsTensor = MakeDecoder<float>(
219  m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>());
220  }
221 
223  {
224  projectionWeightsTensor = MakeDecoder<float>(
225  m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>());
226  if (m_ProjectionBiasTensor)
227  {
228  projectionBiasTensor = MakeDecoder<float>(
229  m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>());
230  }
231  }
232 
233  unsigned int batchInputSize = batchSize * inputSize;
234  unsigned int batchOutputSize = batchSize * nOutput;
235 
236  for (unsigned int t = 0; t < maxTime; ++t)
237  {
239  lstmInputInfo,
240  lstmOutputInfo,
241  inputToOutputWeightsShape,
242  recurrentToOutputWeightsShape,
243  inputData,
244  outputStateIn,
245  cellStateIn,
246  outputStateOut,
247  cellStateOut,
248  output,
249  cellStateOutDecoder,
250  outputDecoder,
251  inputToInputWeightsTensor,
252  inputToForgetWeightsTensor,
253  inputToCellWeightsTensor,
254  inputToOutputWeightsTensor,
255  recurrentToInputWeightsTensor,
256  recurrentToForgetWeightsTensor,
257  recurrentToCellWeightsTensor,
258  recurrentToOutputWeightsTensor,
259  cellToInputWeightsTensor,
260  cellToForgetWeightsTensor,
261  cellToOutputWeightsTensor,
262  inputGateBiasTensor,
263  forgetGateBiasTensor,
264  cellBiasTensor,
265  outputGateBiasTensor,
266  projectionWeightsTensor,
267  projectionBiasTensor,
268  inputLayerNormWeights,
269  forgetLayerNormWeights,
270  cellLayerNormWeights,
271  outputLayerNormWeights,
272  inputGateScratch,
273  cellScratch,
274  forgetGateScratch,
275  outputGateScratch,
276  inputGateScratchDecoder,
277  cellScratchDecoder,
278  forgetGateScratchDecoder,
279  outputGateScratchDecoder,
280  m_LayerNormEpsilon);
281 
282  currentInputData += batchInputSize;
283  inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
284  currentOutputData += batchOutputSize;
285  output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
286  outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
287 
288  // Assign output state out to the next output state in
289  outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateOutData);
290 
291  // Assign cell state out to the next cell state in
292  cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateOutData);
293  }
294 
296  {
297  // Permute Output back to batch major
298  const PermutationVector& mappings = {1U, 0U, 2U};
299  auto outputData = reinterpret_cast<float*>(outputs[0]->Map());
300  std::vector<float> outputValue(outputData, outputData + outputInfo.GetNumElements());
301  outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings);
302  outputInfo.SetShape(outputShape);
303  armnnUtils::Permute(outputShape, mappings, outputValue.data(), outputData, sizeof(float));
304  }
305 }
306 
307 } //namespace armnn
bool m_ProjectionEnabled
Enable/disable the projection layer.
const TensorShape & GetShape() const
Definition: Tensor.hpp:191
void LstmImpl(const LstmDescriptor &descriptor, const TensorInfo &inputInfo, const TensorInfo &outputInfo, const TensorShape &inputToOutputWeightsShape, const TensorShape &recurrentToOutputWeightsShape, std::unique_ptr< Decoder< float >> &inputData, std::unique_ptr< Decoder< float >> &outputStateIn, std::unique_ptr< Decoder< float >> &cellStateIn, std::unique_ptr< Encoder< float >> &outputStateOut, std::unique_ptr< Encoder< float >> &cellStateOut, std::unique_ptr< Encoder< float >> &output, std::unique_ptr< Decoder< float >> &cellStateOutDecoder, std::unique_ptr< Decoder< float >> &outputDecoder, std::unique_ptr< Decoder< float >> &inputToInputWeightsTensor, std::unique_ptr< Decoder< float >> &inputToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &inputToCellWeightsTensor, std::unique_ptr< Decoder< float >> &inputToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToInputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToCellWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToInputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &cellToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &inputGateBiasTensor, std::unique_ptr< Decoder< float >> &forgetGateBiasTensor, std::unique_ptr< Decoder< float >> &cellBiasTensor, std::unique_ptr< Decoder< float >> &outputGateBiasTensor, std::unique_ptr< Decoder< float >> &projectionWeightsTensor, std::unique_ptr< Decoder< float >> &projectionBiasTensor, std::unique_ptr< Decoder< float >> &inputLayerNormWeights, std::unique_ptr< Decoder< float >> &forgetLayerNormWeights, std::unique_ptr< Decoder< float >> &cellLayerNormWeights, std::unique_ptr< Decoder< float >> &outputLayerNormWeights, std::unique_ptr< Encoder< float >> &inputGateScratch, std::unique_ptr< Encoder< float >> &cellScratch, std::unique_ptr< Encoder< float >> &forgetGateScratch, std::unique_ptr< Encoder< float >> &outputGateScratch, std::unique_ptr< Decoder< float >> &inputGateScratchDecoder, std::unique_ptr< Decoder< float >> &cellScratchDecoder, std::unique_ptr< Decoder< float >> &forgetGateScratchDecoder, std::unique_ptr< Decoder< float >> &outputGateScratchDecoder, float layerNormEpsilon)
Definition: Lstm.cpp:13
std::unique_ptr< armnn::ScopedTensorHandle > AssignScopedTensorHandle(const armnn::ConstTensorHandle *ptr)
Definition: LstmUtils.cpp:299
bool m_TimeMajor
Enable/disable time major.
Copyright (c) 2021 ARM Limited and Contributors.
void SetShape(const TensorShape &newShape)
Definition: Tensor.hpp:193
UnidirectionalSequenceLstmQueueDescriptor m_Data
Definition: Workload.hpp:77
bool m_PeepholeEnabled
Enable/disable peephole.
bool m_CifgEnabled
Enable/disable cifg (coupled input & forget gate).
std::vector< ITensorHandle * > m_Outputs
void ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor) override
bool m_LayerNormEnabled
Enable/disable layer normalization.
Contains information about TensorInfos of a layer.
std::vector< ITensorHandle * > m_Inputs
const TensorInfo & GetTensorInfo(const ITensorHandle *tensorHandle)
float32 helpers
armnn::TensorShape Permuted(const armnn::TensorShape &srcShape, const armnn::PermutationVector &mappings)
Definition: Permute.cpp:98
RefUnidirectionalSequenceLstmWorkload(const UnidirectionalSequenceLstmQueueDescriptor &descriptor, const WorkloadInfo &info)
unsigned int GetNumElements() const
Definition: Tensor.hpp:196