58 std::vector<ITensorHandle*> outputs)
const 68 auto inputTensor =
reinterpret_cast<float*
>(inputs[0]->Map());
74 std::vector<float> inputValue(inputTensor, inputTensor + inputInfo.
GetNumElements());
82 unsigned int maxTime = inputShape[0];
83 unsigned int batchSize = inputShape[1];
84 unsigned int outputSize = outputShape[2];
85 unsigned int inputSize = inputShape[2];
90 std::vector<float> inputGateScratchBuffer;
91 std::vector<float> cellScratchBuffer(scratchInfo.
GetNumElements(), 0.);
92 std::vector<float> forgetGateScratchBuffer(scratchInfo.
GetNumElements(), 0.);
93 std::vector<float> outputGateScratchBuffer(scratchInfo.
GetNumElements(), 0.);
95 std::vector<float> outputStateOutBuffer(outputStateInfo.
GetNumElements(), 0.);
96 std::vector<float> cellStateOutBuffer(cellStateInfo.
GetNumElements(), 0.);
98 void* outputStateOutData = outputStateOutBuffer.data();
99 void* cellStateOutData = cellStateOutBuffer.data();
101 std::unique_ptr<Encoder<float>> inputGateScratch;
102 std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(scratchInfo, cellScratchBuffer.data());
103 std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(scratchInfo, forgetGateScratchBuffer.data());
104 std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(scratchInfo, outputGateScratchBuffer.data());
106 std::unique_ptr<Decoder<float>> inputGateScratchDecoder;
107 std::unique_ptr<Decoder<float>> cellScratchDecoder = MakeDecoder<float>(scratchInfo, cellScratchBuffer.data());
108 std::unique_ptr<Decoder<float>> forgetGateScratchDecoder = MakeDecoder<float>(scratchInfo,
109 forgetGateScratchBuffer.data());
110 std::unique_ptr<Decoder<float>> outputGateScratchDecoder = MakeDecoder<float>(scratchInfo,
111 outputGateScratchBuffer.data());
119 inputGateScratchBuffer.resize(scratchInfo.GetNumElements(), 0.);
120 inputGateScratch = MakeEncoder<float>(scratchInfo, inputGateScratchBuffer.data());
121 inputGateScratchDecoder = MakeDecoder<float>(scratchInfo, inputGateScratchBuffer.data());
124 std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputStateInfo, outputStateOutData);
125 std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(cellStateInfo, cellStateOutData);
126 std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(cellStateInfo, cellStateOutData);
130 lstmInputInfo.SetShape(batchInputShape);
133 lstmOutputInfo.
SetShape({batchSize, outputSize});
135 const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
136 const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();
137 unsigned int nOutput = recurrentToOutputWeightsShape[1];
138 auto outputStateInData = inputs[1]->Map();
139 std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateInData);
141 auto cellStateInData = inputs[2]->Map();
142 std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateInData);
144 auto currentInputData =
reinterpret_cast<float*
>(inputs[0]->Map());
145 std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
146 auto currentOutputData =
reinterpret_cast<float*
>(outputs[2]->Map());
147 std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
148 std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
150 std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
151 std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
152 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<
void>());
153 std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
154 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<
void>());
155 std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
156 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<
void>());
158 std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
159 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
160 m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<
void>());
161 std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
162 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<
void>());
163 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
164 m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<
void>());
166 std::unique_ptr<Decoder<float>> inputGateBiasTensor;
167 std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
168 m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<
void>());
169 std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
170 m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<
void>());
171 std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
172 m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<
void>());
174 std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
175 std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
176 std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
178 std::unique_ptr<Decoder<float>> projectionWeightsTensor;
179 std::unique_ptr<Decoder<float>> projectionBiasTensor;
181 std::unique_ptr<Decoder<float>> inputLayerNormWeights;
182 std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
183 std::unique_ptr<Decoder<float>> cellLayerNormWeights;
184 std::unique_ptr<Decoder<float>> outputLayerNormWeights;
190 inputLayerNormWeights = MakeDecoder<float>(
191 m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<
void>());
193 forgetLayerNormWeights = MakeDecoder<float>(
194 m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<
void>());
195 cellLayerNormWeights = MakeDecoder<float>(
196 m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<
void>());
197 outputLayerNormWeights = MakeDecoder<float>(
198 m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<
void>());
203 inputToInputWeightsTensor = MakeDecoder<float>(
204 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<
void>());
205 inputGateBiasTensor = MakeDecoder<float>(
206 m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<
void>());
207 recurrentToInputWeightsTensor = MakeDecoder<float>(
208 m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<
void>());
213 cellToForgetWeightsTensor = MakeDecoder<float>(
214 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<
void>());
215 cellToOutputWeightsTensor = MakeDecoder<float>(
216 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<
void>());
219 if (!useCifg && usePeephole)
221 cellToInputWeightsTensor = MakeDecoder<float>(
222 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<
void>());
227 projectionWeightsTensor = MakeDecoder<float>(
228 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<
void>());
229 if (m_ProjectionBiasTensor)
231 projectionBiasTensor = MakeDecoder<float>(
232 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<
void>());
236 unsigned int batchInputSize = batchSize * inputSize;
237 unsigned int batchOutputSize = batchSize * nOutput;
239 for (
unsigned int t = 0; t < maxTime; ++t)
244 inputToOutputWeightsShape,
245 recurrentToOutputWeightsShape,
254 inputToInputWeightsTensor,
255 inputToForgetWeightsTensor,
256 inputToCellWeightsTensor,
257 inputToOutputWeightsTensor,
258 recurrentToInputWeightsTensor,
259 recurrentToForgetWeightsTensor,
260 recurrentToCellWeightsTensor,
261 recurrentToOutputWeightsTensor,
262 cellToInputWeightsTensor,
263 cellToForgetWeightsTensor,
264 cellToOutputWeightsTensor,
266 forgetGateBiasTensor,
268 outputGateBiasTensor,
269 projectionWeightsTensor,
270 projectionBiasTensor,
271 inputLayerNormWeights,
272 forgetLayerNormWeights,
273 cellLayerNormWeights,
274 outputLayerNormWeights,
279 inputGateScratchDecoder,
281 forgetGateScratchDecoder,
282 outputGateScratchDecoder,
285 currentInputData += batchInputSize;
286 inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
287 currentOutputData += batchOutputSize;
288 output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
289 outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
292 outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateOutData);
295 cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateOutData);
302 auto outputData =
reinterpret_cast<float*
>(outputs[2]->Map());
303 std::vector<float> outputValue(outputData, outputData + outputInfo.
GetNumElements());
bool m_ProjectionEnabled
Enable/disable the projection layer.
const TensorShape & GetShape() const
void LstmImpl(const LstmDescriptor &descriptor, const TensorInfo &inputInfo, const TensorInfo &outputInfo, const TensorShape &inputToOutputWeightsShape, const TensorShape &recurrentToOutputWeightsShape, std::unique_ptr< Decoder< float >> &inputData, std::unique_ptr< Decoder< float >> &outputStateIn, std::unique_ptr< Decoder< float >> &cellStateIn, std::unique_ptr< Encoder< float >> &outputStateOut, std::unique_ptr< Encoder< float >> &cellStateOut, std::unique_ptr< Encoder< float >> &output, std::unique_ptr< Decoder< float >> &cellStateOutDecoder, std::unique_ptr< Decoder< float >> &outputDecoder, std::unique_ptr< Decoder< float >> &inputToInputWeightsTensor, std::unique_ptr< Decoder< float >> &inputToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &inputToCellWeightsTensor, std::unique_ptr< Decoder< float >> &inputToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToInputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToCellWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToInputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &cellToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &inputGateBiasTensor, std::unique_ptr< Decoder< float >> &forgetGateBiasTensor, std::unique_ptr< Decoder< float >> &cellBiasTensor, std::unique_ptr< Decoder< float >> &outputGateBiasTensor, std::unique_ptr< Decoder< float >> &projectionWeightsTensor, std::unique_ptr< Decoder< float >> &projectionBiasTensor, std::unique_ptr< Decoder< float >> &inputLayerNormWeights, std::unique_ptr< Decoder< float >> &forgetLayerNormWeights, std::unique_ptr< Decoder< float >> &cellLayerNormWeights, std::unique_ptr< Decoder< float >> &outputLayerNormWeights, std::unique_ptr< Encoder< float >> &inputGateScratch, std::unique_ptr< Encoder< float >> &cellScratch, std::unique_ptr< Encoder< float >> &forgetGateScratch, std::unique_ptr< Encoder< float >> &outputGateScratch, std::unique_ptr< Decoder< float >> &inputGateScratchDecoder, std::unique_ptr< Decoder< float >> &cellScratchDecoder, std::unique_ptr< Decoder< float >> &forgetGateScratchDecoder, std::unique_ptr< Decoder< float >> &outputGateScratchDecoder, float layerNormEpsilon)
std::unique_ptr< armnn::ScopedTensorHandle > AssignScopedTensorHandle(const armnn::ConstTensorHandle *ptr)
bool m_TimeMajor
Enable/disable time major.
Copyright (c) 2021 ARM Limited and Contributors.
LayerDescriptor m_Parameters
void SetShape(const TensorShape &newShape)
std::vector< ITensorHandle * > m_Inputs
UnidirectionalSequenceLstmQueueDescriptor m_Data
bool m_PeepholeEnabled
Enable/disable peephole.
bool m_CifgEnabled
Enable/disable cifg (coupled input & forget gate).
std::vector< ITensorHandle * > m_Outputs
std::vector< ITensorHandle * > m_Outputs
bool m_LayerNormEnabled
Enable/disable layer normalization.
Contains information about TensorInfos of a layer.
std::vector< ITensorHandle * > m_Inputs
void Execute() const override
armnn::TensorShape Permuted(const armnn::TensorShape &srcShape, const armnn::PermutationVector &mappings)
RefUnidirectionalSequenceLstmWorkload(const UnidirectionalSequenceLstmQueueDescriptor &descriptor, const WorkloadInfo &info)
unsigned int GetNumElements() const
const TensorInfo & GetTensorInfo(const ITensorHandle *tensorHandle)
float32 helpers
void ExecuteAsync(ExecutionData &executionData) override