58 std::vector<ITensorHandle*> outputs)
const
68 auto inputTensor =
reinterpret_cast<float*
>(inputs[0]->Map());
74 std::vector<float> inputValue(inputTensor, inputTensor + inputInfo.
GetNumElements());
82 unsigned int maxTime = inputShape[0];
83 unsigned int batchSize = inputShape[1];
84 unsigned int outputSize = outputShape[2];
85 unsigned int inputSize = inputShape[2];
87 TensorInfo scratchInfo = outputInfo;
90 std::vector<float> inputGateScratchBuffer;
91 std::vector<float> cellScratchBuffer(scratchInfo.GetNumElements(), 0.);
92 std::vector<float> forgetGateScratchBuffer(scratchInfo.GetNumElements(), 0.);
93 std::vector<float> outputGateScratchBuffer(scratchInfo.GetNumElements(), 0.);
95 std::vector<float> outputStateOutBuffer(outputStateInfo.
GetNumElements(), 0.);
96 std::vector<float> cellStateOutBuffer(cellStateInfo.
GetNumElements(), 0.);
98 void* outputStateOutData = outputStateOutBuffer.data();
99 void* cellStateOutData = cellStateOutBuffer.data();
101 std::unique_ptr<Encoder<float>> inputGateScratch;
102 std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(scratchInfo, cellScratchBuffer.data());
103 std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(scratchInfo, forgetGateScratchBuffer.data());
104 std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(scratchInfo, outputGateScratchBuffer.data());
106 std::unique_ptr<Decoder<float>> inputGateScratchDecoder;
107 std::unique_ptr<Decoder<float>> cellScratchDecoder = MakeDecoder<float>(scratchInfo, cellScratchBuffer.data());
108 std::unique_ptr<Decoder<float>> forgetGateScratchDecoder = MakeDecoder<float>(scratchInfo,
109 forgetGateScratchBuffer.data());
110 std::unique_ptr<Decoder<float>> outputGateScratchDecoder = MakeDecoder<float>(scratchInfo,
111 outputGateScratchBuffer.data());
119 inputGateScratchBuffer.resize(scratchInfo.GetNumElements(), 0.);
120 inputGateScratch = MakeEncoder<float>(scratchInfo, inputGateScratchBuffer.data());
121 inputGateScratchDecoder = MakeDecoder<float>(scratchInfo, inputGateScratchBuffer.data());
124 std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputStateInfo, outputStateOutData);
125 std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(cellStateInfo, cellStateOutData);
126 std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(cellStateInfo, cellStateOutData);
128 TensorInfo lstmInputInfo = inputInfo;
129 TensorShape batchInputShape = TensorShape({batchSize, inputSize});
130 lstmInputInfo.
SetShape(batchInputShape);
132 TensorInfo lstmOutputInfo = outputInfo;
133 lstmOutputInfo.
SetShape({batchSize, outputSize});
135 const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
136 const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();
137 unsigned int nOutput = recurrentToOutputWeightsShape[1];
138 auto outputStateInData = inputs[1]->Map();
139 std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateInData);
141 auto cellStateInData = inputs[2]->Map();
142 std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateInData);
144 auto currentInputData =
reinterpret_cast<float*
>(inputs[0]->Map());
145 std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
146 auto currentOutputData =
reinterpret_cast<float*
>(outputs[2]->Map());
147 std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
148 std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
150 std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
151 std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
152 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<
void>());
153 std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
154 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<
void>());
155 std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
156 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<
void>());
158 std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
159 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
160 m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<
void>());
161 std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
162 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<
void>());
163 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
164 m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<
void>());
166 std::unique_ptr<Decoder<float>> inputGateBiasTensor;
167 std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
168 m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<
void>());
169 std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
170 m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<
void>());
171 std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
172 m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<
void>());
174 std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
175 std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
176 std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
178 std::unique_ptr<Decoder<float>> projectionWeightsTensor;
179 std::unique_ptr<Decoder<float>> projectionBiasTensor;
181 std::unique_ptr<Decoder<float>> inputLayerNormWeights;
182 std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
183 std::unique_ptr<Decoder<float>> cellLayerNormWeights;
184 std::unique_ptr<Decoder<float>> outputLayerNormWeights;
190 inputLayerNormWeights = MakeDecoder<float>(
191 m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<
void>());
193 forgetLayerNormWeights = MakeDecoder<float>(
194 m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<
void>());
195 cellLayerNormWeights = MakeDecoder<float>(
196 m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<
void>());
197 outputLayerNormWeights = MakeDecoder<float>(
198 m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<
void>());
203 inputToInputWeightsTensor = MakeDecoder<float>(
204 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<
void>());
205 inputGateBiasTensor = MakeDecoder<float>(
206 m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<
void>());
207 recurrentToInputWeightsTensor = MakeDecoder<float>(
208 m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<
void>());
213 cellToForgetWeightsTensor = MakeDecoder<float>(
214 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<
void>());
215 cellToOutputWeightsTensor = MakeDecoder<float>(
216 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<
void>());
219 if (!useCifg && usePeephole)
221 cellToInputWeightsTensor = MakeDecoder<float>(
222 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<
void>());
227 projectionWeightsTensor = MakeDecoder<float>(
228 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<
void>());
229 if (m_ProjectionBiasTensor)
231 projectionBiasTensor = MakeDecoder<float>(
232 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<
void>());
236 unsigned int batchInputSize = batchSize * inputSize;
237 unsigned int batchOutputSize = batchSize * nOutput;
239 for (
unsigned int t = 0; t < maxTime; ++t)
244 inputToOutputWeightsShape,
245 recurrentToOutputWeightsShape,
254 inputToInputWeightsTensor,
255 inputToForgetWeightsTensor,
256 inputToCellWeightsTensor,
257 inputToOutputWeightsTensor,
258 recurrentToInputWeightsTensor,
259 recurrentToForgetWeightsTensor,
260 recurrentToCellWeightsTensor,
261 recurrentToOutputWeightsTensor,
262 cellToInputWeightsTensor,
263 cellToForgetWeightsTensor,
264 cellToOutputWeightsTensor,
266 forgetGateBiasTensor,
268 outputGateBiasTensor,
269 projectionWeightsTensor,
270 projectionBiasTensor,
271 inputLayerNormWeights,
272 forgetLayerNormWeights,
273 cellLayerNormWeights,
274 outputLayerNormWeights,
279 inputGateScratchDecoder,
281 forgetGateScratchDecoder,
282 outputGateScratchDecoder,
285 currentInputData += batchInputSize;
286 inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
287 currentOutputData += batchOutputSize;
288 output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
289 outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
292 outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateOutData);
295 cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateOutData);
301 const PermutationVector& mappings = {1U, 0U, 2U};
302 auto outputData =
reinterpret_cast<float*
>(outputs[2]->Map());
303 std::vector<float> outputValue(outputData, outputData + outputInfo.
GetNumElements());