23 #include <doctest/doctest.h> 27 template<armnn::DataType ArmnnType,
typename T = armnn::ResolveType<ArmnnType>>
28 void LstmUtilsVectorBatchVectorAddTestImpl(
29 std::vector<float>& vec,
30 std::vector<float>& batchVec,
33 std::vector<float>& expectedOutput,
41 std::unique_ptr<armnn::Decoder<float>> vecDecoder = armnn::MakeDecoder<float>(tensorInfo, vec.data());
42 std::unique_ptr<armnn::Decoder<float>> batchVecDecoder = armnn::MakeDecoder<float>(tensorInfo, batchVec.data());
43 std::unique_ptr<armnn::Encoder<float>> batchVecEncoder = armnn::MakeEncoder<float>(tensorInfo, batchVec.data());
48 auto result =
CompareTensors(batchVec, expectedOutput, expectedShape, expectedShape);
49 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
52 batchVecEncoder->Set(1.0f);
53 CHECK(batchVec[0] == 1.0f);
56 template<armnn::DataType ArmnnType,
typename T = armnn::ResolveType<ArmnnType>>
57 void LstmUtilsZeroVectorTestImpl(
58 std::vector<float>& input,
60 std::vector<float>& expectedOutput,
69 std::unique_ptr<armnn::Encoder<float>> outputEncoder = armnn::MakeEncoder<float>(tensorInfo, input.data());
75 auto result =
CompareTensors(input, expectedOutput, expectedShape, expectedShape);
76 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
79 outputEncoder->Set(1.0f);
80 CHECK(input[0] == 1.0f);
84 template<armnn::DataType ArmnnType,
typename T = armnn::ResolveType<ArmnnType>>
85 void LstmUtilsMeanStddevNormalizationTestImpl(
86 std::vector<float>& input,
89 std::vector<float>& expectedOutput,
97 std::unique_ptr<armnn::Decoder<float>> inputDecoder = armnn::MakeDecoder<float>(tensorInfo, input.data());
98 std::unique_ptr<armnn::Encoder<float>> outputEncoder = armnn::MakeEncoder<float>(tensorInfo, input.data());
103 auto result =
CompareTensors(input, expectedOutput, expectedShape, expectedShape);
104 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
107 outputEncoder->Set(1.0f);
108 CHECK(input[0] == 1.0f);
111 template<armnn::DataType ArmnnType,
typename T = armnn::ResolveType<ArmnnType>>
112 void LstmUtilsVectorBatchVectorCwiseProductTestImpl(
113 std::vector<float>& vec,
114 std::vector<float>& batchVec,
117 std::vector<float>& expectedOutput,
125 std::unique_ptr<armnn::Decoder<float>> vecDecoder = armnn::MakeDecoder<float>(tensorInfo, vec.data());
126 std::unique_ptr<armnn::Decoder<float>> batchVecDecoder = armnn::MakeDecoder<float>(tensorInfo, batchVec.data());
127 std::unique_ptr<armnn::Encoder<float>> batchVecEncoder = armnn::MakeEncoder<float>(tensorInfo, batchVec.data());
132 auto result =
CompareTensors(batchVec, expectedOutput, expectedShape, expectedShape);
133 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
136 batchVecEncoder->Set(1.0f);
137 CHECK(batchVec[0] == 1.0f);
142 template<armnn::DataType ArmnnType,
typename T = armnn::ResolveType<ArmnnType>>
144 LstmNoCifgNoPeepholeNoProjectionTestImpl(
148 const std::vector<T>& input,
149 const std::vector<T>& outputExpected,
161 unsigned numUnits = outputSize;
163 armnn::TensorInfo inputTensorInfo({batchSize , inputSize}, ArmnnType, qScale, qOffset );
164 armnn::TensorInfo cellStateInTensorInfo({batchSize , numUnits}, ArmnnType, qScale, qOffset);
165 armnn::TensorInfo outputStateInTensorInfo({batchSize , outputSize}, ArmnnType, qScale, qOffset);
167 armnn::TensorInfo scratchBufferTensorInfo({batchSize, numUnits * 4}, ArmnnType, qScale, qOffset);
168 armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, ArmnnType, qScale, qOffset);
169 armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
170 armnn::TensorInfo outputTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
172 std::vector<T> inputVector;
173 inputVector.assign(input.data(), input.data() + (batchSize * inputSize));
175 std::vector<T> cellStateInVector(batchSize * numUnits, T());
176 std::vector<T> outputStateInVector(batchSize * outputSize, T());
177 std::vector<T> scratchBufferVector(batchSize * numUnits * 4, T());
178 std::vector<T> outputStateOutVector(batchSize * outputSize, T());
179 std::vector<T> cellStateOutVector(batchSize * numUnits, T());
181 std::vector<T> actualOutput(outputTensorInfo.GetNumElements());
183 std::vector<T> outputVector;
184 outputVector.assign(outputExpected.data(), outputExpected.data() + (batchSize * outputSize));
186 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.
CreateTensorHandle(inputTensorInfo);
187 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
189 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
192 std::unique_ptr<armnn::ITensorHandle> scratchHandle =
194 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
196 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
198 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.
CreateTensorHandle(outputTensorInfo);
203 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
204 AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get());
205 AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get());
207 AddOutputToWorkload(data, info, scratchBufferTensorInfo, scratchHandle.get());
208 AddOutputToWorkload(data, info, outputStateOutTensorInfo, outputStateOutHandle.get());
209 AddOutputToWorkload(data, info, cellStateOutTensorInfo, cellStateOutHandle.get());
210 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
214 armnn::TensorInfo tensorInfo16({numUnits, 4}, constantDataType, qScale, qOffset);
216 std::vector<float> inputToInputWeights = {-0.45018822f, -0.02338299f, -0.0870589f,
217 -0.34550029f, 0.04266912f, -0.15680569f,
218 -0.34856534f, 0.43890524f};
220 std::vector<float> inputToForgetWeights = { 0.09701663f, 0.20334584f, -0.50592935f,
221 -0.31343272f, -0.40032279f, 0.44781327f,
222 0.01387155f, -0.35593212f};
224 std::vector<float> inputToCellWeights = { -0.50013041f, 0.1370284f, 0.11810488f, 0.2013163f,
225 -0.20583314f, 0.44344562f, 0.22077113f,
228 std::vector<float> inputToOutputWeights = { -0.25065863f, -0.28290087f, 0.04613829f,
229 0.40525138f, 0.44272184f, 0.03897077f,
230 -0.1556896f, 0.19487578f};
232 std::vector<float> recurrentToInputWeights = {-0.0063535f, -0.2042388f, 0.31454784f,
233 -0.35746509f, 0.28902304f, 0.08183324f,
234 -0.16555229f, 0.02286911f, -0.13566875f,
235 0.03034258f, 0.48091322f, -0.12528998f,
236 0.24077177f, -0.51332325f, -0.33502164f,
239 std::vector<float> recurrentToForgetWeights = { -0.48684245f, -0.06655136f, 0.42224967f,
240 0.2112639f, 0.27654213f, 0.20864892f,
241 -0.07646349f, 0.45877004f, 0.00141793f,
242 -0.14609534f, 0.36447752f, 0.09196436f,
243 0.28053468f, 0.01560611f, -0.20127171f,
246 std::vector<float> recurrentToCellWeights = { -0.3407414f, 0.24443203f, -0.2078532f,
247 0.26320225f, 0.05695659f, -0.00123841f,
248 -0.4744786f, -0.35869038f, -0.06418842f,
249 -0.13502428f, -0.501764f, 0.22830659f,
250 -0.46367589f, 0.26016325f, -0.03894562f,
253 std::vector<float> recurrentToOutputWeights = { 0.43385774f, -0.17194885f, 0.2718237f,
254 0.09215671f, 0.24107647f, -0.39835793f,
255 0.18212086f, 0.01301402f, 0.48572797f,
256 -0.50656658f, 0.20047462f, -0.20607421f,
257 -0.51818722f, -0.15390486f, 0.0468148f,
260 std::vector<float> cellToInputWeights = {0., 0., 0., 0.};
262 std::vector<float> inputGateBias = {0., 0., 0., 0.};
264 std::vector<float> forgetGateBias = {1., 1., 1., 1.};
266 std::vector<float> cellBias = {0., 0., 0., 0.};
268 std::vector<float> outputGateBias = {0., 0., 0., 0.};
317 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.
CreateLstm(data, info);
318 inputHandle->Allocate();
319 outputStateInHandle->Allocate();
320 cellStateInHandle->Allocate();
322 scratchHandle->Allocate();
323 outputStateOutHandle->Allocate();
324 cellStateOutHandle->Allocate();
325 outputHandle->Allocate();
337 outputHandle->GetShape(),
338 outputTensorInfo.GetShape());
341 template<armnn::DataType ArmnnType,
typename T = armnn::ResolveType<ArmnnType>>
346 const std::vector<T>& input,
347 const std::vector<T>& outputExpected,
353 unsigned int batchSize = 2;
354 unsigned int outputSize = 16;
355 unsigned int inputSize = 5;
356 unsigned numUnits = 20;
358 armnn::TensorInfo inputTensorInfo({batchSize , inputSize}, ArmnnType, qScale, qOffset);
359 armnn::TensorInfo cellStateInTensorInfo({batchSize , numUnits}, ArmnnType, qScale, qOffset);
360 armnn::TensorInfo outputStateInTensorInfo({batchSize , outputSize}, ArmnnType, qScale, qOffset);
363 armnn::TensorInfo scratchBufferTensorInfo({batchSize, numUnits * 4}, ArmnnType, qScale, qOffset);
364 armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, ArmnnType, qScale, qOffset);
365 armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
366 armnn::TensorInfo outputTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
368 std::vector<T> inputVector;
369 inputVector.assign(input.data(), input.data() + (batchSize * inputSize));
371 std::vector<T> cellStateInVector(batchSize * numUnits, T());
372 std::vector<T> outputStateInVector(batchSize * outputSize, T());
373 std::vector<T> scratchBufferVector(batchSize * numUnits * 4, T());
374 std::vector<T> outputStateOutVector(batchSize * outputSize, T());
375 std::vector<T> cellStateOutVector(batchSize * numUnits, T());
377 std::vector<T> actualOutput(outputTensorInfo.GetNumElements());
379 std::vector<T> outputVector;
380 outputVector.assign(outputExpected.data(), outputExpected.data() + (batchSize * outputSize));
382 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.
CreateTensorHandle(inputTensorInfo);
383 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
385 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
388 std::unique_ptr<armnn::ITensorHandle> scratchHandle =
390 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
392 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
394 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.
CreateTensorHandle(outputTensorInfo);
399 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
400 AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get());
401 AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get());
403 AddOutputToWorkload(data, info, scratchBufferTensorInfo, scratchHandle.get());
404 AddOutputToWorkload(data, info, outputStateOutTensorInfo, outputStateOutHandle.get());
405 AddOutputToWorkload(data, info, cellStateOutTensorInfo, cellStateOutHandle.get());
406 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
410 armnn::TensorInfo tensorInfo20x5({numUnits, inputSize}, constantDataType, qScale, qOffset);
411 armnn::TensorInfo tensorInfo20x16({numUnits, outputSize}, constantDataType, qScale, qOffset);
412 armnn::TensorInfo tensorInfo16x20({outputSize, numUnits}, constantDataType, qScale, qOffset);
414 std::vector<float> inputToInputWeights = {0.021393683f,0.06124551f, 0.046905167f,-0.014657677f,-0.03149463f,
415 0.09171803f, 0.14647801f,0.10797193f, -0.0057968358f,0.0019193048f,
416 -0.2726754f, 0.10154029f, -0.018539885f, 0.080349885f, -0.10262385f,
417 -0.022599787f,-0.09121155f, -0.008675967f, -0.045206103f,-0.0821282f,
418 -0.008045952f,0.015478081f, 0.055217247f, 0.038719587f, 0.044153627f,
419 -0.06453243f,0.05031825f, -0.046935108f, -0.008164439f, 0.014574226f,
420 -0.1671009f, -0.15519552f, -0.16819797f,-0.13971269f,-0.11953059f,
421 0.25005487f, -0.22790983f, 0.009855087f, -0.028140958f, -0.11200698f,
422 0.11295408f, -0.0035217577f, 0.054485075f, 0.05184695f, 0.064711206f,
423 0.10989193f, 0.11674786f, 0.03490607f, 0.07727357f, 0.11390585f,
424 -0.1863375f, -0.1034451f, -0.13945189f, -0.049401227f, -0.18767063f,
425 0.042483903f, 0.14233552f, 0.13832581f, 0.18350165f, 0.14545603f,
426 -0.028545704f,0.024939531f,0.050929718f,0.0076203286f,-0.0029723682f,
427 -0.042484224f, -0.11827596f, -0.09171104f, -0.10808628f,-0.16327988f,
428 -0.2273378f, -0.0993647f, -0.017155107f,0.0023917493f,0.049272764f,
429 0.0038534778f, 0.054764505f, 0.089753784f, 0.06947234f, 0.08014476f,
430 -0.04544234f, -0.0497073f,-0.07135631f, -0.048929106f,-0.004042012f,
431 -0.009284026f, 0.018042054f, 0.0036860977f,-0.07427302f, -0.11434604f,
432 -0.018995456f, 0.031487543f, 0.012834908f,0.019977754f,0.044256654f,
433 -0.39292613f, -0.18519334f, -0.11651281f,-0.06809892f, 0.011373677f };
435 std::vector<float> inputToForgetWeights = {-0.0018401089f, -0.004852237f,0.03698424f, 0.014181704f,0.028273236f,
436 -0.016726194f, -0.05249759f,-0.10204261f, 0.00861066f,-0.040979505f,
437 -0.009899187f,0.01923892f,-0.028177269f, -0.08535103f,-0.14585495f,
438 0.10662567f,-0.01909731f,-0.017883534f,-0.0047269356f,-0.045103323f,
439 0.0030784295f,0.076784775f,0.07463696f, 0.094531395f,0.0814421f,
440 -0.12257899f, -0.033945758f,-0.031303465f, 0.045630626f,0.06843887f,
441 -0.13492945f, -0.012480007f,-0.0811829f, -0.07224499f,-0.09628791f,
442 0.045100946f,0.0012300825f, 0.013964662f, 0.099372394f,0.02543059f,
443 0.06958324f, 0.034257296f, 0.0482646f, 0.06267997f,0.052625068f,
444 0.12784666f, 0.07077897f, 0.025725935f, 0.04165009f,0.07241905f,
445 0.018668644f, -0.037377294f,-0.06277783f,-0.08833636f,-0.040120605f,
446 -0.011405586f,-0.007808335f,-0.010301386f,-0.005102167f,0.027717464f,
447 0.05483423f, 0.11449111f, 0.11289652f,0.10939839f, 0.13396506f,
448 -0.08402166f,-0.01901462f, -0.044678304f,-0.07720565f,0.014350063f,
449 -0.11757958f, -0.0652038f, -0.08185733f,-0.076754324f,-0.092614375f,
450 0.10405491f, 0.052960336f, 0.035755895f,0.035839386f,-0.012540553f,
451 0.036881298f, 0.02913376f, 0.03420159f,0.05448447f,-0.054523353f,
452 0.02582715f, 0.02327355f, -0.011857179f,-0.0011980024f,-0.034641717f,
453 -0.026125094f,-0.17582615f,-0.15923657f,-0.27486774f,-0.0006143371f,
454 0.0001771948f, -8.470171e-05f, 0.02651807f,0.045790765f,0.06956496f };
456 std::vector<float> inputToCellWeights = { -0.04580283f, -0.09549462f, -0.032418985f, -0.06454633f,
457 -0.043528453f, 0.043018587f, -0.049152344f, -0.12418144f,
458 -0.078985475f, -0.07596889f, 0.019484362f, -0.11434962f,
459 -0.0074034138f, -0.06314844f, -0.092981495f, 0.0062155537f,
460 -0.025034338f, -0.0028890965f, 0.048929527f, 0.06235075f,
461 0.10665918f, -0.032036792f, -0.08505916f, -0.10843358f,
462 -0.13002433f, -0.036816437f, -0.02130134f, -0.016518239f,
463 0.0047691227f, -0.0025825808f, 0.066017866f, 0.029991534f,
464 -0.10652836f, -0.1037554f, -0.13056071f, -0.03266643f,
465 -0.033702414f, -0.006473424f, -0.04611692f, 0.014419339f,
466 -0.025174323f, 0.0396852f, 0.081777506f, 0.06157468f,
467 0.10210095f, -0.009658194f, 0.046511717f, 0.03603906f,
468 0.0069369148f, 0.015960095f, -0.06507666f, 0.09551598f,
469 0.053568836f, 0.06408714f, 0.12835667f, -0.008714329f,
470 -0.20211966f, -0.12093674f, 0.029450472f, 0.2849013f,
471 -0.029227901f, 0.1164364f, -0.08560263f, 0.09941786f,
472 -0.036999565f, -0.028842626f, -0.0033637602f, -0.017012902f,
473 -0.09720865f, -0.11193351f, -0.029155117f, -0.017936034f,
474 -0.009768936f, -0.04223324f, -0.036159635f, 0.06505112f,
475 -0.021742892f, -0.023377212f, -0.07221364f, -0.06430552f,
476 0.05453865f, 0.091149814f, 0.06387331f, 0.007518393f,
477 0.055960953f, 0.069779344f, 0.046411168f, 0.10509911f,
478 0.07463894f, 0.0075130584f, 0.012850982f, 0.04555431f,
479 0.056955688f, 0.06555285f, 0.050801456f, -0.009862683f,
480 0.00826772f, -0.026555609f, -0.0073611983f, -0.0014897042f };
482 std::vector<float> inputToOutputWeights ={-0.0998932f, -0.07201956f, -0.052803773f,-0.15629593f,-0.15001918f,
483 -0.07650751f,0.02359855f, -0.075155355f, -0.08037709f, -0.15093534f,
484 0.029517552f, -0.04751393f, 0.010350531f,-0.02664851f, -0.016839722f,
485 -0.023121163f, 0.0077019283f, 0.012851257f, -0.05040649f,-0.0129761f,
486 -0.021737747f,-0.038305793f,-0.06870586f, -0.01481247f,-0.001285394f,
487 0.10124236f, 0.083122835f, 0.053313006f,-0.062235646f,-0.075637154f,
488 -0.027833903f, 0.029774971f, 0.1130802f, 0.09218906f, 0.09506135f,
489 -0.086665764f,-0.037162706f,-0.038880914f,-0.035832845f,-0.014481564f,
490 -0.09825003f,-0.12048569f,-0.097665586f,-0.05287633f, -0.0964047f,
491 -0.11366429f, 0.035777505f, 0.13568819f, 0.052451383f,0.050649304f,
492 0.05798951f, -0.021852335f,-0.099848844f,0.014740475f,-0.078897946f,
493 0.04974699f, 0.014160473f, 0.06973932f, 0.04964942f, 0.033364646f,
494 0.08190124f, 0.025535367f, 0.050893165f, 0.048514254f,0.06945813f,
495 -0.078907564f,-0.06707616f, -0.11844508f, -0.09986688f,-0.07509403f,
496 0.06263226f, 0.14925587f, 0.20188436f, 0.12098451f,0.14639415f,
497 0.0015017595f, -0.014267382f, -0.03417257f,0.012711468f,0.0028300495f,
498 -0.024758482f, -0.05098548f,-0.0821182f, 0.014225672f, 0.021544158f,
499 0.08949725f, 0.07505268f, -0.0020780868f, 0.04908258f,0.06476295f,
500 -0.022907063f,0.027562456f,0.040185735f, 0.019567577f,-0.015598739f,
501 -0.049097303f, -0.017121866f, -0.083368234f,-0.02332002f,-0.0840956f };
503 std::vector<float> inputGateBias = {0.02234832f, 0.14757581f, 0.18176508f, 0.10380666f, 0.053110216f,
504 -0.06928846f, -0.13942584f, -0.11816189f, 0.19483899f, 0.03652339f,
505 -0.10250295f, 0.036714908f, -0.18426876f, 0.036065217f, 0.21810818f,
506 0.02383196f, -0.043370757f, 0.08690144f, -0.04444982f, 0.00030581196f };
508 std::vector<float> forgetGateBias ={0.035185695f, -0.042891346f, -0.03032477f, 0.23027696f,
509 0.11098921f, 0.15378423f, 0.09263801f, 0.09790885f,
510 0.09508917f, 0.061199076f, 0.07665568f, -0.015443159f,
511 -0.03499149f, 0.046190713f, 0.08895977f, 0.10899629f,
512 0.40694186f, 0.06030037f, 0.012413437f, -0.06108739f };
514 std::vector<float> cellBias = { -0.024379363f, 0.0055531194f, 0.23377132f, 0.033463873f,
515 -0.1483596f, -0.10639995f, -0.091433935f, 0.058573797f,
516 -0.06809782f, -0.07889636f, -0.043246906f, -0.09829136f,
517 -0.4279842f, 0.034901652f, 0.18797937f, 0.0075234566f,
518 0.016178843f, 0.1749513f, 0.13975595f, 0.92058027f };
520 std::vector<float> outputGateBias ={0.046159424f, -0.0012809046f, 0.03563469f, 0.12648113f, 0.027195795f,
521 0.35373217f, -0.018957434f, 0.008907322f, -0.0762701f, 0.12018895f,
522 0.04216877f, 0.0022856654f, 0.040952638f, 0.3147856f, 0.08225149f,
523 -0.057416286f, -0.14995944f, -0.008040261f, 0.13208859f, 0.029760877f};
525 std::vector<float> recurrentToInputWeights = { -0.001374326f, -0.078856036f, 0.10672688f, 0.029162422f,
526 -0.11585556f, 0.02557986f, -0.13446963f, -0.035785314f,
527 -0.01244275f, 0.025961924f, -0.02337298f, -0.044228926f,
528 -0.055839065f, -0.046598054f, -0.010546039f, -0.06900766f,
529 0.027239809f, 0.022582639f, -0.013296484f, -0.05459212f,
530 0.08981f, -0.045407712f, 0.08682226f, -0.06867011f,
531 -0.14390695f, -0.02916037f, 0.000996957f, 0.091420636f,
532 0.14283475f, -0.07390571f, -0.06402044f, 0.062524505f,
533 -0.093129106f, 0.04860203f, -0.08364217f, -0.08119002f,
534 0.009352075f, 0.22920375f, 0.0016303885f, 0.11583097f,
535 -0.13732095f, 0.012405723f, -0.07551853f, 0.06343048f,
536 0.12162708f, -0.031923793f, -0.014335606f, 0.01790974f,
537 -0.10650317f, -0.0724401f, 0.08554849f, -0.05727212f,
538 0.06556731f, -0.042729504f, -0.043227166f, 0.011683251f,
539 -0.013082158f, -0.029302018f, -0.010899579f, -0.062036745f,
540 -0.022509435f, -0.00964907f, -0.01567329f, 0.04260106f,
541 -0.07787477f, -0.11576462f, 0.017356863f, 0.048673786f,
542 -0.017577527f, -0.05527947f, -0.082487635f, -0.040137455f,
543 -0.10820036f, -0.04666372f, 0.022746278f, -0.07851417f,
544 0.01068115f, 0.032956902f, 0.022433773f, 0.0026891115f,
545 0.08944216f, -0.0685835f, 0.010513544f, 0.07228705f,
546 0.02032331f, -0.059686817f, -0.0005566496f, -0.086984694f,
547 0.040414046f, -0.1380399f, 0.094208956f, -0.05722982f,
548 0.012092817f, -0.04989123f, -0.086576f, -0.003399834f,
549 -0.04696032f, -0.045747425f, 0.10091314f, 0.048676282f,
550 -0.029037097f, 0.031399418f, -0.0040285117f, 0.047237843f,
551 0.09504992f, 0.041799378f, -0.049185462f, -0.031518843f,
552 -0.10516937f, 0.026374253f, 0.10058866f, -0.0033195973f,
553 -0.041975245f, 0.0073591834f, 0.0033782164f, -0.004325073f,
554 -0.10167381f, 0.042500053f, -0.01447153f, 0.06464186f,
555 -0.017142897f, 0.03312627f, 0.009205989f, 0.024138335f,
556 -0.011337001f, 0.035530265f, -0.010912711f, 0.0706555f,
557 -0.005894094f, 0.051841937f, -0.1401738f, -0.02351249f,
558 0.0365468f, 0.07590991f, 0.08838724f, 0.021681072f,
559 -0.10086113f, 0.019608743f, -0.06195883f, 0.077335775f,
560 0.023646897f, -0.095322326f, 0.02233014f, 0.09756986f,
561 -0.048691444f, -0.009579111f, 0.07595467f, 0.11480546f,
562 -0.09801813f, 0.019894179f, 0.08502348f, 0.004032281f,
563 0.037211012f, 0.068537936f, -0.048005626f, -0.091520436f,
564 -0.028379958f, -0.01556313f, 0.06554592f, -0.045599163f,
565 -0.01672207f, -0.020169014f, -0.011877351f, -0.20212261f,
566 0.010889619f, 0.0047078193f, 0.038385306f, 0.08540671f,
567 -0.017140968f, -0.0035865551f, 0.016678626f, 0.005633034f,
568 0.015963363f, 0.00871737f, 0.060130805f, 0.028611384f,
569 0.10109069f, -0.015060172f, -0.07894427f, 0.06401885f,
570 0.011584063f, -0.024466386f, 0.0047652307f, -0.09041358f,
571 0.030737216f, -0.0046374933f, 0.14215417f, -0.11823516f,
572 0.019899689f, 0.006106124f, -0.027092824f, 0.0786356f,
573 0.05052217f, -0.058925f, -0.011402121f, -0.024987547f,
574 -0.0013661642f, -0.06832946f, -0.015667673f, -0.1083353f,
575 -0.00096863037f, -0.06988685f, -0.053350925f, -0.027275559f,
576 -0.033664223f, -0.07978348f, -0.025200296f, -0.017207067f,
577 -0.058403496f, -0.055697463f, 0.005798788f, 0.12965427f,
578 -0.062582195f, 0.0013350133f, -0.10482091f, 0.0379771f,
579 0.072521195f, -0.0029455067f, -0.13797039f, -0.03628521f,
580 0.013806405f, -0.017858358f, -0.01008298f, -0.07700066f,
581 -0.017081132f, 0.019358726f, 0.0027079724f, 0.004635139f,
582 0.062634714f, -0.02338735f, -0.039547626f, -0.02050681f,
583 0.03385117f, -0.083611414f, 0.002862572f, -0.09421313f,
584 0.058618143f, -0.08598433f, 0.00972939f, 0.023867095f,
585 -0.053934585f, -0.023203006f, 0.07452513f, -0.048767887f,
586 -0.07314807f, -0.056307215f, -0.10433547f, -0.06440842f,
587 0.04328182f, 0.04389765f, -0.020006588f, -0.09076438f,
588 -0.11652589f, -0.021705797f, 0.03345259f, -0.010329105f,
589 -0.025767034f, 0.013057034f, -0.07316461f, -0.10145612f,
590 0.06358255f, 0.18531723f, 0.07759293f, 0.12006465f,
591 0.1305557f, 0.058638252f, -0.03393652f, 0.09622831f,
592 -0.16253184f, -2.4580743e-06f, 0.079869635f, -0.070196845f,
593 -0.005644518f, 0.06857898f, -0.12598175f, -0.035084512f,
594 0.03156317f, -0.12794146f, -0.031963028f, 0.04692781f,
595 0.030070418f, 0.0071660685f, -0.095516115f, -0.004643372f,
596 0.040170413f, -0.062104587f, -0.0037324072f, 0.0554317f,
597 0.08184801f, -0.019164372f, 0.06791302f, 0.034257166f,
598 -0.10307039f, 0.021943003f, 0.046745934f, 0.0790918f,
599 -0.0265588f, -0.007824208f, 0.042546265f, -0.00977924f,
600 -0.0002440307f, -0.017384544f, -0.017990116f, 0.12252321f,
601 -0.014512694f, -0.08251313f, 0.08861942f, 0.13589665f,
602 0.026351685f, 0.012641483f, 0.07466548f, 0.044301085f,
603 -0.045414884f, -0.051112458f, 0.03444247f, -0.08502782f,
604 -0.04106223f, -0.028126027f, 0.028473156f, 0.10467447f };
606 std::vector<float> recurrentToForgetWeights = {-0.057784554f, -0.026057621f, -0.068447545f, -0.022581743f,
607 0.14811787f, 0.10826372f, 0.09471067f, 0.03987225f,
608 -0.0039523416f, 0.00030638507f, 0.053185795f, 0.10572994f,
609 0.08414449f, -0.022036452f, -0.00066928595f, -0.09203576f,
610 0.032950465f, -0.10985798f, -0.023809856f, 0.0021431844f,
611 -0.02196096f, -0.00326074f, 0.00058621005f, -0.074678116f,
612 -0.06193199f, 0.055729095f, 0.03736828f, 0.020123724f,
613 0.061878487f, -0.04729229f, 0.034919553f, -0.07585433f,
614 -0.04421272f, -0.044019096f, 0.085488975f, 0.04058006f,
615 -0.06890133f, -0.030951202f, -0.024628663f, -0.07672815f,
616 0.034293607f, 0.08556707f, -0.05293577f, -0.033561368f,
617 -0.04899627f, 0.0241671f, 0.015736353f, -0.095442444f,
618 -0.029564252f, 0.016493602f, -0.035026584f, 0.022337519f,
619 -0.026871363f, 0.004780428f, 0.0077918363f, -0.03601621f,
620 0.016435321f, -0.03263031f, -0.09543275f, -0.047392778f,
621 0.013454138f, 0.028934088f, 0.01685226f, -0.086110644f,
622 -0.046250615f, -0.01847454f, 0.047608484f, 0.07339695f,
623 0.034546845f, -0.04881143f, 0.009128804f, -0.08802852f,
624 0.03761666f, 0.008096139f, -0.014454086f, 0.014361001f,
625 -0.023502491f, -0.0011840804f, -0.07607001f, 0.001856849f,
626 -0.06509276f, -0.006021153f, -0.08570962f, -0.1451793f,
627 0.060212336f, 0.055259194f, 0.06974018f, 0.049454916f,
628 -0.027794661f, -0.08077226f, -0.016179763f, 0.1169753f,
629 0.17213494f, -0.0056326236f, -0.053934924f, -0.0124349f,
630 -0.11520337f, 0.05409887f, 0.088759385f, 0.0019655675f,
631 0.0042065294f, 0.03881498f, 0.019844765f, 0.041858196f,
632 -0.05695512f, 0.047233116f, 0.038937137f, -0.06542224f,
633 0.014429736f, -0.09719407f, 0.13908425f, -0.05379757f,
634 0.012321099f, 0.082840554f, -0.029899208f, 0.044217527f,
635 0.059855383f, 0.07711018f, -0.045319796f, 0.0948846f,
636 -0.011724666f, -0.0033288454f, -0.033542685f, -0.04764985f,
637 -0.13873616f, 0.040668588f, 0.034832682f, -0.015319203f,
638 -0.018715994f, 0.046002675f, 0.0599172f, -0.043107376f,
639 0.0294216f, -0.002314414f, -0.022424703f, 0.0030315618f,
640 0.0014641669f, 0.0029166266f, -0.11878115f, 0.013738511f,
641 0.12375372f, -0.0006038222f, 0.029104086f, 0.087442465f,
642 0.052958444f, 0.07558703f, 0.04817258f, 0.044462286f,
643 -0.015213451f, -0.08783778f, -0.0561384f, -0.003008196f,
644 0.047060397f, -0.002058388f, 0.03429439f, -0.018839769f,
645 0.024734668f, 0.024614193f, -0.042046934f, 0.09597743f,
646 -0.0043254104f, 0.04320769f, 0.0064070094f, -0.0019131786f,
647 -0.02558259f, -0.022822596f, -0.023273505f, -0.02464396f,
648 -0.10991725f, -0.006240552f, 0.0074488563f, 0.024044557f,
649 0.04383914f, -0.046476185f, 0.028658995f, 0.060410924f,
650 0.050786525f, 0.009452605f, -0.0073054377f, -0.024810238f,
651 0.0052906186f, 0.0066939713f, -0.0020913032f, 0.014515517f,
652 0.015898481f, 0.021362653f, -0.030262267f, 0.016587038f,
653 -0.011442813f, 0.041154444f, -0.007631438f, -0.03423484f,
654 -0.010977775f, 0.036152758f, 0.0066366293f, 0.11915515f,
655 0.02318443f, -0.041350313f, 0.021485701f, -0.10906167f,
656 -0.028218046f, -0.00954771f, 0.020531068f, -0.11995105f,
657 -0.03672871f, 0.024019798f, 0.014255957f, -0.05221243f,
658 -0.00661567f, -0.04630967f, 0.033188973f, 0.10107534f,
659 -0.014027541f, 0.030796422f, -0.10270911f, -0.035999842f,
660 0.15443139f, 0.07684145f, 0.036571592f, -0.035900835f,
661 -0.0034699554f, 0.06209149f, 0.015920248f, -0.031122351f,
662 -0.03858649f, 0.01849943f, 0.13872518f, 0.01503974f,
663 0.069941424f, -0.06948533f, -0.0088794185f, 0.061282158f,
664 -0.047401894f, 0.03100163f, -0.041533746f, -0.10430945f,
665 0.044574402f, -0.01425562f, -0.024290353f, 0.034563623f,
666 0.05866852f, 0.023947537f, -0.09445152f, 0.035450947f,
667 0.02247216f, -0.0042998926f, 0.061146557f, -0.10250651f,
668 0.020881841f, -0.06747029f, 0.10062043f, -0.0023941975f,
669 0.03532124f, -0.016341697f, 0.09685456f, -0.016764693f,
670 0.051808182f, 0.05875331f, -0.04536488f, 0.001626336f,
671 -0.028892258f, -0.01048663f, -0.009793449f, -0.017093895f,
672 0.010987891f, 0.02357273f, -0.00010856845f, 0.0099760275f,
673 -0.001845119f, -0.03551521f, 0.0018358806f, 0.05763657f,
674 -0.01769146f, 0.040995963f, 0.02235177f, -0.060430344f,
675 0.11475477f, -0.023854522f, 0.10071741f, 0.0686208f,
676 -0.014250481f, 0.034261297f, 0.047418304f, 0.08562733f,
677 -0.030519066f, 0.0060542435f, 0.014653856f, -0.038836084f,
678 0.04096551f, 0.032249358f, -0.08355519f, -0.026823482f,
679 0.056386515f, -0.010401743f, -0.028396193f, 0.08507674f,
680 0.014410365f, 0.020995233f, 0.17040324f, 0.11511526f,
681 0.02459721f, 0.0066619175f, 0.025853224f, -0.023133837f,
682 -0.081302024f, 0.017264642f, -0.009585969f, 0.09491168f,
683 -0.051313367f, 0.054532815f, -0.014298593f, 0.10657464f,
684 0.007076659f, 0.10964551f, 0.0409152f, 0.008275321f,
685 -0.07283536f, 0.07937492f, 0.04192024f, -0.1075027f };
687 std::vector<float> recurrentToCellWeights = { -0.037322544f, 0.018592842f, 0.0056175636f, -0.06253426f,
688 0.055647098f, -0.05713207f, -0.05626563f, 0.005559383f,
689 0.03375411f, -0.025757805f, -0.088049285f, 0.06017052f,
690 -0.06570978f, 0.007384076f, 0.035123326f, -0.07920549f,
691 0.053676967f, 0.044480428f, -0.07663568f, 0.0071805613f,
692 0.08089997f, 0.05143358f, 0.038261272f, 0.03339287f,
693 -0.027673481f, 0.044746667f, 0.028349208f, 0.020090483f,
694 -0.019443132f, -0.030755889f, -0.0040000007f, 0.04465846f,
695 -0.021585021f, 0.0031670958f, 0.0053199246f, -0.056117613f,
696 -0.10893326f, 0.076739706f, -0.08509834f, -0.027997585f,
697 0.037871376f, 0.01449768f, -0.09002357f, -0.06111149f,
698 -0.046195522f, 0.0422062f, -0.005683705f, -0.1253618f,
699 -0.012925729f, -0.04890792f, 0.06985068f, 0.037654128f,
700 0.03398274f, -0.004781977f, 0.007032333f, -0.031787455f,
701 0.010868644f, -0.031489216f, 0.09525667f, 0.013939797f,
702 0.0058680447f, 0.0167067f, 0.02668468f, -0.04797466f,
703 -0.048885044f, -0.12722108f, 0.035304096f, 0.06554885f,
704 0.00972396f, -0.039238118f, -0.05159735f, -0.11329045f,
705 0.1613692f, -0.03750952f, 0.06529313f, -0.071974665f,
706 -0.11769596f, 0.015524369f, -0.0013754242f, -0.12446318f,
707 0.02786344f, -0.014179351f, 0.005264273f, 0.14376344f,
708 0.015983658f, 0.03406988f, -0.06939408f, 0.040699873f,
709 0.02111075f, 0.09669095f, 0.041345075f, -0.08316494f,
710 -0.07684199f, -0.045768797f, 0.032298047f, -0.041805092f,
711 0.0119405f, 0.0061010392f, 0.12652606f, 0.0064572375f,
712 -0.024950314f, 0.11574242f, 0.04508852f, -0.04335324f,
713 0.06760663f, -0.027437469f, 0.07216407f, 0.06977076f,
714 -0.05438599f, 0.034033038f, -0.028602652f, 0.05346137f,
715 0.043184172f, -0.037189785f, 0.10420091f, 0.00882477f,
716 -0.054019816f, -0.074273005f, -0.030617684f, -0.0028467078f,
717 0.024302477f, -0.0038869337f, 0.005332455f, 0.0013399826f,
718 0.04361412f, -0.007001822f, 0.09631092f, -0.06702025f,
719 -0.042049985f, -0.035070654f, -0.04103342f, -0.10273396f,
720 0.0544271f, 0.037184782f, -0.13150354f, -0.0058036847f,
721 -0.008264958f, 0.042035464f, 0.05891794f, 0.029673764f,
722 0.0063542654f, 0.044788733f, 0.054816857f, 0.062257513f,
723 -0.00093483756f, 0.048938446f, -0.004952862f, -0.007730018f,
724 -0.04043371f, -0.017094059f, 0.07229206f, -0.023670016f,
725 -0.052195564f, -0.025616996f, -0.01520939f, 0.045104615f,
726 -0.007376126f, 0.003533447f, 0.006570588f, 0.056037236f,
727 0.12436656f, 0.051817212f, 0.028532185f, -0.08686856f,
728 0.11868599f, 0.07663395f, -0.07323171f, 0.03463402f,
729 -0.050708205f, -0.04458982f, -0.11590894f, 0.021273347f,
730 0.1251325f, -0.15313013f, -0.12224372f, 0.17228661f,
731 0.023029093f, 0.086124025f, 0.006445803f, -0.03496501f,
732 0.028332196f, 0.04449512f, -0.042436164f, -0.026587414f,
733 -0.006041347f, -0.09292539f, -0.05678812f, 0.03897832f,
734 0.09465633f, 0.008115513f, -0.02171956f, 0.08304309f,
735 0.071401566f, 0.019622514f, 0.032163795f, -0.004167056f,
736 0.02295182f, 0.030739572f, 0.056506045f, 0.004612461f,
737 0.06524936f, 0.059999723f, 0.046395954f, -0.0045512207f,
738 -0.1335546f, -0.030136576f, 0.11584653f, -0.014678886f,
739 0.0020118146f, -0.09688814f, -0.0790206f, 0.039770417f,
740 -0.0329582f, 0.07922767f, 0.029322514f, 0.026405897f,
741 0.04207835f, -0.07073373f, 0.063781224f, 0.0859677f,
742 -0.10925287f, -0.07011058f, 0.048005477f, 0.03438226f,
743 -0.09606514f, -0.006669445f, -0.043381985f, 0.04240257f,
744 -0.06955775f, -0.06769346f, 0.043903265f, -0.026784198f,
745 -0.017840602f, 0.024307009f, -0.040079936f, -0.019946516f,
746 0.045318738f, -0.12233574f, 0.026170589f, 0.0074471775f,
747 0.15978073f, 0.10185836f, 0.10298046f, -0.015476589f,
748 -0.039390966f, -0.072174534f, 0.0739445f, -0.1211869f,
749 -0.0347889f, -0.07943156f, 0.014809798f, -0.12412325f,
750 -0.0030663363f, 0.039695457f, 0.0647603f, -0.08291318f,
751 -0.018529687f, -0.004423833f, 0.0037507233f, 0.084633216f,
752 -0.01514876f, -0.056505352f, -0.012800942f, -0.06994386f,
753 0.012962922f, -0.031234352f, 0.07029052f, 0.016418684f,
754 0.03618972f, 0.055686004f, -0.08663945f, -0.017404709f,
755 -0.054761406f, 0.029065743f, 0.052404847f, 0.020238016f,
756 0.0048197987f, -0.0214882f, 0.07078733f, 0.013016777f,
757 0.06262858f, 0.009184685f, 0.020785125f, -0.043904778f,
758 -0.0270329f, -0.03299152f, -0.060088247f, -0.015162964f,
759 -0.001828936f, 0.12642565f, -0.056757294f, 0.013586685f,
760 0.09232601f, -0.035886683f, 0.06000002f, 0.05229691f,
761 -0.052580316f, -0.082029596f, -0.010794592f, 0.012947712f,
762 -0.036429964f, -0.085508935f, -0.13127148f, -0.017744139f,
763 0.031502828f, 0.036232427f, -0.031581745f, 0.023051167f,
764 -0.05325106f, -0.03421577f, 0.028793324f, -0.034633752f,
765 -0.009881397f, -0.043551125f, -0.018609839f, 0.0019097115f,
766 -0.008799762f, 0.056595087f, 0.0022273948f, 0.055752404f };
768 std::vector<float> recurrentToOutputWeights = { 0.025825322f, -0.05813119f, 0.09495884f,-0.045984812f, -0.01255415f,
769 -0.0026479573f,-0.08196161f,-0.054914974f,-0.0046604523f,
770 -0.029587349f, -0.044576716f, -0.07480124f, -0.082868785f,
771 0.023254942f, 0.027502948f, -0.0039728214f, -0.08683098f,
772 -0.08116779f, -0.014675607f, -0.037924774f, -0.023314456f,
773 -0.007401714f, -0.09255757f, 0.029460307f, -0.08829125f,
774 -0.005139627f, -0.08989442f, -0.0555066f, 0.13596267f,
775 -0.025062224f, -0.048351806f, -0.03850004f, 0.07266485f,
776 -0.022414139f, 0.05940088f, 0.075114764f, 0.09597592f,
777 -0.010211725f, -0.0049794707f, -0.011523867f, -0.025980417f,
778 0.072999895f, 0.11091378f, -0.081685916f, 0.014416728f,
779 0.043229222f, 0.034178585f, -0.07530371f, 0.035837382f,
780 -0.085607f, -0.007721233f, -0.03287832f, -0.043848954f,
781 -0.06404588f, -0.06632928f, -0.073643476f, 0.008214239f,
782 -0.045984086f, 0.039764922f, 0.03474462f, 0.060612556f,
783 -0.080590084f, 0.049127717f, 0.04151091f, -0.030063879f,
784 0.008801774f, -0.023021035f, -0.019558564f, 0.05158114f,
785 -0.010947698f, -0.011825728f, 0.0075720972f, 0.0699727f,
786 -0.0039981045f, 0.069350146f, 0.08799282f, 0.016156472f,
787 0.035502106f, 0.11695009f, 0.006217345f, 0.13392477f,
788 -0.037875112f, 0.025745004f, 0.08940699f, -0.00924166f,
789 0.0046702605f, -0.036598757f, -0.08811812f, 0.10522024f,
790 -0.032441203f, 0.008176899f, -0.04454919f, 0.07058152f,
791 0.0067963637f, 0.039206743f, 0.03259838f, 0.03725492f,
792 -0.09515802f, 0.013326398f, -0.052055415f, -0.025676316f,
793 0.03198509f, -0.015951829f, -0.058556724f, 0.036879618f,
794 0.043357447f, 0.028362012f, -0.05908629f, 0.0059240665f,
795 -0.04995891f, -0.019187413f,0.0276265f, -0.01628143f, 0.0025863599f,
796 0.08800015f, 0.035250366f, -0.022165963f, -0.07328642f,
797 -0.009415526f, -0.07455109f, 0.11690406f, 0.0363299f,
798 0.07411125f, 0.042103454f, -0.009660886f, 0.019076364f,
799 0.018299393f, -0.046004917f, 0.08891175f,0.0431396f, -0.026327137f,
800 -0.051502608f, 0.08979574f, -0.051670972f, 0.04940282f,
801 -0.07491107f, -0.021240504f, 0.022596184f, -0.034280192f,
802 0.060163025f, -0.058211457f, -0.051837247f, -0.01349775f,
803 -0.04639988f, -0.035936575f, -0.011681591f, 0.064818054f,
804 0.0073146066f, -0.021745546f, -0.043124277f, -0.06471268f,
805 -0.07053354f, -0.029321948f, -0.05330136f, 0.016933719f,
806 -0.053782392f, 0.13747959f, -0.1361751f, -0.11569455f,
807 0.0033329215f, 0.05693899f, -0.053219706f, 0.063698f,
808 0.07977434f, -0.07924483f, 0.06936997f, 0.0034815092f,
809 -0.007305279f, -0.037325785f, -0.07251102f, -0.033633437f,
810 -0.08677009f, 0.091591336f, -0.14165086f, 0.021752775f,
811 0.019683983f, 0.0011612234f, -0.058154266f, 0.049996935f,
812 0.0288841f, -0.0024567875f, -0.14345716f, 0.010955264f,-0.10234828f,
813 0.1183656f, -0.0010731248f, -0.023590032f,-0.072285876f,-0.0724771f,
814 -0.026382286f, -0.0014920527f, 0.042667855f, 0.0018776858f,
815 0.02986552f, 0.009814309f, 0.0733756f, 0.12289186f,
816 0.018043943f, -0.0458958f, 0.049412545f, 0.033632483f,
817 0.05495232f, 0.036686596f, -0.013781798f, -0.010036754f,
818 0.02576849f, -0.08307328f, 0.010112348f, 0.042521734f,
819 -0.05869831f, -0.071689695f, 0.03876447f, -0.13275425f, -0.0352966f,
820 -0.023077697f, 0.10285965f, 0.084736146f, 0.15568255f,
821 -0.00040734606f, 0.027835453f, -0.10292561f, -0.032401145f,
822 0.10053256f, -0.026142767f, -0.08271222f, -0.0030240538f,
823 -0.016368777f, 0.1070414f, 0.042672627f, 0.013456989f,
824 -0.0437609f, -0.022309763f, 0.11576483f, 0.04108048f,
825 0.061026827f, -0.0190714f, -0.0869359f, 0.037901703f, 0.0610107f,
826 0.07202949f, 0.01675338f, 0.086139716f, -0.08795751f,
827 -0.014898893f, -0.023771819f, -0.01965048f, 0.007955471f,
828 -0.043740474f, 0.03346837f, -0.10549954f, 0.090567775f,
829 0.042013682f, -0.03176985f, 0.12569028f, -0.02421228f,
830 -0.029526481f, 0.023851605f, 0.031539805f, 0.05292009f,
831 -0.02344001f, -0.07811758f, -0.08834428f, 0.10094801f,
832 0.16594367f, -0.06861939f, -0.021256343f, -0.041093912f,
833 -0.06669611f, 0.035498552f, 0.021757556f, -0.09302526f,
834 -0.015403468f, -0.06614931f, -0.051798206f, -0.013874718f,
835 0.03630673f, 0.010412845f, -0.08077351f, 0.046185967f,
836 0.0035662893f, 0.03541868f, -0.094149634f, -0.034814864f,
837 0.003128424f, -0.020674974f, -0.03944324f, -0.008110165f,
838 -0.11113267f, 0.08484226f, 0.043586485f, 0.040582247f,
839 0.0968012f, -0.065249965f, -0.028036479f, 0.0050708856f,
840 0.0017462453f, 0.0326779f, 0.041296225f, 0.09164146f,
841 -0.047743853f, -0.015952192f, -0.034451712f, 0.084197424f,
842 -0.05347844f, -0.11768019f, 0.085926116f, -0.08251791f,
843 -0.045081906f, 0.0948852f, 0.068401024f, 0.024856757f,
844 0.06978981f, -0.057309967f, -0.012775832f, -0.0032452994f,
845 0.01977615f, -0.041040014f, -0.024264973f,0.063464895f, 0.05431621f};
847 std::vector<float> cellToInputWeights = {0.040369894f, 0.030746894f, 0.24704495f, 0.018586371f, -0.037586458f,
848 -0.15312155f, -0.11812848f, -0.11465643f, 0.20259799f, 0.11418174f,
849 -0.10116027f, -0.011334949f, 0.12411352f, -0.076769054f,-0.052169047f,
850 0.21198851f, -0.38871562f, -0.09061183f, -0.09683246f, -0.21929175f};
853 std::vector<float> cellToForgetWeights = {-0.01998659f,-0.15568835f,-0.24248174f, -0.012770197f, 0.041331276f,
854 -0.072311886f, -0.052123554f,-0.0066330447f,-0.043891653f,0.036225766f,
855 -0.047248036f, 0.021479502f,0.033189066f, 0.11952997f, -0.020432774f,
856 0.64658105f, -0.06650122f, -0.03467612f, 0.095340036f, 0.23647355f};
858 std::vector<float> cellToOutputWeights = { 0.08286371f, -0.08261836f, -0.51210177f, 0.002913762f, 0.17764764f,
859 -0.5495371f, -0.08460716f, -0.24552552f, 0.030037103f, 0.04123544f,
860 -0.11940523f, 0.007358328f, 0.1890978f, 0.4833202f, -0.34441817f,
861 0.36312827f, -0.26375428f, 0.1457655f, -0.19724406f, 0.15548733f};
863 std::vector<float> projectionWeights={-0.009802181f, 0.09401916f, 0.0717386f, -0.13895074f, 0.09641832f,
864 0.060420845f, 0.08539281f, 0.054285463f, 0.061395317f, 0.034448683f,
865 -0.042991187f, 0.019801661f, -0.16840284f, -0.015726732f, -0.23041931f,
866 -0.024478018f, -0.10959692f, -0.013875541f, 0.18600968f, -0.061274476f,
867 0.0138165f, -0.08160894f, -0.07661644f, 0.032372914f, 0.16169067f,
868 0.22465782f, -0.03993472f, -0.004017731f, 0.08633481f, -0.28869787f,
869 0.08682067f, 0.17240396f, 0.014975425f, 0.056431185f, 0.031037588f,
870 0.16702051f, 0.0077946745f, 0.15140012f, 0.29405436f, 0.120285f,
871 -0.188994f, -0.027265169f, 0.043389652f, -0.022061434f, 0.014777949f,
872 -0.20203483f, 0.094781205f, 0.19100232f, 0.13987629f, -0.036132768f,
873 -0.06426278f, -0.05108664f, 0.13221376f, 0.009441198f, -0.16715929f,
874 0.15859416f, -0.040437475f, 0.050779544f, -0.022187516f, 0.012166504f,
875 0.027685808f, -0.07675938f, -0.0055694645f, -0.09444123f, 0.0046453946f,
876 0.050794356f, 0.10770313f, -0.20790008f, -0.07149004f, -0.11425117f,
877 0.008225835f, -0.035802525f, 0.14374903f, 0.15262283f, 0.048710253f,
878 0.1847461f, -0.007487823f, 0.11000021f, -0.09542012f, 0.22619456f,
879 -0.029149994f, 0.08527916f, 0.009043713f, 0.0042746216f, 0.016261552f,
880 0.022461696f, 0.12689082f, -0.043589946f, -0.12035478f, -0.08361797f,
881 -0.050666027f, -0.1248618f, -0.1275799f, -0.071875185f, 0.07377272f,
882 0.09944291f, -0.18897448f, -0.1593054f, -0.06526116f, -0.040107165f,
883 -0.004618631f, -0.067624845f, -0.007576253f, 0.10727444f, 0.041546922f,
884 -0.20424393f, 0.06907816f, 0.050412357f, 0.00724631f, 0.039827548f,
885 0.12449835f, 0.10747581f, 0.13708383f, 0.09134148f, -0.12617786f,
886 -0.06428341f, 0.09956831f, 0.1208086f, -0.14676677f, -0.0727722f,
887 0.1126304f, 0.010139365f, 0.015571211f, -0.038128063f, 0.022913318f,
888 -0.042050496f, 0.16842307f, -0.060597885f, 0.10531834f, -0.06411776f,
889 -0.07451711f, -0.03410368f, -0.13393489f, 0.06534304f, 0.003620307f,
890 0.04490757f, 0.05970546f, 0.05197996f, 0.02839995f, 0.10434969f,
891 -0.013699693f, -0.028353551f, -0.07260381f, 0.047201227f, -0.024575593f,
892 -0.036445823f, 0.07155557f, 0.009672501f, -0.02328883f, 0.009533515f,
893 -0.03606021f, -0.07421458f, -0.028082801f, -0.2678904f, -0.13221288f,
894 0.18419984f, -0.13012612f, -0.014588381f, -0.035059117f, -0.04824723f,
895 0.07830115f, -0.056184657f, 0.03277091f, 0.025466874f, 0.14494097f,
896 -0.12522776f, -0.098633975f, -0.10766018f, -0.08317623f, 0.08594209f,
897 0.07749552f, 0.039474737f, 0.1776665f, -0.07409566f, -0.0477268f,
898 0.29323658f, 0.10801441f, 0.1154011f, 0.013952499f, 0.10739139f,
899 0.10708251f, -0.051456142f, 0.0074137426f, -0.10430189f, 0.10034707f,
900 0.045594677f, 0.0635285f, -0.0715442f, -0.089667566f, -0.10811871f,
901 0.00026344223f, 0.08298446f, -0.009525053f, 0.006585689f, -0.24567553f,
902 -0.09450807f, 0.09648481f, 0.026996298f, -0.06419476f, -0.04752702f,
903 -0.11063944f, -0.23441927f, -0.17608605f, -0.052156363f, 0.067035615f,
904 0.19271925f, -0.0032889997f, -0.043264326f, 0.09663576f, -0.057112187f,
905 -0.10100678f, 0.0628376f, 0.04447668f, 0.017961001f, -0.10094388f,
906 -0.10190601f, 0.18335468f, 0.10494553f, -0.052095775f, -0.0026118709f,
907 0.10539724f, -0.04383912f, -0.042349473f, 0.08438151f, -0.1947263f,
908 0.02251204f, 0.11216432f, -0.10307853f, 0.17351969f, -0.039091777f,
909 0.08066188f, -0.00561982f, 0.12633002f, 0.11335965f, -0.0088127935f,
910 -0.019777594f, 0.06864014f, -0.059751723f, 0.016233567f, -0.06894641f,
911 -0.28651384f, -0.004228674f, 0.019708522f, -0.16305895f, -0.07468996f,
912 -0.0855457f, 0.099339016f, -0.07580735f, -0.13775392f, 0.08434318f,
913 0.08330512f, -0.12131499f, 0.031935584f, 0.09180414f, -0.08876437f,
914 -0.08049874f, 0.008753825f, 0.03498998f, 0.030215185f, 0.03907079f,
915 0.089751154f, 0.029194152f, -0.03337423f, -0.019092513f, 0.04331237f,
916 0.04299654f, -0.036394123f, -0.12915532f, 0.09793732f, 0.07512415f,
917 -0.11319543f, -0.032502122f, 0.15661901f, 0.07671967f, -0.005491124f,
918 -0.19379048f, -0.218606f, 0.21448623f, 0.017840758f, 0.1416943f,
919 -0.07051762f, 0.19488361f, 0.02664691f, -0.18104725f, -0.09334311f,
920 0.15026465f, -0.15493552f, -0.057762887f, -0.11604192f, -0.262013f,
921 -0.01391798f, 0.012185008f, 0.11156489f, -0.07483202f, 0.06693364f,
922 -0.26151478f, 0.046425626f, 0.036540434f, -0.16435726f, 0.17338543f,
923 -0.21401681f, -0.11385144f, -0.08283257f, -0.069031075f, 0.030635102f,
924 0.010969227f, 0.11109743f, 0.010919218f, 0.027526086f, 0.13519906f,
925 0.01891392f, -0.046839405f, -0.040167913f, 0.017953383f, -0.09700955f,
926 0.0061885654f, -0.07000971f, 0.026893595f, -0.038844477f, 0.14543656f};
928 std::vector<float> projectionBiasVector(outputSize, 0.f);
990 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.
CreateLstm(data, info);
991 inputHandle->Allocate();
992 outputStateInHandle->Allocate();
993 cellStateInHandle->Allocate();
995 scratchHandle->Allocate();
996 outputStateOutHandle->Allocate();
997 cellStateOutHandle->Allocate();
998 outputHandle->Allocate();
1004 workload->Execute();
1010 outputHandle->GetShape(),
1011 outputTensorInfo.GetShape());
1014 template<armnn::DataType ArmnnType,
typename T = armnn::ResolveType<ArmnnType>>
1019 const std::vector<T>& input,
1020 const std::vector<T>& outputExpected,
1023 float qScale = 0.0f,
1024 int32_t qOffset = 0,
1028 bool cifgEnabled =
true;
1029 bool peepholeEnabled =
true;
1030 bool projectionEnabled =
false;
1037 const unsigned int cellSize = outputSize;
1040 armnn::TensorInfo inputTensorInfo({batchSize , inputSize}, ArmnnType, qScale, qOffset);
1041 armnn::TensorInfo outputStateInTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
1042 armnn::TensorInfo cellStateInTensorInfo({batchSize, cellSize}, ArmnnType, qScale, qOffset);
1044 unsigned int scratchBufferSize = cifgEnabled ? cellSize * 3 : cellSize * 4;
1045 armnn::TensorInfo scratchBufferTensorInfo({batchSize, scratchBufferSize}, ArmnnType, qScale, qOffset);
1046 armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
1047 armnn::TensorInfo cellStateOutTensorInfo({batchSize, cellSize}, ArmnnType, qScale, qOffset);
1048 armnn::TensorInfo outputTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
1051 std::vector<float> inputData;
1052 inputData.assign(input.data(), input.data() + batchSize*inputSize);
1054 std::vector<float> outputStateInVector(batchSize * outputSize, 0.f);
1056 std::vector<float> cellStateInVector(batchSize * cellSize, 0.f);
1060 armnn::TensorInfo tensorInfoInput({cellSize, inputSize}, constantDataType, qScale, qOffset);
1061 armnn::TensorInfo tensorInfoOutput({cellSize, outputSize}, constantDataType, qScale, qOffset);
1062 armnn::TensorInfo tensorInfoNumUnits({cellSize}, constantDataType, qScale, qOffset);
1064 std::vector<float> inputToCellWeights =
1066 -0.49770179f, -0.27711356f, -0.09624726f, 0.05100781f,
1067 0.04717243f, 0.48944736f, -0.38535351f,
1070 std::vector<float> inputToForgetWeights =
1072 -0.55291498f, -0.42866567f, 0.13056988f,
1073 -0.3633365f, -0.22755712f, 0.28253698f, 0.24407166f,
1076 std::vector<float> inputToOutputWeights =
1078 0.10725588f, -0.02335852f, -0.55932593f,
1079 -0.09426838f, -0.44257352f, 0.54939759f,
1080 0.01533556f, 0.42751634f
1082 std::vector<float> cellBias = {0.f, 0.f, 0.f, 0.f};
1083 std::vector<float> forgetGateBias = {1.f, 1.f, 1.f, 1.f};
1084 std::vector<float> outputGateBias = {0.f, 0.f, 0.f, 0.f};
1086 std::vector<float> recurrentToCellWeights =
1088 0.54066205f, -0.32668582f, -0.43562764f, -0.56094903f, 0.42957711f,
1089 0.01841056f, -0.32764608f, -0.33027974f, -0.10826075f, 0.20675004f,
1090 0.19069612f, -0.03026325f, -0.54532051f, 0.33003211f, 0.44901288f,
1093 std::vector<float> recurrentToForgetWeights =
1095 -0.13832897f, -0.0515101f, -0.2359007f, -0.16661474f, -0.14340827f,
1096 0.36986142f, 0.23414481f, 0.55899f, 0.10798943f, -0.41174671f, 0.17751795f,
1097 -0.34484994f, -0.35874045f, -0.11352962f, 0.27268326f, 0.54058349f
1100 std::vector<float> recurrentToOutputWeights =
1102 0.41613156f, 0.42610586f, -0.16495961f, -0.5663873f, 0.30579174f, -0.05115908f,
1103 -0.33941799f, 0.23364776f, 0.11178309f, 0.09481031f, -0.26424935f, 0.46261835f,
1104 0.50248802f, 0.26114327f, -0.43736315f, 0.33149987f
1107 std::vector<float> cellToForgetWeights = {0.47485286f, -0.51955009f, -0.24458408f, 0.31544167f};
1108 std::vector<float> cellToOutputWeights = {-0.17135078f, 0.82760304f, 0.85573703f, -0.77109635f};
1165 std::vector<T> scratchBufferVector(batchSize * scratchBufferSize, T());
1169 std::vector<T> outputStateOutVector(batchSize * outputSize, T());
1173 std::vector<T> cellStateOutVector(batchSize * cellSize, T());
1177 std::vector<T> outputData;
1178 outputData.assign(outputExpected.data(), outputExpected.data() + batchSize*outputSize);
1180 ret3.m_ExpectedData = outputData;
1182 std::vector<T> actualScratchBufferOutput(scratchBufferTensorInfo.GetNumElements());
1183 std::vector<T> actualOutputStateOutput(outputStateOutTensorInfo.GetNumElements());
1184 std::vector<T> actualCellStateOutput(cellStateOutTensorInfo.GetNumElements());
1185 std::vector<T> actualOutput(outputTensorInfo.GetNumElements());
1188 std::unique_ptr<armnn::ITensorHandle> inputHandle =
1190 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
1192 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
1195 std::unique_ptr<armnn::ITensorHandle> scratchBufferHandle =
1197 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
1199 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
1201 std::unique_ptr<armnn::ITensorHandle> outputHandle =
1205 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
1206 AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get());
1207 AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get());
1209 AddOutputToWorkload(data, info, scratchBufferTensorInfo, scratchBufferHandle.get());
1210 AddOutputToWorkload(data, info, outputStateOutTensorInfo, outputStateOutHandle.get());
1211 AddOutputToWorkload(data, info, cellStateOutTensorInfo, cellStateOutHandle.get());
1212 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
1214 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.
CreateLstm(data, info);
1216 inputHandle->Allocate();
1217 outputStateInHandle->Allocate();
1218 cellStateInHandle->Allocate();
1220 scratchBufferHandle->Allocate();
1221 outputStateOutHandle->Allocate();
1222 cellStateOutHandle->Allocate();
1223 outputHandle->Allocate();
1233 workload->Execute();
1240 ret0.m_ActualData = actualScratchBufferOutput;
1241 ret1.m_ActualData = actualOutputStateOutput;
1242 ret2.m_ActualData = actualCellStateOutput;
1243 ret3.m_ActualData = actualOutput;
1248 template<armnn::DataType ArmnnType,
typename T = armnn::ResolveType<ArmnnType>>
1253 const std::vector<T>& input,
1254 const std::vector<T>& outputExpected,
1255 float qScale = 0.0f,
1256 int32_t qOffset = 0,
1260 unsigned int batchSize = 2;
1261 unsigned int outputSize = 3;
1262 unsigned int inputSize = 5;
1263 unsigned numUnits = 4;
1265 armnn::TensorInfo inputTensorInfo({batchSize , inputSize}, ArmnnType, qScale, qOffset);
1266 armnn::TensorInfo cellStateInTensorInfo({batchSize , numUnits}, ArmnnType, qScale, qOffset);
1267 armnn::TensorInfo outputStateInTensorInfo({batchSize , outputSize}, ArmnnType, qScale, qOffset);
1270 armnn::TensorInfo scratchBufferTensorInfo({batchSize, numUnits * 4}, ArmnnType, qScale, qOffset);
1271 armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, ArmnnType, qScale, qOffset);
1272 armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
1273 armnn::TensorInfo outputTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
1275 std::vector<float> inputVector;
1276 inputVector.assign(input.data(), input.data() + (batchSize * inputSize));
1278 std::vector<float> cellStateInVector(batchSize * numUnits, 0.f);
1279 std::vector<float> outputStateInVector(batchSize * outputSize, 0.f);
1280 std::vector<float> scratchBufferVector(batchSize * numUnits * 4, 0.f);
1281 std::vector<float> outputStateOutVector(batchSize * outputSize, 0.f);
1282 std::vector<float> cellStateOutVector(batchSize * numUnits, 0.f);
1284 std::vector<float> actualOutput(outputTensorInfo.GetNumElements());
1286 std::vector<float> outputVector;
1287 outputVector.assign(outputExpected.data(), outputExpected.data() + (batchSize * outputSize));
1289 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.
CreateTensorHandle(inputTensorInfo);
1290 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
1292 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
1295 std::unique_ptr<armnn::ITensorHandle> scratchHandle =
1297 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
1299 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
1301 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.
CreateTensorHandle(outputTensorInfo);
1306 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
1307 AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get());
1308 AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get());
1310 AddOutputToWorkload(data, info, scratchBufferTensorInfo, scratchHandle.get());
1311 AddOutputToWorkload(data, info, outputStateOutTensorInfo, outputStateOutHandle.get());
1312 AddOutputToWorkload(data, info, cellStateOutTensorInfo, cellStateOutHandle.get());
1313 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
1317 armnn::TensorInfo tensorInfo4x5({numUnits, inputSize}, constantDataType, qScale, qOffset);
1318 armnn::TensorInfo tensorInfo4x3({numUnits, outputSize}, constantDataType, qScale, qOffset);
1319 armnn::TensorInfo tensorInfo3x4({outputSize, numUnits}, constantDataType, qScale, qOffset);
1321 std::vector<float> inputToInputWeights = {0.5f, 0.6f, 0.7f, -0.8f, -0.9f,
1322 0.1f, 0.2f, 0.3f, -0.4f, 0.5f,
1323 -0.8f, 0.7f, -0.6f, 0.5f, -0.4f,
1324 -0.5f, -0.4f, -0.3f, -0.2f, -0.1f};
1326 std::vector<float> inputToForgetWeights = { -0.6f, -0.1f, 0.3f, 0.2f, 0.9f,
1327 -0.5f, -0.2f, -0.4f, 0.3f, -0.8f,
1328 -0.4f, 0.3f, -0.5f, -0.4f, -0.6f,
1329 0.3f, -0.4f, -0.6f, -0.5f, -0.5f};
1331 std::vector<float> inputToCellWeights = {-0.4f, -0.3f, -0.2f, -0.1f, -0.5f,
1332 0.5f, -0.2f, -0.3f, -0.2f, -0.6f,
1333 0.6f, -0.1f, -0.4f, -0.3f, -0.7f,
1334 0.7f, -0.9f, -0.5f, 0.8f, 0.6f};
1336 std::vector<float> inputToOutputWeights = {-0.8f, -0.4f, -0.2f, -0.9f, -0.1f,
1337 -0.7f, 0.3f, -0.3f, -0.8f, -0.2f,
1338 0.6f, -0.2f, 0.4f, -0.7f, -0.3f,
1339 -0.5f, 0.1f, 0.5f, -0.6f, -0.4f};
1341 std::vector<float> inputGateBias = {0.03f, 0.15f, 0.22f, 0.38f};
1343 std::vector<float> forgetGateBias = {0.1f, -0.3f, -0.2f, 0.1f};
1345 std::vector<float> cellBias = {-0.05f, 0.72f, 0.25f, 0.08f};
1347 std::vector<float> outputGateBias = {0.05f, -0.01f, 0.2f, 0.1f};
1349 std::vector<float> recurrentToInputWeights ={-0.2f, -0.3f, 0.4f,
1351 -0.2f, -0.3f, -0.7f,
1352 0.05f, -0.2f, -0.6f};
1354 std::vector<float> recurrentToCellWeights = {-0.3f, 0.2f, 0.1f,
1355 -0.3f, 0.8f, -0.08f,
1357 -0.6f, -0.1f, 0.2f};
1359 std::vector<float> recurrentToForgetWeights = { -0.5f, -0.3f, -0.5f,
1364 std::vector<float> recurrentToOutputWeights = { 0.3f, -0.1f, 0.1f,
1365 -0.2f, -0.5f, -0.7f,
1366 -0.2f, -0.6f, -0.1f,
1367 -0.4f, -0.7f, -0.2f};
1369 std::vector<float> cellToInputWeights = {0.05f, 0.1f, 0.25f, 0.15f};
1371 std::vector<float> cellToForgetWeights = {-0.02f, -0.15f, -0.25f, -0.03f};
1373 std::vector<float> cellToOutputWeights = {0.1f, -0.1f, -0.5f, 0.05f};
1375 std::vector<float> projectionWeights = {-0.1f, 0.2f, 0.01f, -0.2f,
1376 0.1f, 0.5f, 0.3f, 0.08f,
1377 0.07f, 0.2f, -0.4f, 0.2f};
1379 std::vector<float> projectionBiasVector(outputSize, 0.f);
1381 std::vector<float> inputLayerNormWeights = {0.1f, 0.2f, 0.3f, 0.5f};
1383 std::vector<float> forgetLayerNormWeights = {0.2f, 0.2f, 0.4f, 0.3f};
1385 std::vector<float> cellLayerNormWeights = {0.7f, 0.2f, 0.3f, 0.8f};
1387 std::vector<float> outputLayerNormWeights = {0.6f, 0.2f, 0.2f, 0.5f};
1467 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.
CreateLstm(data, info);
1468 inputHandle->Allocate();
1469 outputStateInHandle->Allocate();
1470 cellStateInHandle->Allocate();
1472 scratchHandle->Allocate();
1473 outputStateOutHandle->Allocate();
1474 cellStateOutHandle->Allocate();
1475 outputHandle->Allocate();
1481 workload->Execute();
1487 outputHandle->GetShape(),
1488 outputTensorInfo.GetShape());
1495 const std::vector<uint8_t>& input,
1496 const std::vector<uint8_t>& outputExpected,
1506 float inputOutputScale = 0.0078125f;
1507 int32_t inputOutputOffset = 128;
1509 float cellStateScale = 0.00048828125f;
1510 int32_t cellStateOffset = 0;
1512 float weightsScale = 0.00408021f;
1513 int32_t weightsOffset = 100;
1515 float biasScale = 3.1876640625e-05f;
1516 int32_t biasOffset = 0;
1535 std::vector<uint8_t> inputVector;
1536 inputVector.assign(input.data(), input.data() + (numBatches * inputSize));
1539 std::vector<int16_t> cellStateInVector = {876, 1034, 955, -909, 761, 1029, 796, -1036};
1541 std::vector<uint8_t> outputStateInVector = {136, 150, 140, 115, 135, 152, 138, 112};
1544 std::vector<int16_t> cellStateOutVector = {1485, 1177, 1373, -1023, 1019, 1355, 1097, -1235};
1547 std::vector<uint8_t> outputVector;
1548 outputVector.assign(outputExpected.data(), outputExpected.data() + (numBatches * outputSize));
1550 std::vector<uint8_t> actualOutput(outputStateInfo.GetNumElements());
1553 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.
CreateTensorHandle(inputInfo);
1554 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
1556 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
1559 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
1561 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.
CreateTensorHandle(outputStateInfo);
1567 AddInputToWorkload(data, info, inputInfo, inputHandle.get());
1568 AddInputToWorkload(data, info, cellStateInfo, cellStateInHandle.get());
1569 AddInputToWorkload(data, info, outputStateInfo, outputStateInHandle.get());
1571 AddOutputToWorkload(data, info, cellStateInfo, cellStateOutHandle.get());
1572 AddOutputToWorkload(data, info, outputStateInfo, outputHandle.get());
1588 std::vector<uint8_t> inputToInputWeights = {146, 250, 235, 171, 10, 218, 171, 108};
1589 std::vector<uint8_t> inputToForgetWeights = {24, 50, 132, 179, 158, 110, 3, 169};
1590 std::vector<uint8_t> inputToCellWeights = {133, 34, 29, 49, 206, 109, 54, 183};
1591 std::vector<uint8_t> inputToOutputWeights = {195, 187, 11, 99, 109, 10, 218, 48};
1593 std::vector<uint8_t> recurrentToInputWeights =
1594 {254, 206, 77, 168, 71, 20, 215, 6, 223, 7, 118, 225, 59, 130, 174, 26};
1595 std::vector<uint8_t> recurrentToForgetWeights =
1596 {137, 240, 103, 52, 68, 51, 237, 112, 0, 220, 89, 23, 69, 4, 207, 253};
1597 std::vector<uint8_t> recurrentToCellWeights =
1598 {172, 60, 205, 65, 14, 0, 140, 168, 240, 223, 133, 56, 142, 64, 246, 216};
1599 std::vector<uint8_t> recurrentToOutputWeights =
1600 {106, 214, 67, 23, 59, 158, 45, 3, 119, 132, 49, 205, 129, 218, 11, 98};
1602 std::vector<int32_t> inputGateBias = {-7876, 13488, -726, 32839};
1603 std::vector<int32_t> forgetGateBias = {9206, -46884, -11693, -38724};
1604 std::vector<int32_t> cellBias = {39481, 48624, 48976, -21419};
1605 std::vector<int32_t> outputGateBias = {-58999, -17050, -41852, -40538};
1656 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.
CreateQuantizedLstm(data, info);
1657 inputHandle->Allocate();
1658 outputStateInHandle->Allocate();
1659 cellStateInHandle->Allocate();
1661 cellStateOutHandle->Allocate();
1662 outputHandle->Allocate();
1668 workload->Execute();
1674 outputHandle->GetShape(),
1675 outputStateInfo.GetShape());
1683 const std::vector<int8_t>& input,
1684 const std::vector<int8_t>& outputExpected)
1687 unsigned int numBatches = 2;
1688 unsigned int inputSize = 5;
1689 unsigned int outputSize = 4;
1690 unsigned int numUnits = 4;
1692 bool cifgEnabled =
true;
1693 bool peepholeEnabled =
false;
1694 bool projectionEnabled =
false;
1695 bool layerNormEnabled =
true;
1698 float inputScale = 0.0078125f;
1699 int32_t inputOffset = 0;
1701 int32_t hiddenStateZeroPoint = 0;
1702 float hiddenStateScale = 0.007f;
1705 float outputScale = hiddenStateScale;
1706 int32_t outputOffset = hiddenStateZeroPoint;
1708 float cellStateScale = 3.05176e-05f;
1709 int32_t cellStateOffset = 0;
1711 float weightsScale = 0.00784314f;
1712 int32_t weightsOffset = 0;
1714 float layerNormScale = 3.05182e-05f;
1715 int32_t layerNormOffset = 0;
1717 float biasScale = layerNormScale / 1024;
1718 int32_t biasOffset = 0;
1720 float inputIntermediateScale = 0.007059f;
1721 float forgetIntermediateScale = 0.007812f;
1722 float cellIntermediateScale = inputIntermediateScale;
1723 float outputIntermediateScale = forgetIntermediateScale;
1725 float cellClip = 0.0f;
1726 float projectionClip = 0.0f;
1747 std::vector<int8_t> inputVector;
1748 inputVector.assign(input.data(), input.data() + (numBatches * inputSize));
1750 std::vector<int16_t> cellStateInVector = {0, 0, 0, 0, 0, 0, 0, 0};
1752 std::vector<int8_t> outputStateInVector = {0, 0, 0, 0, 0, 0, 0, 0};
1755 std::vector<int16_t> cellStateOutVector = {-11692, 9960, 5491, 8861, -9422, 7726, 2056, 13149};
1757 std::vector<int8_t> outputVector;
1758 outputVector.assign(outputExpected.data(), outputExpected.data() + (numBatches * outputSize));
1760 std::vector<int8_t> actualOutput(outputStateInfo.GetNumElements());
1763 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.
CreateTensorHandle(inputInfo);
1764 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
1766 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
1769 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
1771 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
1773 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.
CreateTensorHandle(outputStateInfo);
1779 AddInputToWorkload(data, info, inputInfo, inputHandle.get());
1780 AddInputToWorkload(data, info, outputStateInfo, outputStateInHandle.get());
1781 AddInputToWorkload(data, info, cellStateInfo, cellStateInHandle.get());
1783 AddOutputToWorkload(data, info, outputStateInfo, outputStateOutHandle.get());
1784 AddOutputToWorkload(data, info, cellStateInfo, cellStateOutHandle.get());
1785 AddOutputToWorkload(data, info, outputStateInfo, outputHandle.get());
1803 std::vector<int8_t> inputToForgetWeights =
1804 {-77, -13, 38, 25, 115, -64, -25, -51, 38, -102, -51, 38, -64, -51, -77, 38, -51, -77, -64, -64};
1805 std::vector<int8_t> inputToCellWeights =
1806 {-51, -38, -25, -13, -64, 64, -25, -38, -25, -77, 77, -13, -51, -38, -89, 89, -115, -64, 102, 77};
1807 std::vector<int8_t> inputToOutputWeights =
1808 {-102, -51, -25, -115, -13, -89, 38, -38, -102, -25, 77, -25, 51, -89, -38, -64, 13, 64, -77, -51};
1810 std::vector<int8_t> recurrentToForgetWeights =
1811 {-64, -38, -64, -25, 77, 51, 115, 38, -13, 25, 64, 25, 25, 38, -13, 51};
1812 std::vector<int8_t> recurrentToCellWeights =
1813 {-38, 25, 13, -38, 102, -10, -25, 38, 102, -77, -13, 25, 38, -13, 25, 64};
1814 std::vector<int8_t> recurrentToOutputWeights =
1815 {38, -13, 13, -25, -64, -89, -25, -77, -13, -51, -89, -25, 13, 64, 25, -38};
1817 std::vector<int32_t> forgetGateBias = {2147484, -6442451, -4294968, 2147484};
1818 std::vector<int32_t> cellBias = {-1073742, 15461883, 5368709, 1717987};
1819 std::vector<int32_t> outputGateBias = {1073742, -214748, 4294968, 2147484};
1821 std::vector<int16_t> forgetLayerNormWeights = {6553, 6553, 13107, 9830};
1822 std::vector<int16_t> cellLayerNormWeights = {22937, 6553, 9830, 26214};
1823 std::vector<int16_t> outputLayerNormWeights = {19660, 6553, 6553, 16384};
1893 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.
CreateQLstm(data, info);
1894 inputHandle->Allocate();
1895 outputStateInHandle->Allocate();
1896 cellStateInHandle->Allocate();
1898 outputStateOutHandle->Allocate();
1899 cellStateOutHandle->Allocate();
1900 outputHandle->Allocate();
1906 workload->Execute();
1912 outputHandle->GetShape(),
1913 outputStateInfo.GetShape());
1921 const std::vector<int8_t>& input,
1922 const std::vector<int8_t>& outputExpected)
1925 unsigned int numBatches = 2;
1926 unsigned int inputSize = 5;
1927 unsigned int outputSize = 3;
1928 unsigned int numUnits = 4;
1930 bool cifgEnabled =
false;
1931 bool peepholeEnabled =
false;
1932 bool projectionEnabled =
true;
1933 bool layerNormEnabled =
true;
1936 float inputScale = 0.0078125f;
1937 int32_t inputOffset = 0;
1939 int32_t hiddenStateZeroPoint = 0;
1940 float hiddenStateScale = 0.007f;
1943 float outputScale = 3.05176e-05f;
1944 int32_t outputOffset = 0;
1946 float cellStateScale = 3.05176e-05f;
1947 int32_t cellStateOffset = 0;
1949 float weightsScale = 0.00784314f;
1950 int32_t weightsOffset = 0;
1952 float layerNormScale = 3.05182e-05f;
1953 int32_t layerNormOffset = 0;
1955 float biasScale = layerNormScale / 1024;
1956 int32_t biasOffset = 0;
1958 float projectionWeightsScale = 0.00392157f;
1960 float inputIntermediateScale = 0.007059f;
1961 float forgetIntermediateScale = 0.007812f;
1962 float cellIntermediateScale = inputIntermediateScale;
1963 float outputIntermediateScale = forgetIntermediateScale;
1965 float cellClip = 0.0f;
1966 float projectionClip = 0.0f;
1985 std::vector<int8_t> inputVector;
1986 inputVector.assign(input.data(), input.data() + (numBatches * inputSize));
1988 std::vector<int16_t> cellStateInVector = {0, 0, 0, 0, 0, 0, 0, 0};
1990 std::vector<int8_t> outputStateInVector = {0, 0, 0, 0, 0, 0};
1993 std::vector<int16_t> cellStateOutVector = {-14650, 8939, 5771, 6715, -11843, 7847, 1508, 12939};
1995 std::vector<int8_t> outputVector;
1996 outputVector.assign(outputExpected.data(), outputExpected.data() + (numBatches * outputSize));
1998 std::vector<int8_t> actualOutput(outputStateInfo.GetNumElements());
2001 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.
CreateTensorHandle(inputInfo);
2002 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
2004 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
2007 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
2009 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
2011 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.
CreateTensorHandle(outputStateInfo);
2017 AddInputToWorkload(data, info, inputInfo, inputHandle.get());
2018 AddInputToWorkload(data, info, outputStateInfo, outputStateInHandle.get());
2019 AddInputToWorkload(data, info, cellStateInfo, cellStateInHandle.get());
2021 AddOutputToWorkload(data, info, outputStateInfo, outputStateOutHandle.get());
2022 AddOutputToWorkload(data, info, cellStateInfo, cellStateOutHandle.get());
2023 AddOutputToWorkload(data, info, outputStateInfo, outputHandle.get());
2042 projectionWeightsScale,
2046 std::vector<int8_t> inputToInputWeights =
2047 {64, 77, 89, -102, -115, 13, 25, 38, -51, 64, -102, 89, -77, 64, -51, -64, -51, -38, -25, -13};
2048 std::vector<int8_t> inputToForgetWeights =
2049 {-77, -13, 38, 25, 115, -64, -25, -51, 38, -102, -51, 38, -64, -51, -77, 38, -51, -77, -64, -64};
2050 std::vector<int8_t> inputToCellWeights =
2051 {-51, -38, -25, -13, -64, 64, -25, -38, -25, -77, 77, -13, -51, -38, -89, 89, -115, -64, 102, 77};
2052 std::vector<int8_t> inputToOutputWeights =
2053 {-102, -51, -25, -115, -13, -89, 38, -38, -102, -25, 77, -25, 51, -89, -38, -64, 13, 64, -77, -51};
2055 std::vector<int8_t> recurrentToInputWeights = {-25, -38, 51, 13, -64, 115, -25, -38, -89, 6, -25, -77};
2056 std::vector<int8_t> recurrentToForgetWeights = {-64, -38, -64, -25, 77, 51, 115, 38, -13, 25, 64, 25};
2057 std::vector<int8_t> recurrentToCellWeights = {-38, 25, 13, -38, 102, -10, -25, 38, 102, -77, -13, 25};
2058 std::vector<int8_t> recurrentToOutputWeights = {38, -13, 13, -25, -64, -89, -25, -77, -13, -51, -89, -25};
2060 std::vector<int32_t> inputGateBias = {644245, 3221226, 4724464, 8160438};
2061 std::vector<int32_t> forgetGateBias = {2147484, -6442451, -4294968, 2147484};
2062 std::vector<int32_t> cellBias = {-1073742, 15461883, 5368709, 1717987};
2063 std::vector<int32_t> outputGateBias = {1073742, -214748, 4294968, 2147484};
2065 std::vector<int16_t> inputLayerNormWeights = {3277, 6553, 9830, 16384};
2066 std::vector<int16_t> forgetLayerNormWeights = {6553, 6553, 13107, 9830};
2067 std::vector<int16_t> cellLayerNormWeights = {22937, 6553, 9830, 26214};
2068 std::vector<int16_t> outputLayerNormWeights = {19660, 6553, 6553, 16384};
2070 std::vector<int8_t> projectionWeights = {-25, 51, 3, -51, 25, 127, 77, 20, 18, 51, -102, 51};
2158 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.
CreateQLstm(data, info);
2159 inputHandle->Allocate();
2160 outputStateInHandle->Allocate();
2161 cellStateInHandle->Allocate();
2163 outputStateOutHandle->Allocate();
2164 cellStateOutHandle->Allocate();
2165 outputHandle->Allocate();
2171 workload->Execute();
2177 outputHandle->GetShape(),
2178 outputStateInfo.GetShape());
2186 const std::vector<int8_t>& input,
2187 const std::vector<int8_t>& outputExpected)
2190 unsigned int numBatches = 2;
2191 unsigned int inputSize = 5;
2192 unsigned int outputSize = 3;
2193 unsigned int numUnits = 4;
2195 bool cifgEnabled =
true;
2196 bool peepholeEnabled =
false;
2197 bool projectionEnabled =
true;
2198 bool layerNormEnabled =
true;
2201 float inputScale = 0.0078125f;
2202 int32_t inputOffset = 0;
2204 int32_t hiddenStateZeroPoint = 0;
2205 float hiddenStateScale = 0.007f;
2208 float outputScale = 3.05176e-05f;
2209 int32_t outputOffset = 0;
2211 float cellStateScale = 3.05176e-05f;
2212 int32_t cellStateOffset = 0;
2214 float weightsScale = 0.00784314f;
2215 int32_t weightsOffset = 0;
2217 float layerNormScale = 3.05182e-05f;
2218 int32_t layerNormOffset = 0;
2220 float biasScale = layerNormScale / 1024;
2221 int32_t biasOffset = 0;
2223 float projectionWeightsScale = 0.00392157f;
2225 float inputIntermediateScale = 0.007059f;
2226 float forgetIntermediateScale = 0.007812f;
2227 float cellIntermediateScale = inputIntermediateScale;
2228 float outputIntermediateScale = forgetIntermediateScale;
2230 float cellClip = 0.0f;
2231 float projectionClip = 0.0f;
2250 std::vector<int8_t> inputVector;
2251 inputVector.assign(input.data(), input.data() + (numBatches * inputSize));
2253 std::vector<int16_t> cellStateInVector = {0, 0, 0, 0, 0, 0, 0, 0};
2255 std::vector<int8_t> outputStateInVector = {0, 0, 0, 0, 0, 0};
2258 std::vector<int16_t> cellStateOutVector = {-14650, 8939, 5771, 6715, -11843, 7847, 1508, 12939};
2260 std::vector<int8_t> outputVector;
2261 outputVector.assign(outputExpected.data(), outputExpected.data() + (numBatches * outputSize));
2263 std::vector<int8_t> actualOutput(outputStateInfo.GetNumElements());
2266 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.
CreateTensorHandle(inputInfo);
2267 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
2269 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
2272 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
2274 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
2276 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.
CreateTensorHandle(outputStateInfo);
2282 AddInputToWorkload(data, info, inputInfo, inputHandle.get());
2283 AddInputToWorkload(data, info, outputStateInfo, outputStateInHandle.get());
2284 AddInputToWorkload(data, info, cellStateInfo, cellStateInHandle.get());
2286 AddOutputToWorkload(data, info, outputStateInfo, outputStateOutHandle.get());
2287 AddOutputToWorkload(data, info, cellStateInfo, cellStateOutHandle.get());
2288 AddOutputToWorkload(data, info, outputStateInfo, outputHandle.get());
2307 projectionWeightsScale,
2311 std::vector<int8_t> inputToForgetWeights =
2312 {-77, -13, 38, 25, 115, -64, -25, -51, 38, -102, -51, 38, -64, -51, -77, 38, -51, -77, -64, -64};
2313 std::vector<int8_t> inputToCellWeights =
2314 {-51, -38, -25, -13, -64, 64, -25, -38, -25, -77, 77, -13, -51, -38, -89, 89, -115, -64, 102, 77};
2315 std::vector<int8_t> inputToOutputWeights =
2316 {-102, -51, -25, -115, -13, -89, 38, -38, -102, -25, 77, -25, 51, -89, -38, -64, 13, 64, -77, -51};
2318 std::vector<int8_t> recurrentToForgetWeights =
2319 {-64, -38, -64, -25, 77, 51, 115, 38, -13, 25, 64, 25};
2320 std::vector<int8_t> recurrentToCellWeights =
2321 {-38, 25, 13, -38, 102, -10, -25, 38, 102, -77, -13, 25};
2322 std::vector<int8_t> recurrentToOutputWeights =
2323 {38, -13, 13, -25, -64, -89, -25, -77, -13, -51, -89, -25};
2325 std::vector<int32_t> forgetGateBias = {2147484, -6442451, -4294968, 2147484};
2326 std::vector<int32_t> cellBias = {-1073742, 15461883, 5368709, 1717987};
2327 std::vector<int32_t> outputGateBias = {1073742, -214748, 4294968, 2147484};
2329 std::vector<int16_t> forgetLayerNormWeights = {6553, 6553, 13107, 9830};
2330 std::vector<int16_t> cellLayerNormWeights = {22937, 6553, 9830, 26214};
2331 std::vector<int16_t> outputLayerNormWeights = {19660, 6553, 6553, 16384};
2333 std::vector<int8_t> projectionWeights = {-25, 51, 3, -51, 25, 127, 77, 20, 18, 51, -102, 51};
2409 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.
CreateQLstm(data, info);
2410 inputHandle->Allocate();
2411 outputStateInHandle->Allocate();
2412 cellStateInHandle->Allocate();
2414 outputStateOutHandle->Allocate();
2415 cellStateOutHandle->Allocate();
2416 outputHandle->Allocate();
2422 workload->Execute();
2428 outputHandle->GetShape(),
2429 outputStateInfo.GetShape());
2435 #if defined(ARMNNREF_ENABLED) 2439 void LstmUtilsZeroVectorTest()
2442 std::vector<float> input = {2., 3., 3., 4.};
2443 std::vector<float> expectedOutput = {0., 0., 0., 0.};
2445 return LstmUtilsZeroVectorTestImpl<armnn::DataType::Float32>(input, 4, expectedOutput, inputDesc.GetShape());
2448 void LstmUtilsMeanStddevNormalizationNoneZeroInputTest()
2450 uint32_t batchSize = 2;
2451 uint32_t vecSize = 4;
2453 std::vector<float> input =
2454 { 0.1f, 0.2f, 0.3f, 0.4f,
2455 0.9f, 1.0f, 1.1f, 1.2f };
2457 std::vector<float> expectedOutput =
2458 { -1.34164071f, -0.447213531f, 0.44721365f, 1.34164071f,
2459 -1.34163153f, -0.447210163f, 0.447211236f, 1.3416326f };
2461 return LstmUtilsMeanStddevNormalizationTestImpl<armnn::DataType::Float32>(input,
2462 vecSize, batchSize, expectedOutput, inputDesc.GetShape());
2465 void LstmUtilsMeanStddevNormalizationAllZeroInputTest()
2467 uint32_t batchSize = 2;
2468 uint32_t vecSize = 4;
2470 std::vector<float> input =
2471 { 0.0f, 0.0f, 0.0f, 0.0f,
2472 0.0f, 0.0f, 0.0f, 0.0f };
2474 std::vector<float> expectedOutput =
2475 { 0.0f, 0.0f, 0.0f, 0.0f,
2476 0.0f, 0.0f, 0.0f, 0.0f };
2478 return LstmUtilsMeanStddevNormalizationTestImpl<armnn::DataType::Float32>(input,
2479 vecSize, batchSize, expectedOutput, inputDesc.GetShape());
2482 void LstmUtilsMeanStddevNormalizationMixedZeroInputTest()
2484 uint32_t batchSize = 2;
2485 uint32_t vecSize = 4;
2487 std::vector<float> input =
2488 { 0.0f, 0.0f, 0.0f, 0.0f,
2489 0.1f, 0.2f, 0.3f, 0.4f };
2491 std::vector<float> expectedOutput =
2492 { 0.0f, 0.0f, 0.0f, 0.0f,
2493 -1.34164071f, -0.447213531f, 0.44721365f, 1.34164071f };
2495 return LstmUtilsMeanStddevNormalizationTestImpl<armnn::DataType::Float32>(input,
2496 vecSize, batchSize, expectedOutput, inputDesc.GetShape());
2499 void LstmUtilsVectorBatchVectorCwiseProductTest()
2501 uint32_t batchSize = 4;
2502 uint32_t vecSize = 29;
2504 std::vector<float> vector =
2505 { 1.1f, 2.2f, 3.3f, 4.4f, 5.5f, 6.6f, 7.7f, 8.8f, 9.9f, 10.1f,
2506 11.11f, 12.12f, 13.13f, 14.14f, 15.15f, 16.16f, 17.17f, 18.18f, 19.19f, 20.2f,
2507 21.21f, 22.22f, 23.23f, 24.24f, 25.25f, 26.26f, 27.27f, 28.28f, 0.0f};
2510 std::vector<float> batchVector =
2512 1.1f, 2.2f, 3.3f, 4.4f, 5.5f, 6.6f, 7.7f, 8.8f, 9.9f, 10.1f,
2513 11.11f, 12.12f, 13.13f, 14.14f, 15.15f, 16.16f, 17.17f, 18.18f, 19.19f, 20.2f,
2514 21.21f, 22.22f, 23.23f, 24.24f, 25.25f, 26.26f, 27.27f, 28.28f, 0.0f,
2516 -1.1f, -2.2f, -3.3f, -4.4f, -5.5f, -6.6f, -7.7f, -8.8f, -9.9f, -10.1f,
2517 -11.11f, -12.12f, -13.13f, -14.14f, -15.15f, -16.16f, -17.17f, -18.18f, -19.19f, -20.2f,
2518 -21.21f, -22.22f, -23.23f, -24.24f, -25.25f, -26.26f, -27.27f, -28.28f, 0.0f,
2520 1.1f, -2.2f, 3.3f, -4.4f, 5.5f, -6.6f, 7.7f, -8.8f, 9.9f, -10.1f,
2521 11.11f, -12.12f, 13.13f, -14.14f, 15.15f, -16.16f, 17.17f, -18.18f, 19.19f, -20.2f,
2522 21.21f, -22.22f, 23.23f, -24.24f, 25.25f, -26.26f, 27.27f, -28.28f, 0.0f,
2524 -1.1f, 2.2f, -3.3f, 4.4f, -5.5f, 6.6f, -7.7f, 8.8f, -9.9f, 10.1f,
2525 -11.11f, 12.12f, -13.13f, 14.14f, -15.15f, 16.16f, -17.17f, 18.18f, -19.19f, 20.2f,
2526 -21.21f, 22.22f, -23.23f, 24.24f, -25.25f, 26.26f, -27.27f, 28.28f, 0.0f};
2529 std::vector<float> expectedOutput =
2531 1.210000f, 4.840000f, 10.889999f, 19.360001f, 30.250000f, 43.559998f,
2532 59.289997f, 77.440002f, 98.009995f, 102.010010f, 123.432091f, 146.894394f,
2533 172.396896f, 199.939606f, 229.522491f, 261.145599f, 294.808899f, 330.512421f,
2534 368.256134f, 408.040039f, 449.864075f, 493.728363f, 539.632874f, 587.577576f,
2535 637.562500f, 689.587585f, 743.652954f, 799.758423f, 0.000000f,
2537 -1.210000f, -4.840000f, -10.889999f, -19.360001f, -30.250000f, -43.559998f,
2538 -59.289997f, -77.440002f, -98.009995f, -102.010010f, -123.432091f, -146.894394f,
2539 -172.396896f, -199.939606f, -229.522491f, -261.145599f, -294.808899f, -330.512421f,
2540 -368.256134f, -408.040039f, -449.864075f, -493.728363f, -539.632874f, -587.577576f,
2541 -637.562500f, -689.587585f, -743.652954f, -799.758423f, 0.000000f,
2543 1.210000f, -4.840000f, 10.889999f, -19.360001f, 30.250000f, -43.559998f,
2544 59.289997f, -77.440002f, 98.009995f, -102.010010f, 123.432091f, -146.894394f,
2545 172.396896f, -199.939606f, 229.522491f, -261.145599f, 294.808899f, -330.512421f,
2546 368.256134f, -408.040039f, 449.864075f, -493.728363f, 539.632874f, -587.577576f,
2547 637.562500f, -689.587585f, 743.652954f, -799.758423f, 0.000000f,
2549 -1.210000f, 4.840000f, -10.889999f, 19.360001f, -30.250000f, 43.559998f,
2550 -59.289997f, 77.440002f, -98.009995f, 102.010010f, -123.432091f, 146.894394f,
2551 -172.396896f, 199.939606f, -229.522491f, 261.145599f, -294.808899f, 330.512421f,
2552 -368.256134f, 408.040039f, -449.864075f, 493.728363f, -539.632874f, 587.577576f,
2553 -637.562500f, 689.587585f, -743.652954f, 799.758423f, 0.000000f};
2555 return LstmUtilsVectorBatchVectorCwiseProductTestImpl<armnn::DataType::Float32>(vector, batchVector,
2556 vecSize, batchSize, expectedOutput, vecDesc.GetShape());
2559 void LstmUtilsVectorBatchVectorAddTest()
2561 uint32_t batchSize = 2;
2562 uint32_t vecSize = 3;
2564 std::vector<float> vector = { 0.0f, -0.5f, 1.0f};
2567 std::vector<float> batchVector =
2573 std::vector<float> expectedOutput =
2579 return LstmUtilsVectorBatchVectorAddTestImpl<armnn::DataType::Float32>(vector, batchVector,
2580 vecSize, batchSize, expectedOutput, batchVecDesc.GetShape());
2591 std::vector<float> input = { 2., 3., 3., 4. };
2594 std::vector<float> expectedOutput =
2595 {-0.36444446f, -0.00352185f, 0.12886585f, -0.05163646f,
2596 -0.42734814f, -0.00478661f, 0.13455015f, -0.03560682f};
2597 return LstmLayerWithCifgWithPeepholeNoProjectionTestImpl<armnn::DataType::Float32>(
2598 workloadFactory, memoryManager, tensorHandleFactory,
2599 input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape());
2608 std::vector<float> input =
2609 {0.787926f, 0.151646f, 0.071352f, 0.118426f, 0.458058f,
2610 0.295743f, 0.544053f, 0.690064f, 0.858138f, 0.497181f};
2613 std::vector<float> expectedOutput =
2614 {-0.00396806f, 0.029352f, -0.00279226f, 0.0159977f, -0.00835576f,
2615 -0.0211779f, 0.0283512f, -0.0114597f, 0.00907307f, -0.0244004f,
2616 -0.0152191f, -0.0259063f, 0.00914318f, 0.00415118f, 0.017147f,
2617 0.0134203f, -0.013869f, 0.0287268f, -0.00334693f, 0.00733398f, -0.0287926f,
2618 -0.0186926f, 0.0193662f, -0.0115437f, 0.00422612f, -0.0345232f,
2619 0.00223253f, -0.00957321f, 0.0210624f, 0.013331f, 0.0150954f, 0.02168f};
2620 return LstmLayerNoCifgWithPeepholeWithProjectionTestImpl<armnn::DataType::Float32>(
2621 workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput);
2630 std::vector<float> input = {2., 3., 3., 4.};
2633 std::vector<float> expectedOutput =
2634 {-0.02973187f, 0.1229473f, 0.20885126f, -0.15358765f,
2635 -0.0185422f, 0.11281417f, 0.24466537f, -0.1826292f};
2637 return LstmNoCifgNoPeepholeNoProjectionTestImpl<armnn::DataType::Float32>(
2638 workloadFactory, memoryManager, tensorHandleFactory,
2639 input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape());
2648 std::vector<float> input =
2649 {0.7f, 0.8f, 0.1f, 0.2f, 0.3f,
2650 0.3f, 0.2f, 0.9f, 0.8f, 0.1f};
2653 std::vector<float> expectedOutput =
2654 { 0.0244077f, 0.128027f, -0.00170918f,
2655 -0.00692428f, 0.0848741f, 0.063445f};
2656 return LstmLayerNoCifgWithPeepholeWithProjectionWithLayerNormTestImpl<armnn::DataType::Float32>(
2657 workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput);
2665 const float qScale = 1.0f;
2666 const int32_t qOffset = 0;
2672 std::vector<int16_t> input = armnnUtils::QuantizedVector<int16_t>({ 2.f, 3.f, 3.f, 4.f }, qScale, qOffset);
2675 std::vector<int16_t> expectedOutput = armnnUtils::QuantizedVector<int16_t>(
2677 -0.02973187f, 0.12294730f, 0.20885126f, -0.15358765f,
2678 -0.01854220f, 0.11281417f, 0.24466537f, -0.18262920f
2682 return LstmNoCifgNoPeepholeNoProjectionTestImpl<datatype>(
2683 workloadFactory, memoryManager, tensorHandleFactory,
2684 input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape(),
2685 qScale, qOffset, constantDatatype);
2694 const float qScale = 1.0f;
2695 const int32_t qOffset = 0;
2701 std::vector<int16_t> input = armnnUtils::QuantizedVector<int16_t>({ 2.f, 3.f, 3.f, 4.f }, qScale, qOffset);
2704 std::vector<int16_t> expectedOutput = armnnUtils::QuantizedVector<int16_t>(
2706 -0.36444446f, -0.00352185f, 0.12886585f, -0.05163646f,
2707 -0.42734814f, -0.00478661f, 0.13455015f, -0.03560682f
2711 return LstmLayerWithCifgWithPeepholeNoProjectionTestImpl<datatype>(
2712 workloadFactory, memoryManager, tensorHandleFactory,
2713 input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape(),
2714 qScale, qOffset, constantDatatype);
2722 const float qScale = 2.0f;
2723 const int32_t qOffset = 0;
2729 std::vector<int16_t> input = armnnUtils::QuantizedVector<int16_t>(
2731 0.787926f, 0.151646f, 0.071352f, 0.118426f, 0.458058f,
2732 0.295743f, 0.544053f, 0.690064f, 0.858138f, 0.497181f
2737 std::vector<int16_t> expectedOutput = armnnUtils::QuantizedVector<int16_t>(
2739 -0.00396806f, 0.02935200f, -0.00279226f, 0.01599770f,
2740 -0.00835576f, -0.02117790f, 0.02835120f, -0.01145970f,
2741 0.00907307f, -0.02440040f, -0.01521910f, -0.02590630f,
2742 0.00914318f, 0.00415118f, 0.01714700f, 0.01342030f,
2743 -0.01386900f, 0.02872680f, -0.00334693f, 0.00733398f,
2744 -0.02879260f, -0.01869260f, 0.01936620f, -0.01154370f,
2745 0.00422612f, -0.03452320f, 0.00223253f, -0.00957321f,
2746 0.02106240f, 0.01333100f, 0.01509540f, 0.02168000f
2750 return LstmLayerNoCifgWithPeepholeWithProjectionTestImpl<datatype>(
2751 workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput, qScale, qOffset, constantDatatype);
2759 const float qScale = 1.0f;
2760 const int32_t qOffset = 0;
2765 std::vector<int16_t> input = armnnUtils::QuantizedVector<int16_t>({ 2.f, 3.f, 3.f, 4.f }, qScale, qOffset);
2768 std::vector<int16_t> expectedOutput = armnnUtils::QuantizedVector<int16_t>(
2770 -0.02973187f, 0.12294730f, 0.20885126f, -0.15358765f,
2771 -0.01854220f, 0.11281417f, 0.24466537f, -0.18262920f
2775 return LstmNoCifgNoPeepholeNoProjectionTestImpl<datatype>(
2776 workloadFactory, memoryManager, tensorHandleFactory,
2777 input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape(),
2778 qScale, qOffset, datatype);
2791 std::vector<uint8_t> input = {166, 179, 50, 150};
2794 std::vector<uint8_t> expectedOutput = {140, 151, 146, 112, 136, 156, 142, 112 };
2796 return QuantizedLstmTestImpl(workloadFactory, memoryManager, tensorHandleFactory,
2797 input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape());
2807 std::vector<int8_t> input = {90, 102, 13, 26, 38, 102, 13, 26, 51, 64};
2810 std::vector<int8_t> expectedOutput = {-15, 21, 14, 20, -15, 15, 5, 27};
2812 return QLstmTestImpl(workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput);
2821 std::vector<int8_t> input = {90, 102, 13, 26, 38, 102, 13, 26, 51, 64};
2824 std::vector<int8_t> expectedOutput = {127, 127, -108, -67, 127, 127};
2826 return QLstmTestImpl1(workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput);
2835 std::vector<int8_t> input = {90, 102, 13, 26, 38, 102, 13, 26, 51, 64};
2838 std::vector<int8_t> expectedOutput = {127, 127, 127, -128, 127, 127};
2840 return QLstmTestImpl2(workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput);
const ConstTensorHandle * m_CellLayerNormWeights
void MeanStddevNormalization(armnn::Decoder< float > &input_vector, armnn::Encoder< float > &output_vector, uint32_t v_size, uint32_t n_batch, float normalization_epsilon)
void VectorBatchVectorAdd(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Decoder< float > &batchVector, uint32_t nBatch, armnn::Encoder< float > &outResult)
LayerTestResult< float, 2 > LstmLayerFloat32NoCifgWithPeepholeWithProjectionTest(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager, const armnn::ITensorHandleFactory &tensorHandleFactory)
const ConstTensorHandle * m_ProjectionWeights
LayerTestResult< int16_t, 2 > LstmLayerInt16WithCifgWithPeepholeNoProjectionTest(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager, const armnn::ITensorHandleFactory &tensorHandleFactory)
bool m_ProjectionEnabled
Enable/disable the projection layer.
const ConstTensorHandle * m_ProjectionWeights
const ConstTensorHandle * m_RecurrentToForgetWeights
const ConstTensorHandle * m_ForgetGateBias
const ConstTensorHandle * m_InputToOutputWeights
float m_ClippingThresProj
Clipping threshold value for the projection.
const ConstTensorHandle * m_InputGateBias
LayerTestResult< int8_t, 2 > QLstmTest1(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager, const armnn::ITensorHandleFactory &tensorHandleFactory)
const ConstTensorHandle * m_RecurrentToCellWeights
bool m_PeepholeEnabled
Enable/disable peephole.
const ConstTensorHandle * m_CellBias
float m_HiddenStateScale
Hidden State quantization scale.
LayerTestResult< int16_t, 2 > LstmLayerInt16NoCifgWithPeepholeWithProjectionTest(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager, const armnn::ITensorHandleFactory &tensorHandleFactory)
virtual std::unique_ptr< IWorkload > CreateLstm(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info) const
const ConstTensorHandle * m_InputToInputWeights
float m_OutputIntermediateScale
Output intermediate quantization scale.
LayerTestResult< int8_t, 2 > QLstmTest(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager, const armnn::ITensorHandleFactory &tensorHandleFactory)
virtual std::unique_ptr< IWorkload > CreateQuantizedLstm(const QuantizedLstmQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateQLstm(const QLstmQueueDescriptor &descriptor, const WorkloadInfo &info) const
const ConstTensorHandle * m_InputGateBias
const ConstTensorHandle * m_InputToOutputWeights
armnn::PredicateResult CompareTensors(const std::vector< T > &actualData, const std::vector< T > &expectedData, const armnn::TensorShape &actualShape, const armnn::TensorShape &expectedShape, bool compareBoolean=false, bool isDynamic=false)
const ConstTensorHandle * m_OutputLayerNormWeights
void ZeroVector(armnn::Encoder< float > &vector, uint32_t vSize)
void IgnoreUnused(Ts &&...)
const ConstTensorHandle * m_RecurrentToInputWeights
LayerDescriptor m_Parameters
void VectorBatchVectorCwiseProduct(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Decoder< float > &batchVector, uint32_t nBatch, armnn::Encoder< float > &outResult)
const ConstTensorHandle * m_ForgetLayerNormWeights
const ConstTensorHandle * m_OutputGateBias
LayerTestResult< float, 2 > LstmLayerFloat32WithCifgWithPeepholeNoProjectionTest(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager, const armnn::ITensorHandleFactory &tensorHandleFactory)
const ConstTensorHandle * m_InputToForgetWeights
bool m_LayerNormEnabled
Enable/disable layer normalization.
const ConstTensorHandle * m_CellLayerNormWeights
std::shared_ptr< IMemoryManager > IMemoryManagerSharedPtr
float m_ProjectionClip
Clipping threshold value for the projection.
const ConstTensorHandle * m_InputToForgetWeights
const ConstTensorHandle * m_RecurrentToCellWeights
const ConstTensorHandle * m_ForgetGateBias
float m_InputIntermediateScale
Input intermediate quantization scale.
bool m_PeepholeEnabled
Enable/disable peephole.
const ConstTensorHandle * m_CellBias
const ConstTensorHandle * m_CellToOutputWeights
LayerTestResult< uint8_t, 2 > QuantizedLstmTest(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager, const armnn::ITensorHandleFactory &tensorHandleFactory)
const ConstTensorHandle * m_RecurrentToOutputWeights
const ConstTensorHandle * m_InputLayerNormWeights
const ConstTensorHandle * m_InputToCellWeights
void AllocateAndCopyDataToITensorHandle(armnn::ITensorHandle *tensorHandle, const void *memory)
void CopyDataFromITensorHandle(void *memory, const armnn::ITensorHandle *tensorHandle)
const ConstTensorHandle * m_OutputGateBias
const ConstTensorHandle * m_InputToForgetWeights
uint32_t m_ActivationFunc
The activation function to use.
const ConstTensorHandle * m_RecurrentToForgetWeights
LayerTestResult< float, 2 > LstmLayerFloat32NoCifgNoPeepholeNoProjectionTest(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager, const armnn::ITensorHandleFactory &tensorHandleFactory)
float m_ClippingThresCell
Clipping threshold value for the cell state.
const ConstTensorHandle * m_RecurrentToInputWeights
float m_ForgetIntermediateScale
Forget intermediate quantization scale.
const ConstTensorHandle * m_InputToCellWeights
const ConstTensorHandle * m_ForgetGateBias
float m_CellClip
Clipping threshold value for the cell state.
bool m_CifgEnabled
Enable/disable cifg (coupled input & forget gate).
LayerTestResult< int16_t, 2 > LstmLayerInt16NoCifgNoPeepholeNoProjectionTest(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager, const armnn::ITensorHandleFactory &tensorHandleFactory)
const ConstTensorHandle * m_InputToOutputWeights
LayerTestResult< int16_t, 2 > LstmLayerInt16NoCifgNoPeepholeNoProjectionInt16ConstantTest(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager, const armnn::ITensorHandleFactory &tensorHandleFactory)
const ConstTensorHandle * m_CellToForgetWeights
const ConstTensorHandle * m_RecurrentToCellWeights
bool m_ProjectionEnabled
Enable/disable the projection layer.
const ConstTensorHandle * m_InputGateBias
const ConstTensorHandle * m_InputToInputWeights
const ConstTensorHandle * m_CellBias
const ConstTensorHandle * m_ProjectionBias
bool m_LayerNormEnabled
Enable/disable layer normalization.
std::enable_if_t< std::is_unsigned< Source >::value &&std::is_unsigned< Dest >::value, Dest > numeric_cast(Source source)
const ConstTensorHandle * m_ForgetLayerNormWeights
Contains information about TensorInfos of a layer.
const ConstTensorHandle * m_InputLayerNormWeights
const ConstTensorHandle * m_OutputGateBias
LayerTestResult< float, 2 > LstmLayerFloat32NoCifgWithPeepholeWithProjectionWithLayerNormTest(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager, const armnn::ITensorHandleFactory &tensorHandleFactory)
const ConstTensorHandle * m_OutputLayerNormWeights
const ConstTensorHandle * m_InputToCellWeights
const ConstTensorHandle * m_CellToInputWeights
float m_CellIntermediateScale
Cell intermediate quantization scale.
virtual std::unique_ptr< ITensorHandle > CreateTensorHandle(const TensorInfo &tensorInfo) const =0
const ConstTensorHandle * m_InputToInputWeights
const ConstTensorHandle * m_RecurrentToOutputWeights
const ConstTensorHandle * m_RecurrentToInputWeights
bool m_CifgEnabled
Enable/disable CIFG (coupled input & forget gate).
LayerTestResult< int8_t, 2 > QLstmTest2(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager, const armnn::ITensorHandleFactory &tensorHandleFactory)
const ConstTensorHandle * m_RecurrentToForgetWeights
void CopyDataToITensorHandle(armnn::ITensorHandle *tensorHandle, const void *memory)
int32_t m_HiddenStateZeroPoint
Hidden State zero point.
const ConstTensorHandle * m_RecurrentToOutputWeights