63 std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo, outputs[1]->Map());
64 std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(outputInfo, outputs[2]->Map());
65 std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(outputInfo, outputs[3]->Map());
67 std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(outputInfo, outputs[2]->Map());
68 std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(outputInfo, outputs[3]->Map());
70 std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(inputInfo, inputs[0]->Map());
71 std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(inputInfo, inputs[1]->Map());
72 std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(inputInfo, inputs[2]->Map());
74 const uint32_t nBatch = inputShape[0];
75 const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0];
82 std::unique_ptr<Encoder<float>> inputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
83 std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
84 std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
85 std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
87 std::unique_ptr<Decoder<float>> inputGateScratchDecoder =
88 MakeDecoder<float>(outputInfo, outputs[0]->Map());
89 std::unique_ptr<Decoder<float>> cellScratchDecoder =
90 MakeDecoder<float>(outputInfo, outputs[0]->Map());
91 std::unique_ptr<Decoder<float>> forgetGateScratchDecoder =
92 MakeDecoder<float>(outputInfo, outputs[0]->Map());
93 std::unique_ptr<Decoder<float>> outputGateScratchDecoder =
94 MakeDecoder<float>(outputInfo, outputs[0]->Map());
98 *cellScratch += (0 * nCell * nBatch);
99 *forgetGateScratch += (1 * nCell * nBatch);
100 *outputGateScratch += (2 * nCell * nBatch);
102 *cellScratchDecoder += (0 * nCell * nBatch);
103 *forgetGateScratchDecoder += (1 * nCell * nBatch);
104 *outputGateScratchDecoder += (2 * nCell * nBatch);
108 *inputGateScratch += (0 * nCell * nBatch);
109 *cellScratch += (1 * nCell * nBatch);
110 *forgetGateScratch += (2 * nCell * nBatch);
111 *outputGateScratch += (3 * nCell * nBatch);
113 *inputGateScratchDecoder += (0 * nCell * nBatch);
114 *cellScratchDecoder += (1 * nCell * nBatch);
115 *forgetGateScratchDecoder += (2 * nCell * nBatch);
116 *outputGateScratchDecoder += (3 * nCell * nBatch);
119 std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
120 std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
121 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<
void>());
122 std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
123 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<
void>());
124 std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
125 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<
void>());
127 std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
128 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
129 m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<
void>());
130 std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
131 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<
void>());
132 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
133 m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<
void>());
135 std::unique_ptr<Decoder<float>> inputGateBiasTensor;
136 std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
137 m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<
void>());
138 std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
139 m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<
void>());
140 std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
141 m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<
void>());
143 std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
144 std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
145 std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
147 std::unique_ptr<Decoder<float>> projectionWeightsTensor;
148 std::unique_ptr<Decoder<float>> projectionBiasTensor;
150 std::unique_ptr<Decoder<float>> inputLayerNormWeights;
151 std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
152 std::unique_ptr<Decoder<float>> cellLayerNormWeights;
153 std::unique_ptr<Decoder<float>> outputLayerNormWeights;
155 const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
156 const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();
162 inputLayerNormWeights = MakeDecoder<float>(
163 m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<
void>());
165 forgetLayerNormWeights = MakeDecoder<float>(
166 m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<
void>());
167 cellLayerNormWeights = MakeDecoder<float>(
168 m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<
void>());
169 outputLayerNormWeights = MakeDecoder<float>(
170 m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<
void>());
175 inputToInputWeightsTensor = MakeDecoder<float>(
176 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<
void>());
177 inputGateBiasTensor = MakeDecoder<float>(
178 m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<
void>());
179 recurrentToInputWeightsTensor = MakeDecoder<float>(
180 m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<
void>());
185 cellToForgetWeightsTensor = MakeDecoder<float>(
186 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<
void>());
187 cellToOutputWeightsTensor = MakeDecoder<float>(
188 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<
void>());
191 if (!useCifg && usePeephole)
193 cellToInputWeightsTensor = MakeDecoder<float>(
194 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<
void>());
199 projectionWeightsTensor = MakeDecoder<float>(
200 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<
void>());
201 if (m_ProjectionBiasTensor)
203 projectionBiasTensor = MakeDecoder<float>(
204 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<
void>());
211 inputToOutputWeightsShape,
212 recurrentToOutputWeightsShape,
221 inputToInputWeightsTensor,
222 inputToForgetWeightsTensor,
223 inputToCellWeightsTensor,
224 inputToOutputWeightsTensor,
225 recurrentToInputWeightsTensor,
226 recurrentToForgetWeightsTensor,
227 recurrentToCellWeightsTensor,
228 recurrentToOutputWeightsTensor,
229 cellToInputWeightsTensor,
230 cellToForgetWeightsTensor,
231 cellToOutputWeightsTensor,
233 forgetGateBiasTensor,
235 outputGateBiasTensor,
236 projectionWeightsTensor,
237 projectionBiasTensor,
238 inputLayerNormWeights,
239 forgetLayerNormWeights,
240 cellLayerNormWeights,
241 outputLayerNormWeights,
246 inputGateScratchDecoder,
248 forgetGateScratchDecoder,
249 outputGateScratchDecoder,
bool m_ProjectionEnabled
Enable/disable the projection layer.
const TensorShape & GetShape() const
void LstmImpl(const LstmDescriptor &descriptor, const TensorInfo &inputInfo, const TensorInfo &outputInfo, const TensorShape &inputToOutputWeightsShape, const TensorShape &recurrentToOutputWeightsShape, std::unique_ptr< Decoder< float >> &inputData, std::unique_ptr< Decoder< float >> &outputStateIn, std::unique_ptr< Decoder< float >> &cellStateIn, std::unique_ptr< Encoder< float >> &outputStateOut, std::unique_ptr< Encoder< float >> &cellStateOut, std::unique_ptr< Encoder< float >> &output, std::unique_ptr< Decoder< float >> &cellStateOutDecoder, std::unique_ptr< Decoder< float >> &outputDecoder, std::unique_ptr< Decoder< float >> &inputToInputWeightsTensor, std::unique_ptr< Decoder< float >> &inputToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &inputToCellWeightsTensor, std::unique_ptr< Decoder< float >> &inputToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToInputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToCellWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToInputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &cellToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &inputGateBiasTensor, std::unique_ptr< Decoder< float >> &forgetGateBiasTensor, std::unique_ptr< Decoder< float >> &cellBiasTensor, std::unique_ptr< Decoder< float >> &outputGateBiasTensor, std::unique_ptr< Decoder< float >> &projectionWeightsTensor, std::unique_ptr< Decoder< float >> &projectionBiasTensor, std::unique_ptr< Decoder< float >> &inputLayerNormWeights, std::unique_ptr< Decoder< float >> &forgetLayerNormWeights, std::unique_ptr< Decoder< float >> &cellLayerNormWeights, std::unique_ptr< Decoder< float >> &outputLayerNormWeights, std::unique_ptr< Encoder< float >> &inputGateScratch, std::unique_ptr< Encoder< float >> &cellScratch, std::unique_ptr< Encoder< float >> &forgetGateScratch, std::unique_ptr< Encoder< float >> &outputGateScratch, std::unique_ptr< Decoder< float >> &inputGateScratchDecoder, std::unique_ptr< Decoder< float >> &cellScratchDecoder, std::unique_ptr< Decoder< float >> &forgetGateScratchDecoder, std::unique_ptr< Decoder< float >> &outputGateScratchDecoder, float layerNormEpsilon)
RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
std::unique_ptr< armnn::ScopedTensorHandle > AssignScopedTensorHandle(const armnn::ConstTensorHandle *ptr)
Copyright (c) 2021 ARM Limited and Contributors.
LayerDescriptor m_Parameters
std::vector< ITensorHandle * > m_Inputs
LstmQueueDescriptor m_Data
bool m_PeepholeEnabled
Enable/disable peephole.
void Execute() const override
bool m_CifgEnabled
Enable/disable cifg (coupled input & forget gate).
std::vector< ITensorHandle * > m_Outputs
std::vector< ITensorHandle * > m_Outputs
bool m_LayerNormEnabled
Enable/disable layer normalization.
void ExecuteAsync(ExecutionData &executionData) override
Contains information about TensorInfos of a layer.
std::vector< ITensorHandle * > m_Inputs
const TensorInfo & GetTensorInfo(const ITensorHandle *tensorHandle)
float32 helpers