49 const TensorShape& inputShape = inputInfo.GetShape();
50 const DataType& outputType = outputInfo.GetDataType();
52 std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo,
m_Data.
m_Outputs[1]->Map());
53 std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(outputInfo,
m_Data.
m_Outputs[2]->Map());
54 std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(outputInfo,
m_Data.
m_Outputs[3]->Map());
56 std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(outputInfo,
m_Data.
m_Outputs[2]->Map());
57 std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(outputInfo,
m_Data.
m_Outputs[3]->Map());
59 std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(inputInfo,
m_Data.
m_Inputs[0]->Map());
60 std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(inputInfo,
m_Data.
m_Inputs[1]->Map());
61 std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(inputInfo,
m_Data.
m_Inputs[2]->Map());
63 const uint32_t nBatch = inputShape[0];
64 const uint32_t nInput = inputShape[1];
66 const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0];
67 const uint32_t nOutput = m_RecurrentToOutputWeightsTensor->GetShape()[1];
74 std::unique_ptr<Encoder<float>> inputGateScratch = MakeEncoder<float>(outputInfo,
m_Data.
m_Outputs[0]->Map());
75 std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(outputInfo,
m_Data.
m_Outputs[0]->Map());
76 std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(outputInfo,
m_Data.
m_Outputs[0]->Map());
77 std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(outputInfo,
m_Data.
m_Outputs[0]->Map());
79 std::unique_ptr<Decoder<float>> inputGateScratchDecoder =
81 std::unique_ptr<Decoder<float>> cellScratchDecoder =
83 std::unique_ptr<Decoder<float>> forgetGateScratchDecoder =
85 std::unique_ptr<Decoder<float>> outputGateScratchDecoder =
90 *cellScratch += (0 * nCell * nBatch);
91 *forgetGateScratch += (1 * nCell * nBatch);
92 *outputGateScratch += (2 * nCell * nBatch);
94 *cellScratchDecoder += (0 * nCell * nBatch);
95 *forgetGateScratchDecoder += (1 * nCell * nBatch);
96 *outputGateScratchDecoder += (2 * nCell * nBatch);
100 *inputGateScratch += (0 * nCell * nBatch);
101 *cellScratch += (1 * nCell * nBatch);
102 *forgetGateScratch += (2 * nCell * nBatch);
103 *outputGateScratch += (3 * nCell * nBatch);
105 *inputGateScratchDecoder += (0 * nCell * nBatch);
106 *cellScratchDecoder += (1 * nCell * nBatch);
107 *forgetGateScratchDecoder += (2 * nCell * nBatch);
108 *outputGateScratchDecoder += (3 * nCell * nBatch);
111 std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
112 std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
113 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetTensor<
void>());
114 std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
115 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetTensor<
void>());
116 std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
117 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetTensor<
void>());
119 std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
120 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
121 m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetTensor<
void>());
122 std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
123 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetTensor<
void>());
124 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
125 m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetTensor<
void>());
127 std::unique_ptr<Decoder<float>> inputGateBiasTensor;
128 std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
129 m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetTensor<
void>());
130 std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
131 m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetTensor<
void>());
132 std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
133 m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetTensor<
void>());
135 std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
136 std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
137 std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
139 std::unique_ptr<Decoder<float>> projectionWeightsTensor;
140 std::unique_ptr<Decoder<float>> projectionBiasTensor;
142 std::unique_ptr<Decoder<float>> inputLayerNormWeights;
143 std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
144 std::unique_ptr<Decoder<float>> cellLayerNormWeights;
145 std::unique_ptr<Decoder<float>> outputLayerNormWeights;
151 inputLayerNormWeights = MakeDecoder<float>(
152 m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetTensor<
void>());
154 forgetLayerNormWeights = MakeDecoder<float>(
155 m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetTensor<
void>());
156 cellLayerNormWeights = MakeDecoder<float>(
157 m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetTensor<
void>());
158 outputLayerNormWeights = MakeDecoder<float>(
159 m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetTensor<
void>());
164 inputToInputWeightsTensor = MakeDecoder<float>(
165 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetTensor<
void>());
166 inputGateBiasTensor = MakeDecoder<float>(
167 m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetTensor<
void>());
168 recurrentToInputWeightsTensor = MakeDecoder<float>(
169 m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetTensor<
void>());
174 cellToForgetWeightsTensor = MakeDecoder<float>(
175 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetTensor<
void>());
176 cellToOutputWeightsTensor = MakeDecoder<float>(
177 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetTensor<
void>());
180 if (!useCifg && usePeephole)
182 cellToInputWeightsTensor = MakeDecoder<float>(
183 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetTensor<
void>());
188 projectionWeightsTensor = MakeDecoder<float>(
189 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetTensor<
void>());
190 if (m_ProjectionBiasTensor)
192 projectionBiasTensor = MakeDecoder<float>(
193 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetTensor<
void>());
203 nCell, nBatch, *inputGateScratch);
206 nCell, nBatch, *forgetGateScratch);
208 nCell, nBatch, *cellScratch);
210 nCell, nBatch, *outputGateScratch);
217 ZeroVector(*inputGateScratch, nCell * nBatch);
219 ZeroVector(*forgetGateScratch, nCell * nBatch);
221 ZeroVector(*outputGateScratch, nCell * nBatch);
228 nCell, nInput, *inputData, nBatch, *inputGateScratch);
231 nCell, nInput, *inputData, nBatch, *forgetGateScratch);
233 nCell, nInput, *inputData, nBatch, *cellScratch);
235 nCell, nInput, *inputData, nBatch, *outputGateScratch);
241 nCell, nOutput, *outputStateIn, nBatch, *inputGateScratch);
244 nCell, nOutput, *outputStateIn, nBatch, *forgetGateScratch);
246 nCell, nOutput, *outputStateIn, nBatch, *cellScratch);
248 nCell, nOutput, *outputStateIn, nBatch, *outputGateScratch);
256 nCell, *cellStateIn, nBatch, *inputGateScratch);
261 *inputGateScratch, nCell, nBatch, m_LayerNormEpsilon);
263 nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch);
265 nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch);
267 Activation(*inputGateScratchDecoder, *inputGateScratch,
268 TensorInfo({nCell, nBatch}, outputType),
276 *cellStateIn, nBatch, *forgetGateScratch);
281 *forgetGateScratch, nCell, nBatch, m_LayerNormEpsilon);
283 nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch);
285 nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch);
287 Activation(*forgetGateScratchDecoder, *forgetGateScratch,
288 TensorInfo({nCell, nBatch}, outputType),
295 *cellScratch, nCell, nBatch, m_LayerNormEpsilon);
297 nCell, *cellScratchDecoder, nBatch, *cellScratch);
299 nCell, *cellScratchDecoder, nBatch, *cellScratch);
312 TensorInfo({nCell, nBatch}, outputType),
313 armnnActivationFunc, a, b);
317 Sub1Vector(*forgetGateScratchDecoder, nBatch * nCell, *forgetGateScratch);
319 *cellScratchDecoder, *forgetGateScratchDecoder, nBatch * nCell, *cellStateOut);
324 *cellScratchDecoder, *inputGateScratchDecoder, nBatch * nCell, *cellStateOut);
335 nCell, *cellStateOutDecoder, nBatch, *outputGateScratch);
340 *outputGateScratch, nCell, nBatch, m_LayerNormEpsilon);
342 nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch);
344 nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch);
346 Activation(*outputGateScratchDecoder, *outputGateScratch,
347 TensorInfo({nCell, nBatch}, outputType),
352 Activation(*cellStateOutDecoder, *cellScratch,
353 TensorInfo({nCell, nBatch}, outputType),
354 armnnActivationFunc, a, b);
362 if (m_ProjectionBiasTensor)
365 nOutput, nBatch, *output);
368 nOutput, nCell, *outputGateScratchDecoder, nBatch, *output);
377 CopyVector(*outputGateScratchDecoder, nBatch * nOutput, *output);
380 CopyVector(*outputDecoder, nBatch * nOutput, *outputStateOut);
void MeanStddevNormalization(armnn::Decoder< float > &input_vector, armnn::Encoder< float > &output_vector, uint32_t v_size, uint32_t n_batch, float normalization_epsilon)
void VectorBatchVectorAdd(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Decoder< float > &batchVector, uint32_t nBatch, armnn::Encoder< float > &outResult)
bool m_ProjectionEnabled
Enable/disable the projection layer.
float m_ClippingThresProj
Clipping threshold value for the projection.
void ClipVector(armnn::Decoder< float > &vector, uint32_t vSize, float absLimit, armnn::Encoder< float > &outResult)
void Sub1Vector(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Encoder< float > &result)
const LstmQueueDescriptor m_Data
void CopyVector(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Encoder< float > &outResult)
const TensorInfo & GetTensorInfo(const ITensorHandle *tensorHandle)
float32 helpers
void VectorBatchVectorCwiseProductAccumulate(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Decoder< float > &batchVector, uint32_t nBatch, armnn::Encoder< float > &outResult)
void ZeroVector(armnn::Encoder< float > &vector, uint32_t vSize)
void VectorVectorCwiseProduct(armnn::Decoder< float > &vector1, armnn::Decoder< float > &vector2, uint32_t vSize, armnn::Encoder< float > &outResult)
LayerDescriptor m_Parameters
void VectorBatchVectorCwiseProduct(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Decoder< float > &batchVector, uint32_t nBatch, armnn::Encoder< float > &outResult)
void MatrixBatchVectorMultiplyAccumulate(armnn::Decoder< float > &matrix, uint32_t mRows, uint32_t mCols, armnn::Decoder< float > &vector, uint32_t nBatch, armnn::Encoder< float > &outResult)
bool m_PeepholeEnabled
Enable/disable peephole.
void VectorVectorCwiseProductAccumulate(armnn::Decoder< float > &vector1, armnn::Decoder< float > &vector2, uint32_t vSize, armnn::Encoder< float > &outResult)
uint32_t m_ActivationFunc
The activation function to use.
void VectorBatchVectorAssign(armnn::Decoder< float > &vector, uint32_t vSize, uint32_t nBatch, armnn::Encoder< float > &outBatchVector)
float m_ClippingThresCell
Clipping threshold value for the cell state.
bool m_CifgEnabled
Enable/disable cifg (coupled input & forget gate).
std::vector< ITensorHandle * > m_Outputs
bool m_LayerNormEnabled
Enable/disable layer normalization.
std::vector< ITensorHandle * > m_Inputs
void SetActivationParameters(uint32_t activation, armnn::ActivationFunction &outArmnnActivation, float &outA, float &outB)