67 const uint32_t numBatches = inputShape[0];
68 const uint32_t inputSize = inputShape[1];
69 const uint32_t outputSize = outputStateInShape[1];
70 const uint32_t numUnits = cellStateInShape[1];
79 std::unique_ptr<Decoder<float>> inputDecoder =
81 std::unique_ptr<Decoder<float>> outputStateInDecoder =
83 std::unique_ptr<Decoder<float>> cellStateInDecoder =
87 std::unique_ptr<Decoder<float>> outputStateOutDecoder =
89 std::unique_ptr<Decoder<float>> cellStateOutDecoder =
91 std::unique_ptr<Decoder<float>> outputDecoder =
95 std::unique_ptr<Encoder<float>> outputStateOutEncoder =
97 std::unique_ptr<Encoder<float>> cellStateOutEncoder =
99 std::unique_ptr<Encoder<float>> outputEncoder =
103 std::unique_ptr<Decoder<float>> inputToForgetWeightsDecoder = MakeDecoder<float>(
104 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetTensor<
void>());
105 std::unique_ptr<Decoder<float>> inputToCellWeightsDecoder = MakeDecoder<float>(
106 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetTensor<
void>());
107 std::unique_ptr<Decoder<float>> inputToOutputWeightsDecoder = MakeDecoder<float>(
108 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetTensor<
void>());
110 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsDecoder = MakeDecoder<float>(
111 m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetTensor<
void>());
112 std::unique_ptr<Decoder<float>> recurrentToCellWeightsDecoder = MakeDecoder<float>(
113 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetTensor<
void>());
114 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsDecoder = MakeDecoder<float>(
115 m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetTensor<
void>());
118 std::unique_ptr<Decoder<float>> inputToInputWeightsDecoder;
119 std::unique_ptr<Decoder<float>> recurrentToInputWeightsDecoder;
120 std::unique_ptr<Decoder<float>> inputGateBiasDecoder;
123 std::unique_ptr<Decoder<float>> cellToInputWeightsDecoder;
124 std::unique_ptr<Decoder<float>> cellToForgetWeightsDecoder;
125 std::unique_ptr<Decoder<float>> cellToOutputWeightsDecoder;
128 std::unique_ptr<Decoder<float>> projectionWeightsDecoder;
129 std::unique_ptr<Decoder<float>> projectionBiasDecoder;
132 std::unique_ptr<Decoder<float>> inputLayerNormWeightsDecoder;
133 std::unique_ptr<Decoder<float>> forgetLayerNormWeightsDecoder;
134 std::unique_ptr<Decoder<float>> cellLayerNormWeightsDecoder;
135 std::unique_ptr<Decoder<float>> outputLayerNormWeightsDecoder;
138 std::unique_ptr<Decoder<float>> forgetGateBiasDecoder;
139 std::unique_ptr<Decoder<float>> cellGateBiasDecoder;
140 std::unique_ptr<Decoder<float>> outputGateBiasDecoder;
143 const uint32_t stateTensorSize = numBatches * numUnits;
144 std::vector<int16_t> inputGateData(stateTensorSize);
145 std::vector<int16_t> cellGateData(stateTensorSize);
146 std::vector<int16_t> forgetGateData(stateTensorSize);
147 std::vector<int16_t> outputGateData(stateTensorSize);
148 std::vector<int32_t> hiddenStateData(stateTensorSize);
149 std::vector<int16_t> outputInt16Data(numBatches * outputSize);
165 outputInfo.GetQuantizationScale(),
166 outputInfo.GetQuantizationOffset());
169 std::unique_ptr<Decoder<float>> inputGateDecoder =
170 MakeDecoder<float>(inputGateInfo, inputGateData.data());
171 std::unique_ptr<Decoder<float>> cellGateDecoder =
172 MakeDecoder<float>(cellGateInfo, cellGateData.data());
173 std::unique_ptr<Decoder<float>> forgetGateDecoder =
174 MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
175 std::unique_ptr<Decoder<float>> outputGateDecoder =
176 MakeDecoder<float>(outputGateInfo, outputGateData.data());
177 std::unique_ptr<Decoder<float>> hiddenStateDecoder =
178 MakeDecoder<float>(hiddenStateInfo, hiddenStateData.data());
180 std::unique_ptr<Encoder<float>> inputGateEncoder =
181 MakeEncoder<float>(inputGateInfo, inputGateData.data());
182 std::unique_ptr<Encoder<float>> cellGateEncoder =
183 MakeEncoder<float>(cellGateInfo, cellGateData.data());
184 std::unique_ptr<Encoder<float>> forgetGateEncoder =
185 MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
186 std::unique_ptr<Encoder<float>> outputGateEncoder =
187 MakeEncoder<float>(outputGateInfo, outputGateData.data());
188 std::unique_ptr<Encoder<float>> hiddenStateEncoder =
189 MakeEncoder<float>(hiddenStateInfo, hiddenStateData.data());
192 std::unique_ptr<Decoder<float>> outputInt16Decoder =
193 MakeDecoder<float>(outputInt16Info, outputInt16Data.data());
194 std::unique_ptr<Encoder<float>> outputInt16Encoder =
195 MakeEncoder<float>(outputInt16Info, outputInt16Data.data());
200 inputToInputWeightsDecoder = MakeDecoder<float>(
201 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetTensor<
void>());
202 recurrentToInputWeightsDecoder = MakeDecoder<float>(
203 m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetTensor<
void>());
210 cellToInputWeightsDecoder = MakeDecoder<float>(
211 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetTensor<
void>());
213 cellToForgetWeightsDecoder = MakeDecoder<float>(
214 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetTensor<
void>());
215 cellToOutputWeightsDecoder = MakeDecoder<float>(
216 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetTensor<
void>());
219 if (projectionEnabled)
221 projectionWeightsDecoder = MakeDecoder<float>(
222 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetTensor<
void>());
223 if (m_ProjectionBiasTensor)
225 projectionBiasDecoder = MakeDecoder<float>(
226 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetTensor<
void>());
230 if (layerNormEnabled)
234 inputLayerNormWeightsDecoder = MakeDecoder<float>(
235 m_InputLayerNormWeightsTensor->GetTensorInfo(), m_InputLayerNormWeightsTensor->GetTensor<
void>());
239 m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
240 inputGateBiasDecoder = MakeDecoder<float>(
241 inputGateBiasTensorInfo, m_InputGateBiasTensor->GetTensor<
void>());
244 forgetLayerNormWeightsDecoder = MakeDecoder<float>(
245 m_ForgetLayerNormWeightsTensor->GetTensorInfo(), m_ForgetLayerNormWeightsTensor->GetTensor<
void>());
246 cellLayerNormWeightsDecoder = MakeDecoder<float>(
247 m_CellLayerNormWeightsTensor->GetTensorInfo(), m_CellLayerNormWeightsTensor->GetTensor<
void>());
248 outputLayerNormWeightsDecoder = MakeDecoder<float>(
249 m_OutputLayerNormWeightsTensor->GetTensorInfo(), m_OutputLayerNormWeightsTensor->GetTensor<
void>());
253 m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
254 forgetGateBiasDecoder = MakeDecoder<float>(
255 forgetGateBiasTensorInfo, m_ForgetGateBiasTensor->GetTensor<
void>());
258 m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
259 cellGateBiasDecoder = MakeDecoder<float>(
260 cellGateBiasTensorInfo, m_CellBiasTensor->GetTensor<
void>());
263 m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
264 outputGateBiasDecoder = MakeDecoder<float>(
265 outputGateBiasTensorInfo, m_OutputGateBiasTensor->GetTensor<
void>());
271 ZeroVector(*inputGateEncoder, stateTensorSize);
273 ZeroVector(*forgetGateEncoder, stateTensorSize);
274 ZeroVector(*cellGateEncoder, stateTensorSize);
275 ZeroVector(*outputGateEncoder, stateTensorSize);
276 ZeroVector(*hiddenStateEncoder, stateTensorSize);
282 numUnits, inputSize, *inputDecoder, numBatches, *inputGateEncoder);
286 numUnits, inputSize, *inputDecoder, numBatches, *forgetGateEncoder);
289 numUnits, inputSize, *inputDecoder, numBatches, *cellGateEncoder);
292 numUnits, inputSize, *inputDecoder, numBatches, *outputGateEncoder);
298 numUnits, outputSize, *outputStateInDecoder, numBatches, *inputGateEncoder);
302 numUnits, outputSize, *outputStateInDecoder, numBatches, *forgetGateEncoder);
305 numUnits, outputSize, *outputStateInDecoder, numBatches, *cellGateEncoder);
308 numUnits, outputSize, *outputStateInDecoder, numBatches, *outputGateEncoder);
316 numUnits, *cellStateInDecoder, numBatches, *inputGateEncoder);
319 if (layerNormEnabled)
321 inputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
322 m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
324 inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
327 *inputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
329 inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
332 numUnits, *inputGateDecoder, numBatches, *inputGateEncoder);
334 inputGateInfo.SetQuantizationScale(1.f / 4096);
335 inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
338 numUnits, *inputGateDecoder, numBatches, *inputGateEncoder);
340 inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
343 inputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
344 inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
347 Activation(*inputGateDecoder, *inputGateEncoder,
348 TensorInfo({numUnits, numBatches}, internalType),
351 inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
358 *cellStateInDecoder, numBatches, *forgetGateEncoder);
361 if (layerNormEnabled)
364 forgetGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
365 m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
367 forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
372 *forgetGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
375 forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
378 numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder);
382 forgetGateInfo.SetQuantizationScale(1.f / 4096);
383 forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
386 numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder);
389 forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
392 forgetGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
393 forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
396 Activation(*forgetGateDecoder, *forgetGateEncoder,
397 TensorInfo({numUnits, numBatches}, internalType),
400 forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
403 if (layerNormEnabled)
405 cellGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
406 m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
408 cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
412 cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
415 numUnits, *cellGateDecoder, numBatches, *cellGateEncoder);
417 cellGateInfo.SetQuantizationScale(1.f / 4096);
418 cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
421 numUnits, *cellGateDecoder, numBatches, *cellGateEncoder);
423 cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
426 cellGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
427 cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
430 Activation(*cellGateDecoder, *cellGateEncoder,
431 TensorInfo({numUnits, numBatches}, internalType),
434 cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
440 Sub1Vector(*forgetGateDecoder, stateTensorSize, *forgetGateEncoder);
442 *cellGateDecoder, *forgetGateDecoder, stateTensorSize, *cellStateOutEncoder);
447 *cellGateDecoder, *inputGateDecoder, stateTensorSize, *cellStateOutEncoder);
460 numUnits, *cellStateOutDecoder, numBatches, *outputGateEncoder);
463 if (layerNormEnabled)
465 outputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
466 m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
468 outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
472 outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
475 numBatches, *outputGateEncoder);
477 outputGateInfo.SetQuantizationScale(1.f / 4096);
478 outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
480 VectorBatchVectorAdd(*outputGateBiasDecoder, numUnits, *outputGateDecoder, numBatches, *outputGateEncoder);
482 outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
485 outputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
486 outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
489 Activation(*outputGateDecoder, *outputGateEncoder,
490 TensorInfo({numUnits, numBatches}, internalType),
493 outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
496 Activation(*cellStateOutDecoder, *cellGateEncoder,
497 TensorInfo({numUnits, numBatches}, internalType),
506 if (m_ProjectionBiasTensor)
512 numBatches, *outputInt16Encoder);
514 CopyVector(*outputInt16Decoder, numBatches * outputSize, *outputEncoder);
524 CopyVector(*hiddenStateDecoder, numBatches * outputSize, *outputEncoder);
528 CopyVector(*outputDecoder, numBatches * outputSize, *outputStateOutEncoder);
void MeanStddevNormalization(armnn::Decoder< float > &input_vector, armnn::Encoder< float > &output_vector, uint32_t v_size, uint32_t n_batch, float normalization_epsilon)
void VectorBatchVectorAdd(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Decoder< float > &batchVector, uint32_t nBatch, armnn::Encoder< float > &outResult)
const TensorShape & GetShape() const
virtual void Execute() const override
void ClipVector(armnn::Decoder< float > &vector, uint32_t vSize, float absLimit, armnn::Encoder< float > &outResult)
bool m_PeepholeEnabled
Enable/disable peephole.
void Sub1Vector(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Encoder< float > &result)
float m_HiddenStateScale
Hidden State quantization scale.
const QLstmQueueDescriptor m_Data
float m_OutputIntermediateScale
Output intermediate quantization scale.
void CopyVector(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Encoder< float > &outResult)
void VectorBatchVectorCwiseProductAccumulate(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Decoder< float > &batchVector, uint32_t nBatch, armnn::Encoder< float > &outResult)
void ZeroVector(armnn::Encoder< float > &vector, uint32_t vSize)
Copyright (c) 2020 ARM Limited.
void VectorVectorCwiseProduct(armnn::Decoder< float > &vector1, armnn::Decoder< float > &vector2, uint32_t vSize, armnn::Encoder< float > &outResult)
LayerDescriptor m_Parameters
void VectorBatchVectorCwiseProduct(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Decoder< float > &batchVector, uint32_t nBatch, armnn::Encoder< float > &outResult)
void MatrixBatchVectorMultiplyAccumulate(armnn::Decoder< float > &matrix, uint32_t mRows, uint32_t mCols, armnn::Decoder< float > &vector, uint32_t nBatch, armnn::Encoder< float > &outResult)
bool m_LayerNormEnabled
Enable/disable layer normalization.
RefQLstmWorkload(const QLstmQueueDescriptor &descriptor, const WorkloadInfo &info)
float m_ProjectionClip
Clipping threshold value for the projection.
float m_InputIntermediateScale
Input intermediate quantization scale.
void VectorVectorCwiseProductAccumulate(armnn::Decoder< float > &vector1, armnn::Decoder< float > &vector2, uint32_t vSize, armnn::Encoder< float > &outResult)
void VectorBatchVectorAssign(armnn::Decoder< float > &vector, uint32_t vSize, uint32_t nBatch, armnn::Encoder< float > &outBatchVector)
float m_ForgetIntermediateScale
Forget intermediate quantization scale.
float m_CellClip
Clipping threshold value for the cell state.
std::vector< ITensorHandle * > m_Outputs
bool m_ProjectionEnabled
Enable/disable the projection layer.
Contains information about inputs and outputs to a layer.
std::vector< ITensorHandle * > m_Inputs
const TensorInfo & GetTensorInfo(const ITensorHandle *tensorHandle)
float32 helpers
float m_CellIntermediateScale
Cell intermediate quantization scale.
bool m_CifgEnabled
Enable/disable CIFG (coupled input & forget gate).
std::unique_ptr< armnn::ScopedCpuTensorHandle > AssignScopedCpuTensorHandle(const armnn::ConstCpuTensorHandle *ptr)
int32_t m_HiddenStateZeroPoint
Hidden State zero point.