57 std::vector<ITensorHandle*> outputs)
const 65 auto inputTensor =
reinterpret_cast<float*
>(inputs[0]->Map());
71 std::vector<float> inputValue(inputTensor, inputTensor + inputInfo.
GetNumElements());
79 unsigned int maxTime = inputShape[0];
80 unsigned int batchSize = inputShape[1];
81 unsigned int outputSize = outputShape[2];
82 unsigned int inputSize = inputShape[2];
87 std::vector<float> inputGateScratchBuffer;
88 std::vector<float> cellScratchBuffer(scratchInfo.
GetNumElements(), 0.);
89 std::vector<float> forgetGateScratchBuffer(scratchInfo.
GetNumElements(), 0.);
90 std::vector<float> outputGateScratchBuffer(scratchInfo.
GetNumElements(), 0.);
92 std::vector<float> outputStateOutBuffer(outputStateInfo.
GetNumElements(), 0.);
93 std::vector<float> cellStateOutBuffer(cellStateInfo.
GetNumElements(), 0.);
95 void* outputStateOutData = outputStateOutBuffer.data();
96 void* cellStateOutData = cellStateOutBuffer.data();
98 std::unique_ptr<Encoder<float>> inputGateScratch;
99 std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(scratchInfo, cellScratchBuffer.data());
100 std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(scratchInfo, forgetGateScratchBuffer.data());
101 std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(scratchInfo, outputGateScratchBuffer.data());
103 std::unique_ptr<Decoder<float>> inputGateScratchDecoder;
104 std::unique_ptr<Decoder<float>> cellScratchDecoder = MakeDecoder<float>(scratchInfo, cellScratchBuffer.data());
105 std::unique_ptr<Decoder<float>> forgetGateScratchDecoder = MakeDecoder<float>(scratchInfo,
106 forgetGateScratchBuffer.data());
107 std::unique_ptr<Decoder<float>> outputGateScratchDecoder = MakeDecoder<float>(scratchInfo,
108 outputGateScratchBuffer.data());
116 inputGateScratchBuffer.resize(scratchInfo.GetNumElements(), 0.);
117 inputGateScratch = MakeEncoder<float>(scratchInfo, inputGateScratchBuffer.data());
118 inputGateScratchDecoder = MakeDecoder<float>(scratchInfo, inputGateScratchBuffer.data());
121 std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputStateInfo, outputStateOutData);
122 std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(cellStateInfo, cellStateOutData);
123 std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(cellStateInfo, cellStateOutData);
127 lstmInputInfo.SetShape(batchInputShape);
130 lstmOutputInfo.
SetShape({batchSize, outputSize});
132 const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
133 const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();
134 unsigned int nOutput = recurrentToOutputWeightsShape[1];
135 auto outputStateInData = inputs[1]->Map();
136 std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateInData);
138 auto cellStateInData = inputs[2]->Map();
139 std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateInData);
141 auto currentInputData =
reinterpret_cast<float*
>(inputs[0]->Map());
142 std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
143 auto currentOutputData =
reinterpret_cast<float*
>(outputs[0]->Map());
144 std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
145 std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
147 std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
148 std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
149 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<
void>());
150 std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
151 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<
void>());
152 std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
153 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<
void>());
155 std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
156 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
157 m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<
void>());
158 std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
159 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<
void>());
160 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
161 m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<
void>());
163 std::unique_ptr<Decoder<float>> inputGateBiasTensor;
164 std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
165 m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<
void>());
166 std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
167 m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<
void>());
168 std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
169 m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<
void>());
171 std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
172 std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
173 std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
175 std::unique_ptr<Decoder<float>> projectionWeightsTensor;
176 std::unique_ptr<Decoder<float>> projectionBiasTensor;
178 std::unique_ptr<Decoder<float>> inputLayerNormWeights;
179 std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
180 std::unique_ptr<Decoder<float>> cellLayerNormWeights;
181 std::unique_ptr<Decoder<float>> outputLayerNormWeights;
187 inputLayerNormWeights = MakeDecoder<float>(
188 m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<
void>());
190 forgetLayerNormWeights = MakeDecoder<float>(
191 m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<
void>());
192 cellLayerNormWeights = MakeDecoder<float>(
193 m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<
void>());
194 outputLayerNormWeights = MakeDecoder<float>(
195 m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<
void>());
200 inputToInputWeightsTensor = MakeDecoder<float>(
201 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<
void>());
202 inputGateBiasTensor = MakeDecoder<float>(
203 m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<
void>());
204 recurrentToInputWeightsTensor = MakeDecoder<float>(
205 m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<
void>());
210 cellToForgetWeightsTensor = MakeDecoder<float>(
211 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<
void>());
212 cellToOutputWeightsTensor = MakeDecoder<float>(
213 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<
void>());
216 if (!useCifg && usePeephole)
218 cellToInputWeightsTensor = MakeDecoder<float>(
219 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<
void>());
224 projectionWeightsTensor = MakeDecoder<float>(
225 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<
void>());
226 if (m_ProjectionBiasTensor)
228 projectionBiasTensor = MakeDecoder<float>(
229 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<
void>());
233 unsigned int batchInputSize = batchSize * inputSize;
234 unsigned int batchOutputSize = batchSize * nOutput;
236 for (
unsigned int t = 0; t < maxTime; ++t)
241 inputToOutputWeightsShape,
242 recurrentToOutputWeightsShape,
251 inputToInputWeightsTensor,
252 inputToForgetWeightsTensor,
253 inputToCellWeightsTensor,
254 inputToOutputWeightsTensor,
255 recurrentToInputWeightsTensor,
256 recurrentToForgetWeightsTensor,
257 recurrentToCellWeightsTensor,
258 recurrentToOutputWeightsTensor,
259 cellToInputWeightsTensor,
260 cellToForgetWeightsTensor,
261 cellToOutputWeightsTensor,
263 forgetGateBiasTensor,
265 outputGateBiasTensor,
266 projectionWeightsTensor,
267 projectionBiasTensor,
268 inputLayerNormWeights,
269 forgetLayerNormWeights,
270 cellLayerNormWeights,
271 outputLayerNormWeights,
276 inputGateScratchDecoder,
278 forgetGateScratchDecoder,
279 outputGateScratchDecoder,
282 currentInputData += batchInputSize;
283 inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
284 currentOutputData += batchOutputSize;
285 output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
286 outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
289 outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateOutData);
292 cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateOutData);
299 auto outputData =
reinterpret_cast<float*
>(outputs[0]->Map());
300 std::vector<float> outputValue(outputData, outputData + outputInfo.
GetNumElements());
bool m_ProjectionEnabled
Enable/disable the projection layer.
const TensorShape & GetShape() const
void LstmImpl(const LstmDescriptor &descriptor, const TensorInfo &inputInfo, const TensorInfo &outputInfo, const TensorShape &inputToOutputWeightsShape, const TensorShape &recurrentToOutputWeightsShape, std::unique_ptr< Decoder< float >> &inputData, std::unique_ptr< Decoder< float >> &outputStateIn, std::unique_ptr< Decoder< float >> &cellStateIn, std::unique_ptr< Encoder< float >> &outputStateOut, std::unique_ptr< Encoder< float >> &cellStateOut, std::unique_ptr< Encoder< float >> &output, std::unique_ptr< Decoder< float >> &cellStateOutDecoder, std::unique_ptr< Decoder< float >> &outputDecoder, std::unique_ptr< Decoder< float >> &inputToInputWeightsTensor, std::unique_ptr< Decoder< float >> &inputToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &inputToCellWeightsTensor, std::unique_ptr< Decoder< float >> &inputToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToInputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToCellWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToInputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &cellToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &inputGateBiasTensor, std::unique_ptr< Decoder< float >> &forgetGateBiasTensor, std::unique_ptr< Decoder< float >> &cellBiasTensor, std::unique_ptr< Decoder< float >> &outputGateBiasTensor, std::unique_ptr< Decoder< float >> &projectionWeightsTensor, std::unique_ptr< Decoder< float >> &projectionBiasTensor, std::unique_ptr< Decoder< float >> &inputLayerNormWeights, std::unique_ptr< Decoder< float >> &forgetLayerNormWeights, std::unique_ptr< Decoder< float >> &cellLayerNormWeights, std::unique_ptr< Decoder< float >> &outputLayerNormWeights, std::unique_ptr< Encoder< float >> &inputGateScratch, std::unique_ptr< Encoder< float >> &cellScratch, std::unique_ptr< Encoder< float >> &forgetGateScratch, std::unique_ptr< Encoder< float >> &outputGateScratch, std::unique_ptr< Decoder< float >> &inputGateScratchDecoder, std::unique_ptr< Decoder< float >> &cellScratchDecoder, std::unique_ptr< Decoder< float >> &forgetGateScratchDecoder, std::unique_ptr< Decoder< float >> &outputGateScratchDecoder, float layerNormEpsilon)
std::unique_ptr< armnn::ScopedTensorHandle > AssignScopedTensorHandle(const armnn::ConstTensorHandle *ptr)
bool m_TimeMajor
Enable/disable time major.
Copyright (c) 2021 ARM Limited and Contributors.
LayerDescriptor m_Parameters
void SetShape(const TensorShape &newShape)
std::vector< ITensorHandle * > m_Inputs
UnidirectionalSequenceLstmQueueDescriptor m_Data
bool m_PeepholeEnabled
Enable/disable peephole.
bool m_CifgEnabled
Enable/disable cifg (coupled input & forget gate).
std::vector< ITensorHandle * > m_Outputs
void ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor) override
std::vector< ITensorHandle * > m_Outputs
bool m_LayerNormEnabled
Enable/disable layer normalization.
Contains information about TensorInfos of a layer.
std::vector< ITensorHandle * > m_Inputs
const TensorInfo & GetTensorInfo(const ITensorHandle *tensorHandle)
float32 helpers
void Execute() const override
armnn::TensorShape Permuted(const armnn::TensorShape &srcShape, const armnn::PermutationVector &mappings)
RefUnidirectionalSequenceLstmWorkload(const UnidirectionalSequenceLstmQueueDescriptor &descriptor, const WorkloadInfo &info)
unsigned int GetNumElements() const