62 std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo, outputs[1]->Map());
63 std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(outputInfo, outputs[2]->Map());
64 std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(outputInfo, outputs[3]->Map());
66 std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(outputInfo, outputs[2]->Map());
67 std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(outputInfo, outputs[3]->Map());
69 std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(inputInfo, inputs[0]->Map());
70 std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(inputInfo, inputs[1]->Map());
71 std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(inputInfo, inputs[2]->Map());
73 const uint32_t nBatch = inputShape[0];
74 const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0];
81 std::unique_ptr<Encoder<float>> inputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
82 std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
83 std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
84 std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
86 std::unique_ptr<Decoder<float>> inputGateScratchDecoder =
87 MakeDecoder<float>(outputInfo, outputs[0]->Map());
88 std::unique_ptr<Decoder<float>> cellScratchDecoder =
89 MakeDecoder<float>(outputInfo, outputs[0]->Map());
90 std::unique_ptr<Decoder<float>> forgetGateScratchDecoder =
91 MakeDecoder<float>(outputInfo, outputs[0]->Map());
92 std::unique_ptr<Decoder<float>> outputGateScratchDecoder =
93 MakeDecoder<float>(outputInfo, outputs[0]->Map());
97 *cellScratch += (0 * nCell * nBatch);
98 *forgetGateScratch += (1 * nCell * nBatch);
99 *outputGateScratch += (2 * nCell * nBatch);
101 *cellScratchDecoder += (0 * nCell * nBatch);
102 *forgetGateScratchDecoder += (1 * nCell * nBatch);
103 *outputGateScratchDecoder += (2 * nCell * nBatch);
107 *inputGateScratch += (0 * nCell * nBatch);
108 *cellScratch += (1 * nCell * nBatch);
109 *forgetGateScratch += (2 * nCell * nBatch);
110 *outputGateScratch += (3 * nCell * nBatch);
112 *inputGateScratchDecoder += (0 * nCell * nBatch);
113 *cellScratchDecoder += (1 * nCell * nBatch);
114 *forgetGateScratchDecoder += (2 * nCell * nBatch);
115 *outputGateScratchDecoder += (3 * nCell * nBatch);
118 std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
119 std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
120 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<
void>());
121 std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
122 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<
void>());
123 std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
124 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<
void>());
126 std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
127 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
128 m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<
void>());
129 std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
130 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<
void>());
131 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
132 m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<
void>());
134 std::unique_ptr<Decoder<float>> inputGateBiasTensor;
135 std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
136 m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<
void>());
137 std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
138 m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<
void>());
139 std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
140 m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<
void>());
142 std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
143 std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
144 std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
146 std::unique_ptr<Decoder<float>> projectionWeightsTensor;
147 std::unique_ptr<Decoder<float>> projectionBiasTensor;
149 std::unique_ptr<Decoder<float>> inputLayerNormWeights;
150 std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
151 std::unique_ptr<Decoder<float>> cellLayerNormWeights;
152 std::unique_ptr<Decoder<float>> outputLayerNormWeights;
154 const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
155 const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();
161 inputLayerNormWeights = MakeDecoder<float>(
162 m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<
void>());
164 forgetLayerNormWeights = MakeDecoder<float>(
165 m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<
void>());
166 cellLayerNormWeights = MakeDecoder<float>(
167 m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<
void>());
168 outputLayerNormWeights = MakeDecoder<float>(
169 m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<
void>());
174 inputToInputWeightsTensor = MakeDecoder<float>(
175 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<
void>());
176 inputGateBiasTensor = MakeDecoder<float>(
177 m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<
void>());
178 recurrentToInputWeightsTensor = MakeDecoder<float>(
179 m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<
void>());
184 cellToForgetWeightsTensor = MakeDecoder<float>(
185 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<
void>());
186 cellToOutputWeightsTensor = MakeDecoder<float>(
187 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<
void>());
190 if (!useCifg && usePeephole)
192 cellToInputWeightsTensor = MakeDecoder<float>(
193 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<
void>());
198 projectionWeightsTensor = MakeDecoder<float>(
199 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<
void>());
200 if (m_ProjectionBiasTensor)
202 projectionBiasTensor = MakeDecoder<float>(
203 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<
void>());
210 inputToOutputWeightsShape,
211 recurrentToOutputWeightsShape,
220 inputToInputWeightsTensor,
221 inputToForgetWeightsTensor,
222 inputToCellWeightsTensor,
223 inputToOutputWeightsTensor,
224 recurrentToInputWeightsTensor,
225 recurrentToForgetWeightsTensor,
226 recurrentToCellWeightsTensor,
227 recurrentToOutputWeightsTensor,
228 cellToInputWeightsTensor,
229 cellToForgetWeightsTensor,
230 cellToOutputWeightsTensor,
232 forgetGateBiasTensor,
234 outputGateBiasTensor,
235 projectionWeightsTensor,
236 projectionBiasTensor,
237 inputLayerNormWeights,
238 forgetLayerNormWeights,
239 cellLayerNormWeights,
240 outputLayerNormWeights,
245 inputGateScratchDecoder,
247 forgetGateScratchDecoder,
248 outputGateScratchDecoder,
bool m_ProjectionEnabled
Enable/disable the projection layer.
const TensorShape & GetShape() const
void LstmImpl(const LstmDescriptor &descriptor, const TensorInfo &inputInfo, const TensorInfo &outputInfo, const TensorShape &inputToOutputWeightsShape, const TensorShape &recurrentToOutputWeightsShape, std::unique_ptr< Decoder< float >> &inputData, std::unique_ptr< Decoder< float >> &outputStateIn, std::unique_ptr< Decoder< float >> &cellStateIn, std::unique_ptr< Encoder< float >> &outputStateOut, std::unique_ptr< Encoder< float >> &cellStateOut, std::unique_ptr< Encoder< float >> &output, std::unique_ptr< Decoder< float >> &cellStateOutDecoder, std::unique_ptr< Decoder< float >> &outputDecoder, std::unique_ptr< Decoder< float >> &inputToInputWeightsTensor, std::unique_ptr< Decoder< float >> &inputToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &inputToCellWeightsTensor, std::unique_ptr< Decoder< float >> &inputToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToInputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToCellWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToInputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &cellToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &inputGateBiasTensor, std::unique_ptr< Decoder< float >> &forgetGateBiasTensor, std::unique_ptr< Decoder< float >> &cellBiasTensor, std::unique_ptr< Decoder< float >> &outputGateBiasTensor, std::unique_ptr< Decoder< float >> &projectionWeightsTensor, std::unique_ptr< Decoder< float >> &projectionBiasTensor, std::unique_ptr< Decoder< float >> &inputLayerNormWeights, std::unique_ptr< Decoder< float >> &forgetLayerNormWeights, std::unique_ptr< Decoder< float >> &cellLayerNormWeights, std::unique_ptr< Decoder< float >> &outputLayerNormWeights, std::unique_ptr< Encoder< float >> &inputGateScratch, std::unique_ptr< Encoder< float >> &cellScratch, std::unique_ptr< Encoder< float >> &forgetGateScratch, std::unique_ptr< Encoder< float >> &outputGateScratch, std::unique_ptr< Decoder< float >> &inputGateScratchDecoder, std::unique_ptr< Decoder< float >> &cellScratchDecoder, std::unique_ptr< Decoder< float >> &forgetGateScratchDecoder, std::unique_ptr< Decoder< float >> &outputGateScratchDecoder, float layerNormEpsilon)
RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
std::unique_ptr< armnn::ScopedTensorHandle > AssignScopedTensorHandle(const armnn::ConstTensorHandle *ptr)
Copyright (c) 2021 ARM Limited and Contributors.
LayerDescriptor m_Parameters
std::vector< ITensorHandle * > m_Inputs
LstmQueueDescriptor m_Data
void ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor) override
bool m_PeepholeEnabled
Enable/disable peephole.
void Execute() const override
bool m_CifgEnabled
Enable/disable cifg (coupled input & forget gate).
std::vector< ITensorHandle * > m_Outputs
std::vector< ITensorHandle * > m_Outputs
bool m_LayerNormEnabled
Enable/disable layer normalization.
Contains information about TensorInfos of a layer.
std::vector< ITensorHandle * > m_Inputs
const TensorInfo & GetTensorInfo(const ITensorHandle *tensorHandle)
float32 helpers