ArmNN
 22.11
RefLstmWorkload.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "RefLstmWorkload.hpp"
7 #include "Activation.hpp"
8 #include "Encoders.hpp"
9 #include "Decoders.hpp"
10 #include "Lstm.hpp"
11 #include "LstmUtils.hpp"
12 #include "RefWorkloadUtils.hpp"
13 
14 namespace armnn
15 {
16 
18  : RefBaseWorkload<LstmQueueDescriptor>(descriptor, info)
19  , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights))
20  , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights))
21  , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights))
22  , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights))
23  , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights))
24  , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights))
25  , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights))
26  , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights))
27  , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights))
28  , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights))
29  , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights))
30  , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias))
31  , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias))
32  , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias))
33  , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias))
34  , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights))
35  , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias))
36  , m_InputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights))
37  , m_ForgetLayerNormWeights (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights))
38  , m_CellLayerNormWeights (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights))
39  , m_OutputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights))
40 {}
41 
43 {
45 }
46 
48 {
49  WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
50  Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
51 }
52 
53 void RefLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
54 {
55  // This is a porting of the LSTM::Eval() method in the Android code base
56  // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp
57 
58  const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
59  const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
60 
61  const TensorShape& inputShape = inputInfo.GetShape();
62 
63  std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo, outputs[1]->Map());
64  std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(outputInfo, outputs[2]->Map());
65  std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(outputInfo, outputs[3]->Map());
66 
67  std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(outputInfo, outputs[2]->Map());
68  std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(outputInfo, outputs[3]->Map());
69 
70  std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(inputInfo, inputs[0]->Map());
71  std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(inputInfo, inputs[1]->Map());
72  std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(inputInfo, inputs[2]->Map());
73 
74  const uint32_t nBatch = inputShape[0];
75  const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0];
76 
77  const bool useCifg = m_Data.m_Parameters.m_CifgEnabled;
78  const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled;
79  const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled;
80 
81  // Index the scratch buffers pointers to the global scratch buffer.
82  std::unique_ptr<Encoder<float>> inputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
83  std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
84  std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
85  std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
86 
87  std::unique_ptr<Decoder<float>> inputGateScratchDecoder =
88  MakeDecoder<float>(outputInfo, outputs[0]->Map());
89  std::unique_ptr<Decoder<float>> cellScratchDecoder =
90  MakeDecoder<float>(outputInfo, outputs[0]->Map());
91  std::unique_ptr<Decoder<float>> forgetGateScratchDecoder =
92  MakeDecoder<float>(outputInfo, outputs[0]->Map());
93  std::unique_ptr<Decoder<float>> outputGateScratchDecoder =
94  MakeDecoder<float>(outputInfo, outputs[0]->Map());
95 
96  if (useCifg)
97  {
98  *cellScratch += (0 * nCell * nBatch);
99  *forgetGateScratch += (1 * nCell * nBatch);
100  *outputGateScratch += (2 * nCell * nBatch);
101 
102  *cellScratchDecoder += (0 * nCell * nBatch);
103  *forgetGateScratchDecoder += (1 * nCell * nBatch);
104  *outputGateScratchDecoder += (2 * nCell * nBatch);
105  }
106  else
107  {
108  *inputGateScratch += (0 * nCell * nBatch);
109  *cellScratch += (1 * nCell * nBatch);
110  *forgetGateScratch += (2 * nCell * nBatch);
111  *outputGateScratch += (3 * nCell * nBatch);
112 
113  *inputGateScratchDecoder += (0 * nCell * nBatch);
114  *cellScratchDecoder += (1 * nCell * nBatch);
115  *forgetGateScratchDecoder += (2 * nCell * nBatch);
116  *outputGateScratchDecoder += (3 * nCell * nBatch);
117  }
118 
119  std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
120  std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
121  m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>());
122  std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
123  m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>());
124  std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
125  m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>());
126 
127  std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
128  std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
129  m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<void>());
130  std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
131  m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>());
132  std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
133  m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<void>());
134 
135  std::unique_ptr<Decoder<float>> inputGateBiasTensor;
136  std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
137  m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<void>());
138  std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
139  m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<void>());
140  std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
141  m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<void>());
142 
143  std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
144  std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
145  std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
146 
147  std::unique_ptr<Decoder<float>> projectionWeightsTensor;
148  std::unique_ptr<Decoder<float>> projectionBiasTensor;
149 
150  std::unique_ptr<Decoder<float>> inputLayerNormWeights;
151  std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
152  std::unique_ptr<Decoder<float>> cellLayerNormWeights;
153  std::unique_ptr<Decoder<float>> outputLayerNormWeights;
154 
155  const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
156  const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();
157 
158  if (useLayerNorm)
159  {
160  if (!useCifg)
161  {
162  inputLayerNormWeights = MakeDecoder<float>(
163  m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<void>());
164  }
165  forgetLayerNormWeights = MakeDecoder<float>(
166  m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<void>());
167  cellLayerNormWeights = MakeDecoder<float>(
168  m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<void>());
169  outputLayerNormWeights = MakeDecoder<float>(
170  m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<void>());
171  }
172 
173  if (!useCifg)
174  {
175  inputToInputWeightsTensor = MakeDecoder<float>(
176  m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>());
177  inputGateBiasTensor = MakeDecoder<float>(
178  m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<void>());
179  recurrentToInputWeightsTensor = MakeDecoder<float>(
180  m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<void>());
181  }
182 
183  if (usePeephole)
184  {
185  cellToForgetWeightsTensor = MakeDecoder<float>(
186  m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>());
187  cellToOutputWeightsTensor = MakeDecoder<float>(
188  m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>());
189  }
190 
191  if (!useCifg && usePeephole)
192  {
193  cellToInputWeightsTensor = MakeDecoder<float>(
194  m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>());
195  }
196 
198  {
199  projectionWeightsTensor = MakeDecoder<float>(
200  m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>());
201  if (m_ProjectionBiasTensor)
202  {
203  projectionBiasTensor = MakeDecoder<float>(
204  m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>());
205  }
206  }
207 
209  inputInfo,
210  outputInfo,
211  inputToOutputWeightsShape,
212  recurrentToOutputWeightsShape,
213  inputData,
214  outputStateIn,
215  cellStateIn,
216  outputStateOut,
217  cellStateOut,
218  output,
219  cellStateOutDecoder,
220  outputDecoder,
221  inputToInputWeightsTensor,
222  inputToForgetWeightsTensor,
223  inputToCellWeightsTensor,
224  inputToOutputWeightsTensor,
225  recurrentToInputWeightsTensor,
226  recurrentToForgetWeightsTensor,
227  recurrentToCellWeightsTensor,
228  recurrentToOutputWeightsTensor,
229  cellToInputWeightsTensor,
230  cellToForgetWeightsTensor,
231  cellToOutputWeightsTensor,
232  inputGateBiasTensor,
233  forgetGateBiasTensor,
234  cellBiasTensor,
235  outputGateBiasTensor,
236  projectionWeightsTensor,
237  projectionBiasTensor,
238  inputLayerNormWeights,
239  forgetLayerNormWeights,
240  cellLayerNormWeights,
241  outputLayerNormWeights,
242  inputGateScratch,
243  cellScratch,
244  forgetGateScratch,
245  outputGateScratch,
246  inputGateScratchDecoder,
247  cellScratchDecoder,
248  forgetGateScratchDecoder,
249  outputGateScratchDecoder,
250  m_LayerNormEpsilon);
251 }
252 
253 } //namespace armnn
bool m_ProjectionEnabled
Enable/disable the projection layer.
const TensorShape & GetShape() const
Definition: Tensor.hpp:191
void LstmImpl(const LstmDescriptor &descriptor, const TensorInfo &inputInfo, const TensorInfo &outputInfo, const TensorShape &inputToOutputWeightsShape, const TensorShape &recurrentToOutputWeightsShape, std::unique_ptr< Decoder< float >> &inputData, std::unique_ptr< Decoder< float >> &outputStateIn, std::unique_ptr< Decoder< float >> &cellStateIn, std::unique_ptr< Encoder< float >> &outputStateOut, std::unique_ptr< Encoder< float >> &cellStateOut, std::unique_ptr< Encoder< float >> &output, std::unique_ptr< Decoder< float >> &cellStateOutDecoder, std::unique_ptr< Decoder< float >> &outputDecoder, std::unique_ptr< Decoder< float >> &inputToInputWeightsTensor, std::unique_ptr< Decoder< float >> &inputToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &inputToCellWeightsTensor, std::unique_ptr< Decoder< float >> &inputToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToInputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToCellWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToInputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &cellToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &inputGateBiasTensor, std::unique_ptr< Decoder< float >> &forgetGateBiasTensor, std::unique_ptr< Decoder< float >> &cellBiasTensor, std::unique_ptr< Decoder< float >> &outputGateBiasTensor, std::unique_ptr< Decoder< float >> &projectionWeightsTensor, std::unique_ptr< Decoder< float >> &projectionBiasTensor, std::unique_ptr< Decoder< float >> &inputLayerNormWeights, std::unique_ptr< Decoder< float >> &forgetLayerNormWeights, std::unique_ptr< Decoder< float >> &cellLayerNormWeights, std::unique_ptr< Decoder< float >> &outputLayerNormWeights, std::unique_ptr< Encoder< float >> &inputGateScratch, std::unique_ptr< Encoder< float >> &cellScratch, std::unique_ptr< Encoder< float >> &forgetGateScratch, std::unique_ptr< Encoder< float >> &outputGateScratch, std::unique_ptr< Decoder< float >> &inputGateScratchDecoder, std::unique_ptr< Decoder< float >> &cellScratchDecoder, std::unique_ptr< Decoder< float >> &forgetGateScratchDecoder, std::unique_ptr< Decoder< float >> &outputGateScratchDecoder, float layerNormEpsilon)
Definition: Lstm.cpp:13
RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
std::unique_ptr< armnn::ScopedTensorHandle > AssignScopedTensorHandle(const armnn::ConstTensorHandle *ptr)
Definition: LstmUtils.cpp:299
Copyright (c) 2021 ARM Limited and Contributors.
bool m_PeepholeEnabled
Enable/disable peephole.
void Execute() const override
bool m_CifgEnabled
Enable/disable cifg (coupled input & forget gate).
std::vector< ITensorHandle * > m_Outputs
bool m_LayerNormEnabled
Enable/disable layer normalization.
void ExecuteAsync(ExecutionData &executionData) override
Contains information about TensorInfos of a layer.
std::vector< ITensorHandle * > m_Inputs
const TensorInfo & GetTensorInfo(const ITensorHandle *tensorHandle)
float32 helpers