ArmNN
 20.08
ClQLstmWorkload.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2020 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ClQLstmWorkload.hpp"
7 #include "ClWorkloadUtils.hpp"
8 
10 
11 #include "cl/ClTensorHandle.hpp"
12 
13 namespace armnn
14 {
15 using namespace armcomputetensorutils;
16 
18  : BaseWorkload<QLstmQueueDescriptor>(descriptor, info)
19 {
20  arm_compute::LSTMParams<arm_compute::ICLTensor> qLstmParams;
21 
22  // Mandatory params
23  m_InputToForgetWeightsTensor = std::make_unique<arm_compute::CLTensor>();
24  BuildArmComputeTensor(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights->GetTensorInfo());
25 
26  m_InputToCellWeightsTensor = std::make_unique<arm_compute::CLTensor>();
27  BuildArmComputeTensor(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights->GetTensorInfo());
28 
29  m_InputToOutputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
30  BuildArmComputeTensor(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights->GetTensorInfo());
31 
32  m_RecurrentToForgetWeightsTensor = std::make_unique<arm_compute::CLTensor>();
33  BuildArmComputeTensor(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights->GetTensorInfo());
34 
35  m_RecurrentToCellWeightsTensor = std::make_unique<arm_compute::CLTensor>();
36  BuildArmComputeTensor(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights->GetTensorInfo());
37 
38  m_RecurrentToOutputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
39  BuildArmComputeTensor(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights->GetTensorInfo());
40 
41  m_ForgetGateBiasTensor = std::make_unique<arm_compute::CLTensor>();
42  BuildArmComputeTensor(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias->GetTensorInfo());
43 
44  m_CellBiasTensor = std::make_unique<arm_compute::CLTensor>();
45  BuildArmComputeTensor(*m_CellBiasTensor, m_Data.m_CellBias->GetTensorInfo());
46 
47  m_OutputGateBiasTensor = std::make_unique<arm_compute::CLTensor>();
48  BuildArmComputeTensor(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias->GetTensorInfo());
49 
50  // Create tensors for optional params if they are enabled
52  {
53  m_CellToInputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
54 
56  {
57  // In ACL this is categorised as a CIFG param and not a Peephole param
58  BuildArmComputeTensor(*m_CellToInputWeightsTensor, m_Data.m_CellToInputWeights->GetTensorInfo());
59  }
60 
61  m_CellToForgetWeightsTensor = std::make_unique<arm_compute::CLTensor>();
62  BuildArmComputeTensor(*m_CellToForgetWeightsTensor, m_Data.m_CellToForgetWeights->GetTensorInfo());
63 
64  m_CellToOutputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
65  BuildArmComputeTensor(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights->GetTensorInfo());
66 
67  // Set Peephole params
68  qLstmParams.set_peephole_params(m_CellToForgetWeightsTensor.get(),
69  m_CellToOutputWeightsTensor.get());
70  }
71 
73  {
74  m_ProjectionWeightsTensor = std::make_unique<arm_compute::CLTensor>();
75  BuildArmComputeTensor(*m_ProjectionWeightsTensor, m_Data.m_ProjectionWeights->GetTensorInfo());
76 
77  m_ProjectionBiasTensor = std::make_unique<arm_compute::CLTensor>();
78  if (m_Data.m_ProjectionBias != nullptr)
79  {
80  BuildArmComputeTensor(*m_ProjectionBiasTensor, m_Data.m_ProjectionBias->GetTensorInfo());
81  }
82 
83  // Set projection params
84  qLstmParams.set_projection_params(
85  m_ProjectionWeightsTensor.get(),
86  m_Data.m_ProjectionBias != nullptr ? m_ProjectionBiasTensor.get() : nullptr);
87  }
88 
90  {
91  m_InputLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
92 
94  {
95  BuildArmComputeTensor(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights->GetTensorInfo());
96  }
97 
98  m_ForgetLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
99  BuildArmComputeTensor(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights->GetTensorInfo());
100 
101  m_CellLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
102  BuildArmComputeTensor(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights->GetTensorInfo());
103 
104  m_OutputLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
105  BuildArmComputeTensor(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights->GetTensorInfo());
106 
107  qLstmParams.set_layer_normalization_params(
108  m_Data.m_InputLayerNormWeights != nullptr ? m_InputLayerNormWeightsTensor.get() : nullptr,
109  m_ForgetLayerNormWeightsTensor.get(),
110  m_CellLayerNormWeightsTensor.get(),
111  m_OutputLayerNormWeightsTensor.get());
112  }
113 
115  {
116  m_InputToInputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
117  BuildArmComputeTensor(*m_InputToInputWeightsTensor, m_Data.m_InputToInputWeights->GetTensorInfo());
118 
119  m_RecurrentToInputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
120  BuildArmComputeTensor(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights->GetTensorInfo());
121 
122  m_InputGateBiasTensor = std::make_unique<arm_compute::CLTensor>();
123  BuildArmComputeTensor(*m_InputGateBiasTensor, m_Data.m_InputGateBias->GetTensorInfo());
124 
125  qLstmParams.set_cifg_params(
126  m_InputToInputWeightsTensor.get(),
127  m_RecurrentToInputWeightsTensor.get(),
128  m_Data.m_CellToInputWeights != nullptr ? m_CellToInputWeightsTensor.get() : nullptr,
129  m_InputGateBiasTensor.get());
130  }
131 
132  // Input/Output tensors
133  const arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
134  const arm_compute::ICLTensor& outputStateIn = static_cast<IClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
135  arm_compute::ICLTensor& cellStateIn = static_cast<IClTensorHandle*>(m_Data.m_Inputs[2])->GetTensor();
136 
137  arm_compute::ICLTensor& outputStateOut = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
138  arm_compute::ICLTensor& cellStateOut = static_cast<IClTensorHandle*>(m_Data.m_Outputs[1])->GetTensor();
139  arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[2])->GetTensor();
140 
141  // Set scalar descriptor params
142  qLstmParams.set_cell_clip_params(m_Data.m_Parameters.m_CellClip);
143  qLstmParams.set_projection_clip_params(m_Data.m_Parameters.m_ProjectionClip);
144  qLstmParams.set_hidden_state_params(m_Data.m_Parameters.m_HiddenStateZeroPoint,
146  qLstmParams.set_matmul_scale_params(m_Data.m_Parameters.m_InputIntermediateScale,
150 
151  m_QLstmLayer.configure(&input,
152  m_InputToForgetWeightsTensor.get(),
153  m_InputToCellWeightsTensor.get(),
154  m_InputToOutputWeightsTensor.get(),
155  m_RecurrentToForgetWeightsTensor.get(),
156  m_RecurrentToCellWeightsTensor.get(),
157  m_RecurrentToOutputWeightsTensor.get(),
158  m_ForgetGateBiasTensor.get(),
159  m_CellBiasTensor.get(),
160  m_OutputGateBiasTensor.get(),
161  &cellStateIn,
162  &outputStateIn,
163  &cellStateOut,
164  &outputStateOut,
165  &output,
166  qLstmParams);
167 
168  // InitializeArmComputeTensorData for mandatory params
169  InitializeArmComputeClTensorData(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights);
170  InitializeArmComputeClTensorData(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights);
171  InitializeArmComputeClTensorData(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights);
172 
173  InitializeArmComputeClTensorData(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights);
174  InitializeArmComputeClTensorData(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights);
175  InitializeArmComputeClTensorData(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights);
176 
177  InitializeArmComputeClTensorData(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias);
179  InitializeArmComputeClTensorData(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias);
180 
182  {
183  InitializeArmComputeClTensorData(*m_InputToInputWeightsTensor, m_Data.m_InputToInputWeights);
184  InitializeArmComputeClTensorData(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights);
186  }
187 
189  {
190  InitializeArmComputeClTensorData(*m_ProjectionWeightsTensor, m_Data.m_ProjectionWeights);
191 
192  if (m_Data.m_ProjectionBias != nullptr)
193  {
194  InitializeArmComputeClTensorData(*m_ProjectionBiasTensor, m_Data.m_ProjectionBias);
195  }
196  }
197 
199  {
201  {
202  InitializeArmComputeClTensorData(*m_CellToInputWeightsTensor, m_Data.m_CellToInputWeights);
203  }
204 
205  InitializeArmComputeClTensorData(*m_CellToForgetWeightsTensor, m_Data.m_CellToForgetWeights);
206  InitializeArmComputeClTensorData(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights);
207  }
208 
210  {
212  {
213  InitializeArmComputeClTensorData(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights);
214  }
215  InitializeArmComputeClTensorData(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights);
216  InitializeArmComputeClTensorData(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights);
217  InitializeArmComputeClTensorData(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights);
218  }
219 
220  m_QLstmLayer.prepare();
221 
222  FreeUnusedTensors();
223 }
224 
226 {
227  m_QLstmLayer.run();
228 }
229 
231  const TensorInfo& cellStateIn,
232  const TensorInfo& outputStateIn,
233  const TensorInfo& cellStateOut,
234  const TensorInfo& outputStateOut,
235  const TensorInfo& output,
236  const QLstmDescriptor& descriptor,
237  const LstmInputParamsInfo& paramsInfo)
238 {
239  arm_compute::LSTMParams<arm_compute::ITensorInfo> aclParamsInfo;
240 
241  // The inputs and outputs
242  const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input);
243  const arm_compute::TensorInfo aclOutputStateInInfo = BuildArmComputeTensorInfo(outputStateIn);
244  const arm_compute::TensorInfo aclCellStateInInfo = BuildArmComputeTensorInfo(cellStateIn);
245 
246  const arm_compute::TensorInfo aclOutputStateOutInfo = BuildArmComputeTensorInfo(outputStateOut);
247  const arm_compute::TensorInfo aclCellStateOutInfo = BuildArmComputeTensorInfo(cellStateOut);
248  const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
249 
250  // Mandatory tensor info
251  const arm_compute::TensorInfo aclInputToForgetWeightsInfo
252  = BuildArmComputeTensorInfo(paramsInfo.GetInputToForgetWeights());
253  const arm_compute::TensorInfo aclInputToCellWeightsInfo
254  = BuildArmComputeTensorInfo(paramsInfo.GetInputToCellWeights());
255  const arm_compute::TensorInfo aclInputToOutputWeightsInfo
256  = BuildArmComputeTensorInfo(paramsInfo.GetInputToOutputWeights());
257  const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo
258  = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToForgetWeights());
259  const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo
260  = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToCellWeights());
261  const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo
262  = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToOutputWeights());
263  const arm_compute::TensorInfo aclForgetGateBiasInfo
264  = BuildArmComputeTensorInfo(paramsInfo.GetForgetGateBias());
265  const arm_compute::TensorInfo aclCellBiasInfo
266  = BuildArmComputeTensorInfo(paramsInfo.GetCellBias());
267  const arm_compute::TensorInfo aclOutputGateBiasInfo
268  = BuildArmComputeTensorInfo(paramsInfo.GetOutputGateBias());
269 
270  // Optional tensor info
271  arm_compute::TensorInfo aclInputToInputWeightsInfo;
272  arm_compute::TensorInfo aclRecurrentToInputWeightsInfo;
273  arm_compute::TensorInfo aclCellToInputWeightsInfo;
274  arm_compute::TensorInfo aclCellToForgetWeightsInfo;
275  arm_compute::TensorInfo aclCellToOutputWeightsInfo;
276  arm_compute::TensorInfo aclInputGateBiasInfo;
277  arm_compute::TensorInfo aclProjectionWeightsInfo;
278  arm_compute::TensorInfo aclProjectionBiasInfo;
279  arm_compute::TensorInfo aclInputLayerNormWeightsInfo;
280  arm_compute::TensorInfo aclForgetLayerNormWeightsInfo;
281  arm_compute::TensorInfo aclCellLayerNormWeightsInfo;
282  arm_compute::TensorInfo aclOutputLayerNormWeightsInfo;
283 
284 
285  if (descriptor.m_PeepholeEnabled)
286  {
287  if (!descriptor.m_CifgEnabled)
288  {
289  aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToInputWeights());
290  }
291 
292  aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToForgetWeights());
293  aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToOutputWeights());
294 
295  aclParamsInfo.set_peephole_params(&aclCellToForgetWeightsInfo,
296  &aclCellToOutputWeightsInfo);
297  }
298 
299  if (descriptor.m_ProjectionEnabled)
300  {
301  aclProjectionWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetProjectionWeights());
302 
303  if (paramsInfo.m_ProjectionBias != nullptr)
304  {
305  aclProjectionBiasInfo = BuildArmComputeTensorInfo(paramsInfo.GetProjectionBias());
306  }
307 
308  aclParamsInfo.set_projection_params(
309  &aclProjectionWeightsInfo,
310  paramsInfo.m_ProjectionBias != nullptr ? &aclProjectionBiasInfo : nullptr);
311  }
312 
313  if (descriptor.m_LayerNormEnabled)
314  {
315  if (!descriptor.m_CifgEnabled)
316  {
317  aclInputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetInputLayerNormWeights());
318  }
319 
320  aclForgetLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetForgetLayerNormWeights());
321  aclCellLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellLayerNormWeights());
322  aclOutputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetOutputLayerNormWeights());
323 
324  aclParamsInfo.set_layer_normalization_params(
325  paramsInfo.m_InputLayerNormWeights != nullptr ? &aclInputLayerNormWeightsInfo : nullptr,
326  &aclForgetLayerNormWeightsInfo,
327  &aclCellLayerNormWeightsInfo,
328  &aclOutputLayerNormWeightsInfo);
329  }
330 
331  if (!descriptor.m_CifgEnabled)
332  {
333  aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetInputToInputWeights());
334  aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToInputWeights());
335  aclInputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.GetInputGateBias());
336 
337 
338  aclParamsInfo.set_cifg_params(
339  &aclInputToInputWeightsInfo,
340  &aclRecurrentToInputWeightsInfo,
341  paramsInfo.m_CellToInputWeights != nullptr ? &aclCellToInputWeightsInfo : nullptr,
342  &aclInputGateBiasInfo);
343  }
344 
345  aclParamsInfo.set_cell_clip_params(descriptor.m_CellClip);
346  aclParamsInfo.set_projection_clip_params(descriptor.m_ProjectionClip);
347  aclParamsInfo.set_hidden_state_params(descriptor.m_HiddenStateZeroPoint, descriptor.m_HiddenStateScale);
348  aclParamsInfo.set_matmul_scale_params(descriptor.m_InputIntermediateScale,
349  descriptor.m_ForgetIntermediateScale,
350  descriptor.m_CellIntermediateScale,
351  descriptor.m_OutputIntermediateScale);
352 
353  return arm_compute::CLQLSTMLayer::validate(&aclInputInfo,
354  &aclInputToForgetWeightsInfo,
355  &aclInputToCellWeightsInfo,
356  &aclInputToOutputWeightsInfo,
357  &aclRecurrentToForgetWeightsInfo,
358  &aclRecurrentToCellWeightsInfo,
359  &aclRecurrentToOutputWeightsInfo,
360  &aclForgetGateBiasInfo,
361  &aclCellBiasInfo,
362  &aclOutputGateBiasInfo,
363  &aclCellStateInInfo,
364  &aclOutputStateInInfo,
365  &aclCellStateOutInfo,
366  &aclOutputStateOutInfo,
367  &aclOutputInfo,
368  aclParamsInfo);
369 }
370 
371 void ClQLstmWorkload::FreeUnusedTensors()
372 {
373  FreeTensorIfUnused(m_InputToInputWeightsTensor);
374  FreeTensorIfUnused(m_InputToForgetWeightsTensor);
375  FreeTensorIfUnused(m_InputToCellWeightsTensor);
376  FreeTensorIfUnused(m_InputToOutputWeightsTensor);
377 
378  FreeTensorIfUnused(m_RecurrentToInputWeightsTensor);
379  FreeTensorIfUnused(m_RecurrentToForgetWeightsTensor);
380  FreeTensorIfUnused(m_RecurrentToCellWeightsTensor);
381  FreeTensorIfUnused(m_RecurrentToOutputWeightsTensor);
382 
383  FreeTensorIfUnused(m_CellToInputWeightsTensor);
384  FreeTensorIfUnused(m_CellToForgetWeightsTensor);
385  FreeTensorIfUnused(m_CellToOutputWeightsTensor);
386 
387  FreeTensorIfUnused(m_InputGateBiasTensor);
388  FreeTensorIfUnused(m_ForgetGateBiasTensor);
389  FreeTensorIfUnused(m_CellBiasTensor);
390  FreeTensorIfUnused(m_OutputGateBiasTensor);
391 
392  FreeTensorIfUnused(m_ProjectionWeightsTensor);
393  FreeTensorIfUnused(m_ProjectionBiasTensor);
394 
395  FreeTensorIfUnused(m_InputLayerNormWeightsTensor);
396  FreeTensorIfUnused(m_ForgetLayerNormWeightsTensor);
397  FreeTensorIfUnused(m_CellLayerNormWeightsTensor);
398  FreeTensorIfUnused(m_OutputLayerNormWeightsTensor);
399 }
400 
401 } //namespace armnn
const ConstCpuTensorHandle * m_CellToForgetWeights
const TensorInfo * m_InputLayerNormWeights
Definition: LstmParams.hpp:106
const TensorInfo & GetRecurrentToCellWeights() const
Definition: LstmParams.hpp:145
const ConstCpuTensorHandle * m_ProjectionWeights
const TensorInfo & GetCellBias() const
Definition: LstmParams.hpp:173
void InitializeArmComputeClTensorData(arm_compute::CLTensor &clTensor, const ConstCpuTensorHandle *handle)
const TensorInfo & GetRecurrentToInputWeights() const
Definition: LstmParams.hpp:137
const TensorInfo & GetCellLayerNormWeights() const
Definition: LstmParams.hpp:197
const TensorInfo & GetRecurrentToOutputWeights() const
Definition: LstmParams.hpp:149
bool m_PeepholeEnabled
Enable/disable peephole.
const ConstCpuTensorHandle * m_ProjectionBias
virtual void Execute() const override
float m_HiddenStateScale
Hidden State quantization scale.
const ConstCpuTensorHandle * m_ForgetLayerNormWeights
const QLstmQueueDescriptor m_Data
Definition: Workload.hpp:46
float m_OutputIntermediateScale
Output intermediate quantization scale.
const TensorInfo & GetCellToInputWeights() const
Definition: LstmParams.hpp:153
const ConstCpuTensorHandle * m_CellLayerNormWeights
const ConstCpuTensorHandle * m_RecurrentToCellWeights
const ConstCpuTensorHandle * m_RecurrentToInputWeights
const ConstCpuTensorHandle * m_OutputGateBias
const ConstCpuTensorHandle * m_CellBias
Copyright (c) 2020 ARM Limited.
const TensorInfo & GetCellToForgetWeights() const
Definition: LstmParams.hpp:157
const TensorInfo & GetForgetLayerNormWeights() const
Definition: LstmParams.hpp:193
const TensorInfo & GetCellToOutputWeights() const
Definition: LstmParams.hpp:161
const ConstCpuTensorHandle * m_CellToOutputWeights
const TensorInfo & GetInputToCellWeights() const
Definition: LstmParams.hpp:129
const ConstCpuTensorHandle * m_OutputLayerNormWeights
bool m_LayerNormEnabled
Enable/disable layer normalization.
const ConstCpuTensorHandle * m_InputToForgetWeights
const TensorInfo & GetInputToOutputWeights() const
Definition: LstmParams.hpp:133
float m_ProjectionClip
Clipping threshold value for the projection.
float m_InputIntermediateScale
Input intermediate quantization scale.
const TensorInfo * m_ProjectionBias
Definition: LstmParams.hpp:105
Status
enumeration
Definition: Types.hpp:26
A QLstmDescriptor for the QLstmLayer.
const TensorInfo * m_CellToInputWeights
Definition: LstmParams.hpp:97
const TensorInfo & GetRecurrentToForgetWeights() const
Definition: LstmParams.hpp:141
float m_ForgetIntermediateScale
Forget intermediate quantization scale.
const ConstCpuTensorHandle * m_CellToInputWeights
const TensorInfo & GetInputToInputWeights() const
Definition: LstmParams.hpp:121
const TensorInfo & GetOutputLayerNormWeights() const
Definition: LstmParams.hpp:201
float m_CellClip
Clipping threshold value for the cell state.
const ConstCpuTensorHandle * m_RecurrentToOutputWeights
const TensorInfo & GetForgetGateBias() const
Definition: LstmParams.hpp:169
std::vector< ITensorHandle * > m_Outputs
bool m_ProjectionEnabled
Enable/disable the projection layer.
const ConstCpuTensorHandle * m_InputGateBias
const TensorInfo & GetInputGateBias() const
Definition: LstmParams.hpp:165
const TensorInfo & GetProjectionWeights() const
Definition: LstmParams.hpp:181
const TensorInfo & GetInputToForgetWeights() const
Definition: LstmParams.hpp:125
Contains information about inputs and outputs to a layer.
arm_compute::Status ClQLstmWorkloadValidate(const TensorInfo &input, const TensorInfo &cellStateIn, const TensorInfo &outputStateIn, const TensorInfo &cellStateOut, const TensorInfo &outputStateOut, const TensorInfo &output, const QLstmDescriptor &descriptor, const LstmInputParamsInfo &paramsInfo)
const TensorInfo & GetInputLayerNormWeights() const
Definition: LstmParams.hpp:189
std::vector< ITensorHandle * > m_Inputs
const ConstCpuTensorHandle * m_InputLayerNormWeights
const ConstCpuTensorHandle * m_RecurrentToForgetWeights
ClQLstmWorkload(const QLstmQueueDescriptor &descriptor, const WorkloadInfo &info)
const TensorInfo & GetOutputGateBias() const
Definition: LstmParams.hpp:177
const ConstCpuTensorHandle * m_ForgetGateBias
const TensorInfo & GetProjectionBias() const
Definition: LstmParams.hpp:185
float m_CellIntermediateScale
Cell intermediate quantization scale.
const ConstCpuTensorHandle * m_InputToOutputWeights
bool m_CifgEnabled
Enable/disable CIFG (coupled input & forget gate).
const TensorInfo & GetTensorInfo() const
const ConstCpuTensorHandle * m_InputToInputWeights
int32_t m_HiddenStateZeroPoint
Hidden State zero point.
const ConstCpuTensorHandle * m_InputToCellWeights