58 std::vector<ITensorHandle*> outputs)
const
70 auto inputTensor =
reinterpret_cast<float*
>(inputs[0]->Map());
76 std::vector<float> inputValue(inputTensor, inputTensor + inputInfo.
GetNumElements());
84 unsigned int maxTime = inputShape[0];
85 unsigned int batchSize = inputShape[1];
86 unsigned int outputSize = outputShape[2];
87 unsigned int inputSize = inputShape[2];
89 TensorInfo scratchInfo = outputInfo;
92 std::vector<float> inputGateScratchBuffer;
93 std::vector<float> cellScratchBuffer(scratchInfo.GetNumElements(), 0.);
94 std::vector<float> forgetGateScratchBuffer(scratchInfo.GetNumElements(), 0.);
95 std::vector<float> outputGateScratchBuffer(scratchInfo.GetNumElements(), 0.);
97 std::vector<float> outputStateOutBuffer(outputStateInfo.
GetNumElements(), 0.);
98 std::vector<float> cellStateOutBuffer(cellStateInfo.
GetNumElements(), 0.);
100 void* outputStateOutData = outputStateOutBuffer.data();
101 void* cellStateOutData = cellStateOutBuffer.data();
103 std::unique_ptr<Encoder<float>> inputGateScratch;
104 std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(scratchInfo, cellScratchBuffer.data());
105 std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(scratchInfo, forgetGateScratchBuffer.data());
106 std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(scratchInfo, outputGateScratchBuffer.data());
108 std::unique_ptr<Decoder<float>> inputGateScratchDecoder;
109 std::unique_ptr<Decoder<float>> cellScratchDecoder = MakeDecoder<float>(scratchInfo, cellScratchBuffer.data());
110 std::unique_ptr<Decoder<float>> forgetGateScratchDecoder = MakeDecoder<float>(scratchInfo,
111 forgetGateScratchBuffer.data());
112 std::unique_ptr<Decoder<float>> outputGateScratchDecoder = MakeDecoder<float>(scratchInfo,
113 outputGateScratchBuffer.data());
121 inputGateScratchBuffer.resize(scratchInfo.GetNumElements(), 0.);
122 inputGateScratch = MakeEncoder<float>(scratchInfo, inputGateScratchBuffer.data());
123 inputGateScratchDecoder = MakeDecoder<float>(scratchInfo, inputGateScratchBuffer.data());
126 std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputStateInfo, outputStateOutData);
127 std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(cellStateInfo, cellStateOutData);
128 std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(cellStateInfo, cellStateOutData);
130 TensorInfo lstmInputInfo = inputInfo;
131 TensorShape batchInputShape = TensorShape({batchSize, inputSize});
132 lstmInputInfo.
SetShape(batchInputShape);
134 TensorInfo lstmOutputInfo = outputInfo;
135 lstmOutputInfo.
SetShape({batchSize, outputSize});
137 const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
138 const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();
139 unsigned int nOutput = recurrentToOutputWeightsShape[1];
140 auto outputStateInData = inputs[1]->Map();
141 std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateInData);
143 auto cellStateInData = inputs[2]->Map();
144 std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateInData);
146 auto currentInputData =
reinterpret_cast<float*
>(inputs[0]->Map());
147 std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
148 auto currentOutputData =
reinterpret_cast<float*
>(outputs[2]->Map());
149 std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
150 std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
152 std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
153 std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
154 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<
void>());
155 std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
156 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<
void>());
157 std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
158 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<
void>());
160 std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
161 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
162 m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<
void>());
163 std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
164 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<
void>());
165 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
166 m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<
void>());
168 std::unique_ptr<Decoder<float>> inputGateBiasTensor;
169 std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
170 m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<
void>());
171 std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
172 m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<
void>());
173 std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
174 m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<
void>());
176 std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
177 std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
178 std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
180 std::unique_ptr<Decoder<float>> projectionWeightsTensor;
181 std::unique_ptr<Decoder<float>> projectionBiasTensor;
183 std::unique_ptr<Decoder<float>> inputLayerNormWeights;
184 std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
185 std::unique_ptr<Decoder<float>> cellLayerNormWeights;
186 std::unique_ptr<Decoder<float>> outputLayerNormWeights;
192 inputLayerNormWeights = MakeDecoder<float>(
193 m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<
void>());
195 forgetLayerNormWeights = MakeDecoder<float>(
196 m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<
void>());
197 cellLayerNormWeights = MakeDecoder<float>(
198 m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<
void>());
199 outputLayerNormWeights = MakeDecoder<float>(
200 m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<
void>());
205 inputToInputWeightsTensor = MakeDecoder<float>(
206 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<
void>());
207 inputGateBiasTensor = MakeDecoder<float>(
208 m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<
void>());
209 recurrentToInputWeightsTensor = MakeDecoder<float>(
210 m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<
void>());
215 cellToForgetWeightsTensor = MakeDecoder<float>(
216 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<
void>());
217 cellToOutputWeightsTensor = MakeDecoder<float>(
218 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<
void>());
221 if (!useCifg && usePeephole)
223 cellToInputWeightsTensor = MakeDecoder<float>(
224 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<
void>());
229 projectionWeightsTensor = MakeDecoder<float>(
230 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<
void>());
231 if (m_ProjectionBiasTensor)
233 projectionBiasTensor = MakeDecoder<float>(
234 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<
void>());
238 unsigned int batchInputSize = batchSize * inputSize;
239 unsigned int batchOutputSize = batchSize * nOutput;
241 for (
unsigned int t = 0; t < maxTime; ++t)
246 inputToOutputWeightsShape,
247 recurrentToOutputWeightsShape,
256 inputToInputWeightsTensor,
257 inputToForgetWeightsTensor,
258 inputToCellWeightsTensor,
259 inputToOutputWeightsTensor,
260 recurrentToInputWeightsTensor,
261 recurrentToForgetWeightsTensor,
262 recurrentToCellWeightsTensor,
263 recurrentToOutputWeightsTensor,
264 cellToInputWeightsTensor,
265 cellToForgetWeightsTensor,
266 cellToOutputWeightsTensor,
268 forgetGateBiasTensor,
270 outputGateBiasTensor,
271 projectionWeightsTensor,
272 projectionBiasTensor,
273 inputLayerNormWeights,
274 forgetLayerNormWeights,
275 cellLayerNormWeights,
276 outputLayerNormWeights,
281 inputGateScratchDecoder,
283 forgetGateScratchDecoder,
284 outputGateScratchDecoder,
287 currentInputData += batchInputSize;
288 inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
289 currentOutputData += batchOutputSize;
290 output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
291 outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
294 outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateOutData);
297 cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateOutData);
303 const PermutationVector& mappings = {1U, 0U, 2U};
304 auto outputData =
reinterpret_cast<float*
>(outputs[2]->Map());
305 std::vector<float> outputValue(outputData, outputData + outputInfo.
GetNumElements());