63 std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo, outputs[1]->
Map());
64 std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(outputInfo, outputs[2]->
Map());
65 std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(outputInfo, outputs[3]->
Map());
67 std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(outputInfo, outputs[2]->
Map());
68 std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(outputInfo, outputs[3]->
Map());
70 std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(inputInfo, inputs[0]->
Map());
71 std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(inputInfo, inputs[1]->
Map());
72 std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(inputInfo, inputs[2]->
Map());
74 const uint32_t nBatch = inputShape[0];
75 const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0];
82 std::unique_ptr<Encoder<float>> inputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->
Map());
83 std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(outputInfo, outputs[0]->
Map());
84 std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->
Map());
85 std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->
Map());
87 std::unique_ptr<Decoder<float>> inputGateScratchDecoder =
88 MakeDecoder<float>(outputInfo, outputs[0]->
Map());
89 std::unique_ptr<Decoder<float>> cellScratchDecoder =
90 MakeDecoder<float>(outputInfo, outputs[0]->
Map());
91 std::unique_ptr<Decoder<float>> forgetGateScratchDecoder =
92 MakeDecoder<float>(outputInfo, outputs[0]->
Map());
93 std::unique_ptr<Decoder<float>> outputGateScratchDecoder =
94 MakeDecoder<float>(outputInfo, outputs[0]->
Map());
98 *cellScratch += (0 * nCell * nBatch);
99 *forgetGateScratch += (1 * nCell * nBatch);
100 *outputGateScratch += (2 * nCell * nBatch);
102 *cellScratchDecoder += (0 * nCell * nBatch);
103 *forgetGateScratchDecoder += (1 * nCell * nBatch);
104 *outputGateScratchDecoder += (2 * nCell * nBatch);
108 *inputGateScratch += (0 * nCell * nBatch);
109 *cellScratch += (1 * nCell * nBatch);
110 *forgetGateScratch += (2 * nCell * nBatch);
111 *outputGateScratch += (3 * nCell * nBatch);
113 *inputGateScratchDecoder += (0 * nCell * nBatch);
114 *cellScratchDecoder += (1 * nCell * nBatch);
115 *forgetGateScratchDecoder += (2 * nCell * nBatch);
116 *outputGateScratchDecoder += (3 * nCell * nBatch);
119 std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
120 std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
121 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<
void>());
122 std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
123 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<
void>());
124 std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
125 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<
void>());
127 std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
128 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
129 m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<
void>());
130 std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
131 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<
void>());
132 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
133 m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<
void>());
135 std::unique_ptr<Decoder<float>> inputGateBiasTensor;
136 std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
137 m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<
void>());
138 std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
139 m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<
void>());
140 std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
141 m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<
void>());
143 std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
144 std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
145 std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
147 std::unique_ptr<Decoder<float>> projectionWeightsTensor;
148 std::unique_ptr<Decoder<float>> projectionBiasTensor;
150 std::unique_ptr<Decoder<float>> inputLayerNormWeights;
151 std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
152 std::unique_ptr<Decoder<float>> cellLayerNormWeights;
153 std::unique_ptr<Decoder<float>> outputLayerNormWeights;
155 const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
156 const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();
162 inputLayerNormWeights = MakeDecoder<float>(
163 m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<
void>());
165 forgetLayerNormWeights = MakeDecoder<float>(
166 m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<
void>());
167 cellLayerNormWeights = MakeDecoder<float>(
168 m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<
void>());
169 outputLayerNormWeights = MakeDecoder<float>(
170 m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<
void>());
175 inputToInputWeightsTensor = MakeDecoder<float>(
176 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<
void>());
177 inputGateBiasTensor = MakeDecoder<float>(
178 m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<
void>());
179 recurrentToInputWeightsTensor = MakeDecoder<float>(
180 m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<
void>());
185 cellToForgetWeightsTensor = MakeDecoder<float>(
186 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<
void>());
187 cellToOutputWeightsTensor = MakeDecoder<float>(
188 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<
void>());
191 if (!useCifg && usePeephole)
193 cellToInputWeightsTensor = MakeDecoder<float>(
194 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<
void>());
199 projectionWeightsTensor = MakeDecoder<float>(
200 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<
void>());
201 if (m_ProjectionBiasTensor)
203 projectionBiasTensor = MakeDecoder<float>(
204 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<
void>());
211 inputToOutputWeightsShape,
212 recurrentToOutputWeightsShape,
221 inputToInputWeightsTensor,
222 inputToForgetWeightsTensor,
223 inputToCellWeightsTensor,
224 inputToOutputWeightsTensor,
225 recurrentToInputWeightsTensor,
226 recurrentToForgetWeightsTensor,
227 recurrentToCellWeightsTensor,
228 recurrentToOutputWeightsTensor,
229 cellToInputWeightsTensor,
230 cellToForgetWeightsTensor,
231 cellToOutputWeightsTensor,
233 forgetGateBiasTensor,
235 outputGateBiasTensor,
236 projectionWeightsTensor,
237 projectionBiasTensor,
238 inputLayerNormWeights,
239 forgetLayerNormWeights,
240 cellLayerNormWeights,
241 outputLayerNormWeights,
246 inputGateScratchDecoder,
248 forgetGateScratchDecoder,
249 outputGateScratchDecoder,