ArmNN
 23.02
RefUnidirectionalSequenceLstmWorkload.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
7 #include "Activation.hpp"
8 #include "Encoders.hpp"
9 #include "Decoders.hpp"
10 #include "Lstm.hpp"
11 #include "LstmUtils.hpp"
12 #include "RefWorkloadUtils.hpp"
13 
14 #include <armnnUtils/Permute.hpp>
15 
16 namespace armnn
17 {
18 
21  const WorkloadInfo& info)
23  , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights))
24  , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights))
25  , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights))
26  , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights))
27  , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights))
28  , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights))
29  , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights))
30  , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights))
31  , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights))
32  , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights))
33  , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights))
34  , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias))
35  , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias))
36  , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias))
37  , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias))
38  , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights))
39  , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias))
40  , m_InputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights))
41  , m_ForgetLayerNormWeights (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights))
42  , m_CellLayerNormWeights (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights))
43  , m_OutputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights))
44 {}
45 
47 {
49 }
50 
52 {
53  WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
54  Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
55 }
56 
57 void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector<ITensorHandle*> inputs,
58  std::vector<ITensorHandle*> outputs) const
59 {
60  TensorInfo inputInfo = GetTensorInfo(inputs[0]);
61  const TensorInfo& outputStateInfo = GetTensorInfo(inputs[1]);
62  const TensorInfo& cellStateInfo = GetTensorInfo(inputs[2]);
63  TensorInfo outputStateOutInfo = GetTensorInfo(outputs[0]);
64  TensorInfo cellStateOutInfo = GetTensorInfo(outputs[1]);
65  TensorInfo outputInfo = GetTensorInfo(outputs[2]);
66  TensorShape& inputShape = inputInfo.GetShape();
67  TensorShape& outputShape= outputInfo.GetShape();
68  auto inputTensor = reinterpret_cast<float*>(inputs[0]->Map());
69 
71  {
72  // Permute to time major
73  const PermutationVector& mappings = {1U, 0U, 2U};
74  std::vector<float> inputValue(inputTensor, inputTensor + inputInfo.GetNumElements());
75  inputShape = armnnUtils::Permuted(inputInfo.GetShape(), mappings);
76  inputInfo.SetShape(inputShape);
77  armnnUtils::Permute(inputShape, mappings, inputValue.data(), inputTensor, sizeof(float));
78 
79  outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings);
80  outputInfo.SetShape(outputShape);
81  }
82  unsigned int maxTime = inputShape[0];
83  unsigned int batchSize = inputShape[1];
84  unsigned int outputSize = outputShape[2];
85  unsigned int inputSize = inputShape[2];
86 
87  TensorInfo scratchInfo = outputInfo;
88  scratchInfo.SetShape({batchSize, cellStateInfo.GetShape()[1]});
89 
90  std::vector<float> inputGateScratchBuffer;
91  std::vector<float> cellScratchBuffer(scratchInfo.GetNumElements(), 0.);
92  std::vector<float> forgetGateScratchBuffer(scratchInfo.GetNumElements(), 0.);
93  std::vector<float> outputGateScratchBuffer(scratchInfo.GetNumElements(), 0.);
94 
95  std::vector<float> outputStateOutBuffer(outputStateInfo.GetNumElements(), 0.);
96  std::vector<float> cellStateOutBuffer(cellStateInfo.GetNumElements(), 0.);
97 
98  void* outputStateOutData = outputStateOutBuffer.data();
99  void* cellStateOutData = cellStateOutBuffer.data();
100 
101  std::unique_ptr<Encoder<float>> inputGateScratch;
102  std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(scratchInfo, cellScratchBuffer.data());
103  std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(scratchInfo, forgetGateScratchBuffer.data());
104  std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(scratchInfo, outputGateScratchBuffer.data());
105 
106  std::unique_ptr<Decoder<float>> inputGateScratchDecoder;
107  std::unique_ptr<Decoder<float>> cellScratchDecoder = MakeDecoder<float>(scratchInfo, cellScratchBuffer.data());
108  std::unique_ptr<Decoder<float>> forgetGateScratchDecoder = MakeDecoder<float>(scratchInfo,
109  forgetGateScratchBuffer.data());
110  std::unique_ptr<Decoder<float>> outputGateScratchDecoder = MakeDecoder<float>(scratchInfo,
111  outputGateScratchBuffer.data());
112 
113  const bool useCifg = m_Data.m_Parameters.m_CifgEnabled;
114  const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled;
115  const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled;
116 
117  if (!useCifg)
118  {
119  inputGateScratchBuffer.resize(scratchInfo.GetNumElements(), 0.);
120  inputGateScratch = MakeEncoder<float>(scratchInfo, inputGateScratchBuffer.data());
121  inputGateScratchDecoder = MakeDecoder<float>(scratchInfo, inputGateScratchBuffer.data());
122  }
123 
124  std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputStateInfo, outputStateOutData);
125  std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(cellStateInfo, cellStateOutData);
126  std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(cellStateInfo, cellStateOutData);
127 
128  TensorInfo lstmInputInfo = inputInfo;
129  TensorShape batchInputShape = TensorShape({batchSize, inputSize});
130  lstmInputInfo.SetShape(batchInputShape);
131 
132  TensorInfo lstmOutputInfo = outputInfo;
133  lstmOutputInfo.SetShape({batchSize, outputSize});
134 
135  const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
136  const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();
137  unsigned int nOutput = recurrentToOutputWeightsShape[1];
138  auto outputStateInData = inputs[1]->Map();
139  std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateInData);
140 
141  auto cellStateInData = inputs[2]->Map();
142  std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateInData);
143 
144  auto currentInputData = reinterpret_cast<float*>(inputs[0]->Map());
145  std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
146  auto currentOutputData = reinterpret_cast<float*>(outputs[2]->Map());
147  std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
148  std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
149 
150  std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
151  std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
152  m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>());
153  std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
154  m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>());
155  std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
156  m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>());
157 
158  std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
159  std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
160  m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<void>());
161  std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
162  m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>());
163  std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
164  m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<void>());
165 
166  std::unique_ptr<Decoder<float>> inputGateBiasTensor;
167  std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
168  m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<void>());
169  std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
170  m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<void>());
171  std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
172  m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<void>());
173 
174  std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
175  std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
176  std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
177 
178  std::unique_ptr<Decoder<float>> projectionWeightsTensor;
179  std::unique_ptr<Decoder<float>> projectionBiasTensor;
180 
181  std::unique_ptr<Decoder<float>> inputLayerNormWeights;
182  std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
183  std::unique_ptr<Decoder<float>> cellLayerNormWeights;
184  std::unique_ptr<Decoder<float>> outputLayerNormWeights;
185 
186  if (useLayerNorm)
187  {
188  if (!useCifg)
189  {
190  inputLayerNormWeights = MakeDecoder<float>(
191  m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<void>());
192  }
193  forgetLayerNormWeights = MakeDecoder<float>(
194  m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<void>());
195  cellLayerNormWeights = MakeDecoder<float>(
196  m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<void>());
197  outputLayerNormWeights = MakeDecoder<float>(
198  m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<void>());
199  }
200 
201  if (!useCifg)
202  {
203  inputToInputWeightsTensor = MakeDecoder<float>(
204  m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>());
205  inputGateBiasTensor = MakeDecoder<float>(
206  m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<void>());
207  recurrentToInputWeightsTensor = MakeDecoder<float>(
208  m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<void>());
209  }
210 
211  if (usePeephole)
212  {
213  cellToForgetWeightsTensor = MakeDecoder<float>(
214  m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>());
215  cellToOutputWeightsTensor = MakeDecoder<float>(
216  m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>());
217  }
218 
219  if (!useCifg && usePeephole)
220  {
221  cellToInputWeightsTensor = MakeDecoder<float>(
222  m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>());
223  }
224 
226  {
227  projectionWeightsTensor = MakeDecoder<float>(
228  m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>());
229  if (m_ProjectionBiasTensor)
230  {
231  projectionBiasTensor = MakeDecoder<float>(
232  m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>());
233  }
234  }
235 
236  unsigned int batchInputSize = batchSize * inputSize;
237  unsigned int batchOutputSize = batchSize * nOutput;
238 
239  for (unsigned int t = 0; t < maxTime; ++t)
240  {
242  lstmInputInfo,
243  lstmOutputInfo,
244  inputToOutputWeightsShape,
245  recurrentToOutputWeightsShape,
246  inputData,
247  outputStateIn,
248  cellStateIn,
249  outputStateOut,
250  cellStateOut,
251  output,
252  cellStateOutDecoder,
253  outputDecoder,
254  inputToInputWeightsTensor,
255  inputToForgetWeightsTensor,
256  inputToCellWeightsTensor,
257  inputToOutputWeightsTensor,
258  recurrentToInputWeightsTensor,
259  recurrentToForgetWeightsTensor,
260  recurrentToCellWeightsTensor,
261  recurrentToOutputWeightsTensor,
262  cellToInputWeightsTensor,
263  cellToForgetWeightsTensor,
264  cellToOutputWeightsTensor,
265  inputGateBiasTensor,
266  forgetGateBiasTensor,
267  cellBiasTensor,
268  outputGateBiasTensor,
269  projectionWeightsTensor,
270  projectionBiasTensor,
271  inputLayerNormWeights,
272  forgetLayerNormWeights,
273  cellLayerNormWeights,
274  outputLayerNormWeights,
275  inputGateScratch,
276  cellScratch,
277  forgetGateScratch,
278  outputGateScratch,
279  inputGateScratchDecoder,
280  cellScratchDecoder,
281  forgetGateScratchDecoder,
282  outputGateScratchDecoder,
283  m_LayerNormEpsilon);
284 
285  currentInputData += batchInputSize;
286  inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
287  currentOutputData += batchOutputSize;
288  output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
289  outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
290 
291  // Assign output state out to the next output state in
292  outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateOutData);
293 
294  // Assign cell state out to the next cell state in
295  cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateOutData);
296  }
297 
299  {
300  // Permute Output back to batch major
301  const PermutationVector& mappings = {1U, 0U, 2U};
302  auto outputData = reinterpret_cast<float*>(outputs[2]->Map());
303  std::vector<float> outputValue(outputData, outputData + outputInfo.GetNumElements());
304  outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings);
305  outputInfo.SetShape(outputShape);
306  armnnUtils::Permute(outputShape, mappings, outputValue.data(), outputData, sizeof(float));
307  }
308 }
309 
310 } //namespace armnn
armnn::GetTensorInfo
const TensorInfo & GetTensorInfo(const ITensorHandle *tensorHandle)
float32 helpers
Definition: RefWorkloadUtils.hpp:27
Activation.hpp
armnn::LstmDescriptor::m_TimeMajor
bool m_TimeMajor
Enable/disable time major.
Definition: Descriptors.hpp:1101
AssignScopedTensorHandle
std::unique_ptr< armnn::ScopedTensorHandle > AssignScopedTensorHandle(const armnn::ConstTensorHandle *ptr)
Definition: LstmUtils.cpp:299
armnn::RefBaseWorkload
Definition: RefBaseWorkload.hpp:13
armnn::experimental::WorkingMemDescriptor::m_Inputs
std::vector< ITensorHandle * > m_Inputs
Definition: WorkingMemDescriptor.hpp:20
armnn::experimental::ExecutionData
Definition: ExecutionData.hpp:14
RefUnidirectionalSequenceLstmWorkload.hpp
armnn::LstmDescriptor::m_CifgEnabled
bool m_CifgEnabled
Enable/disable cifg (coupled input & forget gate).
Definition: Descriptors.hpp:1093
armnn::LstmDescriptor::m_LayerNormEnabled
bool m_LayerNormEnabled
Enable/disable layer normalization.
Definition: Descriptors.hpp:1099
armnn::experimental::ExecutionData::m_Data
void * m_Data
Definition: ExecutionData.hpp:16
armnn::experimental::WorkingMemDescriptor
Definition: WorkingMemDescriptor.hpp:18
armnn::BaseWorkload< UnidirectionalSequenceLstmQueueDescriptor >::m_Data
UnidirectionalSequenceLstmQueueDescriptor m_Data
Definition: Workload.hpp:83
armnn::RefUnidirectionalSequenceLstmWorkload::ExecuteAsync
void ExecuteAsync(ExecutionData &executionData) override
Definition: RefUnidirectionalSequenceLstmWorkload.cpp:51
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition: 01_00_quick_start.dox:6
armnn::UnidirectionalSequenceLstmQueueDescriptor
Definition: WorkloadData.hpp:686
armnnUtils::Permuted
armnn::TensorShape Permuted(const armnn::TensorShape &srcShape, const armnn::PermutationVector &mappings)
Definition: Permute.cpp:98
armnn::TensorShape
Definition: Tensor.hpp:20
RefWorkloadUtils.hpp
armnn::TensorInfo::GetNumElements
unsigned int GetNumElements() const
Definition: Tensor.hpp:196
Encoders.hpp
armnn::LstmDescriptor::m_PeepholeEnabled
bool m_PeepholeEnabled
Enable/disable peephole.
Definition: Descriptors.hpp:1095
armnn::TensorInfo
Definition: Tensor.hpp:152
Permute.hpp
armnn::TensorInfo::GetShape
const TensorShape & GetShape() const
Definition: Tensor.hpp:191
Lstm.hpp
armnn::WorkloadInfo
Contains information about TensorInfos of a layer.
Definition: WorkloadInfo.hpp:16
armnn::RefUnidirectionalSequenceLstmWorkload::Execute
void Execute() const override
Definition: RefUnidirectionalSequenceLstmWorkload.cpp:46
armnn::PermutationVector
Definition: Types.hpp:295
armnn::RefUnidirectionalSequenceLstmWorkload::RefUnidirectionalSequenceLstmWorkload
RefUnidirectionalSequenceLstmWorkload(const UnidirectionalSequenceLstmQueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: RefUnidirectionalSequenceLstmWorkload.cpp:19
armnn::QueueDescriptorWithParameters::m_Parameters
LayerDescriptor m_Parameters
Definition: WorkloadData.hpp:66
Decoders.hpp
armnn::LstmDescriptor::m_ProjectionEnabled
bool m_ProjectionEnabled
Enable/disable the projection layer.
Definition: Descriptors.hpp:1097
armnn::experimental::WorkingMemDescriptor::m_Outputs
std::vector< ITensorHandle * > m_Outputs
Definition: WorkingMemDescriptor.hpp:21
armnn::QueueDescriptor::m_Outputs
std::vector< ITensorHandle * > m_Outputs
Definition: WorkloadData.hpp:27
armnn::TensorInfo::SetShape
void SetShape(const TensorShape &newShape)
Definition: Tensor.hpp:193
armnn::LstmImpl
void LstmImpl(const LstmDescriptor &descriptor, const TensorInfo &inputInfo, const TensorInfo &outputInfo, const TensorShape &inputToOutputWeightsShape, const TensorShape &recurrentToOutputWeightsShape, std::unique_ptr< Decoder< float >> &inputData, std::unique_ptr< Decoder< float >> &outputStateIn, std::unique_ptr< Decoder< float >> &cellStateIn, std::unique_ptr< Encoder< float >> &outputStateOut, std::unique_ptr< Encoder< float >> &cellStateOut, std::unique_ptr< Encoder< float >> &output, std::unique_ptr< Decoder< float >> &cellStateOutDecoder, std::unique_ptr< Decoder< float >> &outputDecoder, std::unique_ptr< Decoder< float >> &inputToInputWeightsTensor, std::unique_ptr< Decoder< float >> &inputToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &inputToCellWeightsTensor, std::unique_ptr< Decoder< float >> &inputToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToInputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToCellWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToInputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &cellToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &inputGateBiasTensor, std::unique_ptr< Decoder< float >> &forgetGateBiasTensor, std::unique_ptr< Decoder< float >> &cellBiasTensor, std::unique_ptr< Decoder< float >> &outputGateBiasTensor, std::unique_ptr< Decoder< float >> &projectionWeightsTensor, std::unique_ptr< Decoder< float >> &projectionBiasTensor, std::unique_ptr< Decoder< float >> &inputLayerNormWeights, std::unique_ptr< Decoder< float >> &forgetLayerNormWeights, std::unique_ptr< Decoder< float >> &cellLayerNormWeights, std::unique_ptr< Decoder< float >> &outputLayerNormWeights, std::unique_ptr< Encoder< float >> &inputGateScratch, std::unique_ptr< Encoder< float >> &cellScratch, std::unique_ptr< Encoder< float >> &forgetGateScratch, std::unique_ptr< Encoder< float >> &outputGateScratch, std::unique_ptr< Decoder< float >> &inputGateScratchDecoder, std::unique_ptr< Decoder< float >> &cellScratchDecoder, std::unique_ptr< Decoder< float >> &forgetGateScratchDecoder, std::unique_ptr< Decoder< float >> &outputGateScratchDecoder, float layerNormEpsilon)
Definition: Lstm.cpp:13
LstmUtils.hpp
armnn::QueueDescriptor::m_Inputs
std::vector< ITensorHandle * > m_Inputs
Definition: WorkloadData.hpp:26
armnn::BoostLogSeverityMapping::info
@ info
armnnUtils::Permute
void Permute(const armnn::TensorShape &dstShape, const armnn::PermutationVector &mappings, const void *src, void *dst, size_t dataTypeSize)
Definition: Permute.cpp:131