57 std::vector<ITensorHandle*> outputs)
const 67 auto inputTensor =
reinterpret_cast<float*
>(inputs[0]->Map());
73 std::vector<float> inputValue(inputTensor, inputTensor + inputInfo.
GetNumElements());
81 unsigned int maxTime = inputShape[0];
82 unsigned int batchSize = inputShape[1];
83 unsigned int outputSize = outputShape[2];
84 unsigned int inputSize = inputShape[2];
89 std::vector<float> inputGateScratchBuffer;
90 std::vector<float> cellScratchBuffer(scratchInfo.
GetNumElements(), 0.);
91 std::vector<float> forgetGateScratchBuffer(scratchInfo.
GetNumElements(), 0.);
92 std::vector<float> outputGateScratchBuffer(scratchInfo.
GetNumElements(), 0.);
94 std::vector<float> outputStateOutBuffer(outputStateInfo.
GetNumElements(), 0.);
95 std::vector<float> cellStateOutBuffer(cellStateInfo.
GetNumElements(), 0.);
97 void* outputStateOutData = outputStateOutBuffer.data();
98 void* cellStateOutData = cellStateOutBuffer.data();
100 std::unique_ptr<Encoder<float>> inputGateScratch;
101 std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(scratchInfo, cellScratchBuffer.data());
102 std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(scratchInfo, forgetGateScratchBuffer.data());
103 std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(scratchInfo, outputGateScratchBuffer.data());
105 std::unique_ptr<Decoder<float>> inputGateScratchDecoder;
106 std::unique_ptr<Decoder<float>> cellScratchDecoder = MakeDecoder<float>(scratchInfo, cellScratchBuffer.data());
107 std::unique_ptr<Decoder<float>> forgetGateScratchDecoder = MakeDecoder<float>(scratchInfo,
108 forgetGateScratchBuffer.data());
109 std::unique_ptr<Decoder<float>> outputGateScratchDecoder = MakeDecoder<float>(scratchInfo,
110 outputGateScratchBuffer.data());
118 inputGateScratchBuffer.resize(scratchInfo.GetNumElements(), 0.);
119 inputGateScratch = MakeEncoder<float>(scratchInfo, inputGateScratchBuffer.data());
120 inputGateScratchDecoder = MakeDecoder<float>(scratchInfo, inputGateScratchBuffer.data());
123 std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputStateInfo, outputStateOutData);
124 std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(cellStateInfo, cellStateOutData);
125 std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(cellStateInfo, cellStateOutData);
129 lstmInputInfo.SetShape(batchInputShape);
132 lstmOutputInfo.
SetShape({batchSize, outputSize});
134 const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
135 const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();
136 unsigned int nOutput = recurrentToOutputWeightsShape[1];
137 auto outputStateInData = inputs[1]->Map();
138 std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateInData);
140 auto cellStateInData = inputs[2]->Map();
141 std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateInData);
143 auto currentInputData =
reinterpret_cast<float*
>(inputs[0]->Map());
144 std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
145 auto currentOutputData =
reinterpret_cast<float*
>(outputs[2]->Map());
146 std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
147 std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
149 std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
150 std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
151 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<
void>());
152 std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
153 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<
void>());
154 std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
155 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<
void>());
157 std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
158 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
159 m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<
void>());
160 std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
161 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<
void>());
162 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
163 m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<
void>());
165 std::unique_ptr<Decoder<float>> inputGateBiasTensor;
166 std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
167 m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<
void>());
168 std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
169 m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<
void>());
170 std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
171 m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<
void>());
173 std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
174 std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
175 std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
177 std::unique_ptr<Decoder<float>> projectionWeightsTensor;
178 std::unique_ptr<Decoder<float>> projectionBiasTensor;
180 std::unique_ptr<Decoder<float>> inputLayerNormWeights;
181 std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
182 std::unique_ptr<Decoder<float>> cellLayerNormWeights;
183 std::unique_ptr<Decoder<float>> outputLayerNormWeights;
189 inputLayerNormWeights = MakeDecoder<float>(
190 m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<
void>());
192 forgetLayerNormWeights = MakeDecoder<float>(
193 m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<
void>());
194 cellLayerNormWeights = MakeDecoder<float>(
195 m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<
void>());
196 outputLayerNormWeights = MakeDecoder<float>(
197 m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<
void>());
202 inputToInputWeightsTensor = MakeDecoder<float>(
203 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<
void>());
204 inputGateBiasTensor = MakeDecoder<float>(
205 m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<
void>());
206 recurrentToInputWeightsTensor = MakeDecoder<float>(
207 m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<
void>());
212 cellToForgetWeightsTensor = MakeDecoder<float>(
213 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<
void>());
214 cellToOutputWeightsTensor = MakeDecoder<float>(
215 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<
void>());
218 if (!useCifg && usePeephole)
220 cellToInputWeightsTensor = MakeDecoder<float>(
221 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<
void>());
226 projectionWeightsTensor = MakeDecoder<float>(
227 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<
void>());
228 if (m_ProjectionBiasTensor)
230 projectionBiasTensor = MakeDecoder<float>(
231 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<
void>());
235 unsigned int batchInputSize = batchSize * inputSize;
236 unsigned int batchOutputSize = batchSize * nOutput;
238 for (
unsigned int t = 0; t < maxTime; ++t)
243 inputToOutputWeightsShape,
244 recurrentToOutputWeightsShape,
253 inputToInputWeightsTensor,
254 inputToForgetWeightsTensor,
255 inputToCellWeightsTensor,
256 inputToOutputWeightsTensor,
257 recurrentToInputWeightsTensor,
258 recurrentToForgetWeightsTensor,
259 recurrentToCellWeightsTensor,
260 recurrentToOutputWeightsTensor,
261 cellToInputWeightsTensor,
262 cellToForgetWeightsTensor,
263 cellToOutputWeightsTensor,
265 forgetGateBiasTensor,
267 outputGateBiasTensor,
268 projectionWeightsTensor,
269 projectionBiasTensor,
270 inputLayerNormWeights,
271 forgetLayerNormWeights,
272 cellLayerNormWeights,
273 outputLayerNormWeights,
278 inputGateScratchDecoder,
280 forgetGateScratchDecoder,
281 outputGateScratchDecoder,
284 currentInputData += batchInputSize;
285 inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
286 currentOutputData += batchOutputSize;
287 output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
288 outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
291 outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateOutData);
294 cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateOutData);
301 auto outputData =
reinterpret_cast<float*
>(outputs[2]->Map());
302 std::vector<float> outputValue(outputData, outputData + outputInfo.
GetNumElements());
bool m_ProjectionEnabled
Enable/disable the projection layer.
const TensorShape & GetShape() const
void LstmImpl(const LstmDescriptor &descriptor, const TensorInfo &inputInfo, const TensorInfo &outputInfo, const TensorShape &inputToOutputWeightsShape, const TensorShape &recurrentToOutputWeightsShape, std::unique_ptr< Decoder< float >> &inputData, std::unique_ptr< Decoder< float >> &outputStateIn, std::unique_ptr< Decoder< float >> &cellStateIn, std::unique_ptr< Encoder< float >> &outputStateOut, std::unique_ptr< Encoder< float >> &cellStateOut, std::unique_ptr< Encoder< float >> &output, std::unique_ptr< Decoder< float >> &cellStateOutDecoder, std::unique_ptr< Decoder< float >> &outputDecoder, std::unique_ptr< Decoder< float >> &inputToInputWeightsTensor, std::unique_ptr< Decoder< float >> &inputToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &inputToCellWeightsTensor, std::unique_ptr< Decoder< float >> &inputToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToInputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToCellWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToInputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &cellToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &inputGateBiasTensor, std::unique_ptr< Decoder< float >> &forgetGateBiasTensor, std::unique_ptr< Decoder< float >> &cellBiasTensor, std::unique_ptr< Decoder< float >> &outputGateBiasTensor, std::unique_ptr< Decoder< float >> &projectionWeightsTensor, std::unique_ptr< Decoder< float >> &projectionBiasTensor, std::unique_ptr< Decoder< float >> &inputLayerNormWeights, std::unique_ptr< Decoder< float >> &forgetLayerNormWeights, std::unique_ptr< Decoder< float >> &cellLayerNormWeights, std::unique_ptr< Decoder< float >> &outputLayerNormWeights, std::unique_ptr< Encoder< float >> &inputGateScratch, std::unique_ptr< Encoder< float >> &cellScratch, std::unique_ptr< Encoder< float >> &forgetGateScratch, std::unique_ptr< Encoder< float >> &outputGateScratch, std::unique_ptr< Decoder< float >> &inputGateScratchDecoder, std::unique_ptr< Decoder< float >> &cellScratchDecoder, std::unique_ptr< Decoder< float >> &forgetGateScratchDecoder, std::unique_ptr< Decoder< float >> &outputGateScratchDecoder, float layerNormEpsilon)
std::unique_ptr< armnn::ScopedTensorHandle > AssignScopedTensorHandle(const armnn::ConstTensorHandle *ptr)
bool m_TimeMajor
Enable/disable time major.
Copyright (c) 2021 ARM Limited and Contributors.
LayerDescriptor m_Parameters
void SetShape(const TensorShape &newShape)
std::vector< ITensorHandle * > m_Inputs
UnidirectionalSequenceLstmQueueDescriptor m_Data
bool m_PeepholeEnabled
Enable/disable peephole.
bool m_CifgEnabled
Enable/disable cifg (coupled input & forget gate).
std::vector< ITensorHandle * > m_Outputs
void ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor) override
std::vector< ITensorHandle * > m_Outputs
bool m_LayerNormEnabled
Enable/disable layer normalization.
Contains information about TensorInfos of a layer.
std::vector< ITensorHandle * > m_Inputs
const TensorInfo & GetTensorInfo(const ITensorHandle *tensorHandle)
float32 helpers
void Execute() const override
armnn::TensorShape Permuted(const armnn::TensorShape &srcShape, const armnn::PermutationVector &mappings)
RefUnidirectionalSequenceLstmWorkload(const UnidirectionalSequenceLstmQueueDescriptor &descriptor, const WorkloadInfo &info)
unsigned int GetNumElements() const