23 #include <doctest/doctest.h> 27 template<armnn::DataType ArmnnType,
typename T = armnn::ResolveType<ArmnnType>>
28 void LstmUtilsVectorBatchVectorAddTestImpl(
29 std::vector<float>& vec,
30 std::vector<float>& batchVec,
33 std::vector<float>& expectedOutput,
41 std::unique_ptr<armnn::Decoder<float>> vecDecoder = armnn::MakeDecoder<float>(tensorInfo, vec.data());
42 std::unique_ptr<armnn::Decoder<float>> batchVecDecoder = armnn::MakeDecoder<float>(tensorInfo, batchVec.data());
43 std::unique_ptr<armnn::Encoder<float>> batchVecEncoder = armnn::MakeEncoder<float>(tensorInfo, batchVec.data());
48 auto result =
CompareTensors(batchVec, expectedOutput, expectedShape, expectedShape);
49 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
52 batchVecEncoder->Set(1.0f);
53 CHECK(batchVec[0] == 1.0f);
56 template<armnn::DataType ArmnnType,
typename T = armnn::ResolveType<ArmnnType>>
57 void LstmUtilsZeroVectorTestImpl(
58 std::vector<float>& input,
60 std::vector<float>& expectedOutput,
69 std::unique_ptr<armnn::Encoder<float>> outputEncoder = armnn::MakeEncoder<float>(tensorInfo, input.data());
75 auto result =
CompareTensors(input, expectedOutput, expectedShape, expectedShape);
76 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
79 outputEncoder->Set(1.0f);
80 CHECK(input[0] == 1.0f);
84 template<armnn::DataType ArmnnType,
typename T = armnn::ResolveType<ArmnnType>>
85 void LstmUtilsMeanStddevNormalizationTestImpl(
86 std::vector<float>& input,
89 std::vector<float>& expectedOutput,
97 std::unique_ptr<armnn::Decoder<float>> inputDecoder = armnn::MakeDecoder<float>(tensorInfo, input.data());
98 std::unique_ptr<armnn::Encoder<float>> outputEncoder = armnn::MakeEncoder<float>(tensorInfo, input.data());
103 auto result =
CompareTensors(input, expectedOutput, expectedShape, expectedShape);
104 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
107 outputEncoder->Set(1.0f);
108 CHECK(input[0] == 1.0f);
111 template<armnn::DataType ArmnnType,
typename T = armnn::ResolveType<ArmnnType>>
112 void LstmUtilsVectorBatchVectorCwiseProductTestImpl(
113 std::vector<float>& vec,
114 std::vector<float>& batchVec,
117 std::vector<float>& expectedOutput,
125 std::unique_ptr<armnn::Decoder<float>> vecDecoder = armnn::MakeDecoder<float>(tensorInfo, vec.data());
126 std::unique_ptr<armnn::Decoder<float>> batchVecDecoder = armnn::MakeDecoder<float>(tensorInfo, batchVec.data());
127 std::unique_ptr<armnn::Encoder<float>> batchVecEncoder = armnn::MakeEncoder<float>(tensorInfo, batchVec.data());
132 auto result =
CompareTensors(batchVec, expectedOutput, expectedShape, expectedShape);
133 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
136 batchVecEncoder->Set(1.0f);
137 CHECK(batchVec[0] == 1.0f);
142 template<armnn::DataType ArmnnType,
typename T = armnn::ResolveType<ArmnnType>>
144 LstmNoCifgNoPeepholeNoProjectionTestImpl(
148 const std::vector<T>& input,
149 const std::vector<T>& outputExpected,
161 unsigned numUnits = outputSize;
163 armnn::TensorInfo inputTensorInfo({batchSize , inputSize}, ArmnnType, qScale, qOffset );
164 armnn::TensorInfo cellStateInTensorInfo({batchSize , numUnits}, ArmnnType, qScale, qOffset);
165 armnn::TensorInfo outputStateInTensorInfo({batchSize , outputSize}, ArmnnType, qScale, qOffset);
167 armnn::TensorInfo scratchBufferTensorInfo({batchSize, numUnits * 4}, ArmnnType, qScale, qOffset);
168 armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, ArmnnType, qScale, qOffset);
169 armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
170 armnn::TensorInfo outputTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
172 std::vector<T> inputVector;
173 inputVector.assign(input.data(), input.data() + (batchSize * inputSize));
175 std::vector<T> cellStateInVector(batchSize * numUnits, T());
176 std::vector<T> outputStateInVector(batchSize * outputSize, T());
177 std::vector<T> scratchBufferVector(batchSize * numUnits * 4, T());
178 std::vector<T> outputStateOutVector(batchSize * outputSize, T());
179 std::vector<T> cellStateOutVector(batchSize * numUnits, T());
181 std::vector<T> actualOutput(outputTensorInfo.GetNumElements());
183 std::vector<T> outputVector;
184 outputVector.assign(outputExpected.data(), outputExpected.data() + (batchSize * outputSize));
186 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.
CreateTensorHandle(inputTensorInfo);
187 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
189 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
192 std::unique_ptr<armnn::ITensorHandle> scratchHandle =
194 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
196 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
198 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.
CreateTensorHandle(outputTensorInfo);
203 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
204 AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get());
205 AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get());
207 AddOutputToWorkload(data, info, scratchBufferTensorInfo, scratchHandle.get());
208 AddOutputToWorkload(data, info, outputStateOutTensorInfo, outputStateOutHandle.get());
209 AddOutputToWorkload(data, info, cellStateOutTensorInfo, cellStateOutHandle.get());
210 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
214 armnn::TensorInfo tensorInfo16({numUnits, 4}, constantDataType, qScale, qOffset);
216 std::vector<float> inputToInputWeights = {-0.45018822f, -0.02338299f, -0.0870589f,
217 -0.34550029f, 0.04266912f, -0.15680569f,
218 -0.34856534f, 0.43890524f};
220 std::vector<float> inputToForgetWeights = { 0.09701663f, 0.20334584f, -0.50592935f,
221 -0.31343272f, -0.40032279f, 0.44781327f,
222 0.01387155f, -0.35593212f};
224 std::vector<float> inputToCellWeights = { -0.50013041f, 0.1370284f, 0.11810488f, 0.2013163f,
225 -0.20583314f, 0.44344562f, 0.22077113f,
228 std::vector<float> inputToOutputWeights = { -0.25065863f, -0.28290087f, 0.04613829f,
229 0.40525138f, 0.44272184f, 0.03897077f,
230 -0.1556896f, 0.19487578f};
232 std::vector<float> recurrentToInputWeights = {-0.0063535f, -0.2042388f, 0.31454784f,
233 -0.35746509f, 0.28902304f, 0.08183324f,
234 -0.16555229f, 0.02286911f, -0.13566875f,
235 0.03034258f, 0.48091322f, -0.12528998f,
236 0.24077177f, -0.51332325f, -0.33502164f,
239 std::vector<float> recurrentToForgetWeights = { -0.48684245f, -0.06655136f, 0.42224967f,
240 0.2112639f, 0.27654213f, 0.20864892f,
241 -0.07646349f, 0.45877004f, 0.00141793f,
242 -0.14609534f, 0.36447752f, 0.09196436f,
243 0.28053468f, 0.01560611f, -0.20127171f,
246 std::vector<float> recurrentToCellWeights = { -0.3407414f, 0.24443203f, -0.2078532f,
247 0.26320225f, 0.05695659f, -0.00123841f,
248 -0.4744786f, -0.35869038f, -0.06418842f,
249 -0.13502428f, -0.501764f, 0.22830659f,
250 -0.46367589f, 0.26016325f, -0.03894562f,
253 std::vector<float> recurrentToOutputWeights = { 0.43385774f, -0.17194885f, 0.2718237f,
254 0.09215671f, 0.24107647f, -0.39835793f,
255 0.18212086f, 0.01301402f, 0.48572797f,
256 -0.50656658f, 0.20047462f, -0.20607421f,
257 -0.51818722f, -0.15390486f, 0.0468148f,
260 std::vector<float> cellToInputWeights = {0., 0., 0., 0.};
262 std::vector<float> inputGateBias = {0., 0., 0., 0.};
264 std::vector<float> forgetGateBias = {1., 1., 1., 1.};
266 std::vector<float> cellBias = {0., 0., 0., 0.};
268 std::vector<float> outputGateBias = {0., 0., 0., 0.};
318 inputHandle->Allocate();
319 outputStateInHandle->Allocate();
320 cellStateInHandle->Allocate();
322 scratchHandle->Allocate();
323 outputStateOutHandle->Allocate();
324 cellStateOutHandle->Allocate();
325 outputHandle->Allocate();
337 outputHandle->GetShape(),
338 outputTensorInfo.GetShape());
341 template<armnn::DataType ArmnnType,
typename T = armnn::ResolveType<ArmnnType>>
346 const std::vector<T>& input,
347 const std::vector<T>& outputExpected,
353 unsigned int batchSize = 2;
354 unsigned int outputSize = 16;
355 unsigned int inputSize = 5;
356 unsigned numUnits = 20;
358 armnn::TensorInfo inputTensorInfo({batchSize , inputSize}, ArmnnType, qScale, qOffset);
359 armnn::TensorInfo cellStateInTensorInfo({batchSize , numUnits}, ArmnnType, qScale, qOffset);
360 armnn::TensorInfo outputStateInTensorInfo({batchSize , outputSize}, ArmnnType, qScale, qOffset);
363 armnn::TensorInfo scratchBufferTensorInfo({batchSize, numUnits * 4}, ArmnnType, qScale, qOffset);
364 armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, ArmnnType, qScale, qOffset);
365 armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
366 armnn::TensorInfo outputTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
368 std::vector<T> inputVector;
369 inputVector.assign(input.data(), input.data() + (batchSize * inputSize));
371 std::vector<T> cellStateInVector(batchSize * numUnits, T());
372 std::vector<T> outputStateInVector(batchSize * outputSize, T());
373 std::vector<T> scratchBufferVector(batchSize * numUnits * 4, T());
374 std::vector<T> outputStateOutVector(batchSize * outputSize, T());
375 std::vector<T> cellStateOutVector(batchSize * numUnits, T());
377 std::vector<T> actualOutput(outputTensorInfo.GetNumElements());
379 std::vector<T> outputVector;
380 outputVector.assign(outputExpected.data(), outputExpected.data() + (batchSize * outputSize));
382 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.
CreateTensorHandle(inputTensorInfo);
383 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
385 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
388 std::unique_ptr<armnn::ITensorHandle> scratchHandle =
390 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
392 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
394 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.
CreateTensorHandle(outputTensorInfo);
399 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
400 AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get());
401 AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get());
403 AddOutputToWorkload(data, info, scratchBufferTensorInfo, scratchHandle.get());
404 AddOutputToWorkload(data, info, outputStateOutTensorInfo, outputStateOutHandle.get());
405 AddOutputToWorkload(data, info, cellStateOutTensorInfo, cellStateOutHandle.get());
406 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
410 armnn::TensorInfo tensorInfo20x5({numUnits, inputSize}, constantDataType, qScale, qOffset);
411 armnn::TensorInfo tensorInfo20x16({numUnits, outputSize}, constantDataType, qScale, qOffset);
412 armnn::TensorInfo tensorInfo16x20({outputSize, numUnits}, constantDataType, qScale, qOffset);
414 std::vector<float> inputToInputWeights = {0.021393683f,0.06124551f, 0.046905167f,-0.014657677f,-0.03149463f,
415 0.09171803f, 0.14647801f,0.10797193f, -0.0057968358f,0.0019193048f,
416 -0.2726754f, 0.10154029f, -0.018539885f, 0.080349885f, -0.10262385f,
417 -0.022599787f,-0.09121155f, -0.008675967f, -0.045206103f,-0.0821282f,
418 -0.008045952f,0.015478081f, 0.055217247f, 0.038719587f, 0.044153627f,
419 -0.06453243f,0.05031825f, -0.046935108f, -0.008164439f, 0.014574226f,
420 -0.1671009f, -0.15519552f, -0.16819797f,-0.13971269f,-0.11953059f,
421 0.25005487f, -0.22790983f, 0.009855087f, -0.028140958f, -0.11200698f,
422 0.11295408f, -0.0035217577f, 0.054485075f, 0.05184695f, 0.064711206f,
423 0.10989193f, 0.11674786f, 0.03490607f, 0.07727357f, 0.11390585f,
424 -0.1863375f, -0.1034451f, -0.13945189f, -0.049401227f, -0.18767063f,
425 0.042483903f, 0.14233552f, 0.13832581f, 0.18350165f, 0.14545603f,
426 -0.028545704f,0.024939531f,0.050929718f,0.0076203286f,-0.0029723682f,
427 -0.042484224f, -0.11827596f, -0.09171104f, -0.10808628f,-0.16327988f,
428 -0.2273378f, -0.0993647f, -0.017155107f,0.0023917493f,0.049272764f,
429 0.0038534778f, 0.054764505f, 0.089753784f, 0.06947234f, 0.08014476f,
430 -0.04544234f, -0.0497073f,-0.07135631f, -0.048929106f,-0.004042012f,
431 -0.009284026f, 0.018042054f, 0.0036860977f,-0.07427302f, -0.11434604f,
432 -0.018995456f, 0.031487543f, 0.012834908f,0.019977754f,0.044256654f,
433 -0.39292613f, -0.18519334f, -0.11651281f,-0.06809892f, 0.011373677f };
435 std::vector<float> inputToForgetWeights = {-0.0018401089f, -0.004852237f,0.03698424f, 0.014181704f,0.028273236f,
436 -0.016726194f, -0.05249759f,-0.10204261f, 0.00861066f,-0.040979505f,
437 -0.009899187f,0.01923892f,-0.028177269f, -0.08535103f,-0.14585495f,
438 0.10662567f,-0.01909731f,-0.017883534f,-0.0047269356f,-0.045103323f,
439 0.0030784295f,0.076784775f,0.07463696f, 0.094531395f,0.0814421f,
440 -0.12257899f, -0.033945758f,-0.031303465f, 0.045630626f,0.06843887f,
441 -0.13492945f, -0.012480007f,-0.0811829f, -0.07224499f,-0.09628791f,
442 0.045100946f,0.0012300825f, 0.013964662f, 0.099372394f,0.02543059f,
443 0.06958324f, 0.034257296f, 0.0482646f, 0.06267997f,0.052625068f,
444 0.12784666f, 0.07077897f, 0.025725935f, 0.04165009f,0.07241905f,
445 0.018668644f, -0.037377294f,-0.06277783f,-0.08833636f,-0.040120605f,
446 -0.011405586f,-0.007808335f,-0.010301386f,-0.005102167f,0.027717464f,
447 0.05483423f, 0.11449111f, 0.11289652f,0.10939839f, 0.13396506f,
448 -0.08402166f,-0.01901462f, -0.044678304f,-0.07720565f,0.014350063f,
449 -0.11757958f, -0.0652038f, -0.08185733f,-0.076754324f,-0.092614375f,
450 0.10405491f, 0.052960336f, 0.035755895f,0.035839386f,-0.012540553f,
451 0.036881298f, 0.02913376f, 0.03420159f,0.05448447f,-0.054523353f,
452 0.02582715f, 0.02327355f, -0.011857179f,-0.0011980024f,-0.034641717f,
453 -0.026125094f,-0.17582615f,-0.15923657f,-0.27486774f,-0.0006143371f,
454 0.0001771948f, -8.470171e-05f, 0.02651807f,0.045790765f,0.06956496f };
456 std::vector<float> inputToCellWeights = { -0.04580283f, -0.09549462f, -0.032418985f, -0.06454633f,
457 -0.043528453f, 0.043018587f, -0.049152344f, -0.12418144f,
458 -0.078985475f, -0.07596889f, 0.019484362f, -0.11434962f,
459 -0.0074034138f, -0.06314844f, -0.092981495f, 0.0062155537f,
460 -0.025034338f, -0.0028890965f, 0.048929527f, 0.06235075f,
461 0.10665918f, -0.032036792f, -0.08505916f, -0.10843358f,
462 -0.13002433f, -0.036816437f, -0.02130134f, -0.016518239f,
463 0.0047691227f, -0.0025825808f, 0.066017866f, 0.029991534f,
464 -0.10652836f, -0.1037554f, -0.13056071f, -0.03266643f,
465 -0.033702414f, -0.006473424f, -0.04611692f, 0.014419339f,
466 -0.025174323f, 0.0396852f, 0.081777506f, 0.06157468f,
467 0.10210095f, -0.009658194f, 0.046511717f, 0.03603906f,
468 0.0069369148f, 0.015960095f, -0.06507666f, 0.09551598f,
469 0.053568836f, 0.06408714f, 0.12835667f, -0.008714329f,
470 -0.20211966f, -0.12093674f, 0.029450472f, 0.2849013f,
471 -0.029227901f, 0.1164364f, -0.08560263f, 0.09941786f,
472 -0.036999565f, -0.028842626f, -0.0033637602f, -0.017012902f,
473 -0.09720865f, -0.11193351f, -0.029155117f, -0.017936034f,
474 -0.009768936f, -0.04223324f, -0.036159635f, 0.06505112f,
475 -0.021742892f, -0.023377212f, -0.07221364f, -0.06430552f,
476 0.05453865f, 0.091149814f, 0.06387331f, 0.007518393f,
477 0.055960953f, 0.069779344f, 0.046411168f, 0.10509911f,
478 0.07463894f, 0.0075130584f, 0.012850982f, 0.04555431f,
479 0.056955688f, 0.06555285f, 0.050801456f, -0.009862683f,
480 0.00826772f, -0.026555609f, -0.0073611983f, -0.0014897042f };
482 std::vector<float> inputToOutputWeights ={-0.0998932f, -0.07201956f, -0.052803773f,-0.15629593f,-0.15001918f,
483 -0.07650751f,0.02359855f, -0.075155355f, -0.08037709f, -0.15093534f,
484 0.029517552f, -0.04751393f, 0.010350531f,-0.02664851f, -0.016839722f,
485 -0.023121163f, 0.0077019283f, 0.012851257f, -0.05040649f,-0.0129761f,
486 -0.021737747f,-0.038305793f,-0.06870586f, -0.01481247f,-0.001285394f,
487 0.10124236f, 0.083122835f, 0.053313006f,-0.062235646f,-0.075637154f,
488 -0.027833903f, 0.029774971f, 0.1130802f, 0.09218906f, 0.09506135f,
489 -0.086665764f,-0.037162706f,-0.038880914f,-0.035832845f,-0.014481564f,
490 -0.09825003f,-0.12048569f,-0.097665586f,-0.05287633f, -0.0964047f,
491 -0.11366429f, 0.035777505f, 0.13568819f, 0.052451383f,0.050649304f,
492 0.05798951f, -0.021852335f,-0.099848844f,0.014740475f,-0.078897946f,
493 0.04974699f, 0.014160473f, 0.06973932f, 0.04964942f, 0.033364646f,
494 0.08190124f, 0.025535367f, 0.050893165f, 0.048514254f,0.06945813f,
495 -0.078907564f,-0.06707616f, -0.11844508f, -0.09986688f,-0.07509403f,
496 0.06263226f, 0.14925587f, 0.20188436f, 0.12098451f,0.14639415f,
497 0.0015017595f, -0.014267382f, -0.03417257f,0.012711468f,0.0028300495f,
498 -0.024758482f, -0.05098548f,-0.0821182f, 0.014225672f, 0.021544158f,
499 0.08949725f, 0.07505268f, -0.0020780868f, 0.04908258f,0.06476295f,
500 -0.022907063f,0.027562456f,0.040185735f, 0.019567577f,-0.015598739f,
501 -0.049097303f, -0.017121866f, -0.083368234f,-0.02332002f,-0.0840956f };
503 std::vector<float> inputGateBias = {0.02234832f, 0.14757581f, 0.18176508f, 0.10380666f, 0.053110216f,
504 -0.06928846f, -0.13942584f, -0.11816189f, 0.19483899f, 0.03652339f,
505 -0.10250295f, 0.036714908f, -0.18426876f, 0.036065217f, 0.21810818f,
506 0.02383196f, -0.043370757f, 0.08690144f, -0.04444982f, 0.00030581196f };
508 std::vector<float> forgetGateBias ={0.035185695f, -0.042891346f, -0.03032477f, 0.23027696f,
509 0.11098921f, 0.15378423f, 0.09263801f, 0.09790885f,
510 0.09508917f, 0.061199076f, 0.07665568f, -0.015443159f,
511 -0.03499149f, 0.046190713f, 0.08895977f, 0.10899629f,
512 0.40694186f, 0.06030037f, 0.012413437f, -0.06108739f };
514 std::vector<float> cellBias = { -0.024379363f, 0.0055531194f, 0.23377132f, 0.033463873f,
515 -0.1483596f, -0.10639995f, -0.091433935f, 0.058573797f,
516 -0.06809782f, -0.07889636f, -0.043246906f, -0.09829136f,
517 -0.4279842f, 0.034901652f, 0.18797937f, 0.0075234566f,
518 0.016178843f, 0.1749513f, 0.13975595f, 0.92058027f };
520 std::vector<float> outputGateBias ={0.046159424f, -0.0012809046f, 0.03563469f, 0.12648113f, 0.027195795f,
521 0.35373217f, -0.018957434f, 0.008907322f, -0.0762701f, 0.12018895f,
522 0.04216877f, 0.0022856654f, 0.040952638f, 0.3147856f, 0.08225149f,
523 -0.057416286f, -0.14995944f, -0.008040261f, 0.13208859f, 0.029760877f};
525 std::vector<float> recurrentToInputWeights = { -0.001374326f, -0.078856036f, 0.10672688f, 0.029162422f,
526 -0.11585556f, 0.02557986f, -0.13446963f, -0.035785314f,
527 -0.01244275f, 0.025961924f, -0.02337298f, -0.044228926f,
528 -0.055839065f, -0.046598054f, -0.010546039f, -0.06900766f,
529 0.027239809f, 0.022582639f, -0.013296484f, -0.05459212f,
530 0.08981f, -0.045407712f, 0.08682226f, -0.06867011f,
531 -0.14390695f, -0.02916037f, 0.000996957f, 0.091420636f,
532 0.14283475f, -0.07390571f, -0.06402044f, 0.062524505f,
533 -0.093129106f, 0.04860203f, -0.08364217f, -0.08119002f,
534 0.009352075f, 0.22920375f, 0.0016303885f, 0.11583097f,
535 -0.13732095f, 0.012405723f, -0.07551853f, 0.06343048f,
536 0.12162708f, -0.031923793f, -0.014335606f, 0.01790974f,
537 -0.10650317f, -0.0724401f, 0.08554849f, -0.05727212f,
538 0.06556731f, -0.042729504f, -0.043227166f, 0.011683251f,
539 -0.013082158f, -0.029302018f, -0.010899579f, -0.062036745f,
540 -0.022509435f, -0.00964907f, -0.01567329f, 0.04260106f,
541 -0.07787477f, -0.11576462f, 0.017356863f, 0.048673786f,
542 -0.017577527f, -0.05527947f, -0.082487635f, -0.040137455f,
543 -0.10820036f, -0.04666372f, 0.022746278f, -0.07851417f,
544 0.01068115f, 0.032956902f, 0.022433773f, 0.0026891115f,
545 0.08944216f, -0.0685835f, 0.010513544f, 0.07228705f,
546 0.02032331f, -0.059686817f, -0.0005566496f, -0.086984694f,
547 0.040414046f, -0.1380399f, 0.094208956f, -0.05722982f,
548 0.012092817f, -0.04989123f, -0.086576f, -0.003399834f,
549 -0.04696032f, -0.045747425f, 0.10091314f, 0.048676282f,
550 -0.029037097f, 0.031399418f, -0.0040285117f, 0.047237843f,
551 0.09504992f, 0.041799378f, -0.049185462f, -0.031518843f,
552 -0.10516937f, 0.026374253f, 0.10058866f, -0.0033195973f,
553 -0.041975245f, 0.0073591834f, 0.0033782164f, -0.004325073f,
554 -0.10167381f, 0.042500053f, -0.01447153f, 0.06464186f,
555 -0.017142897f, 0.03312627f, 0.009205989f, 0.024138335f,
556 -0.011337001f, 0.035530265f, -0.010912711f, 0.0706555f,
557 -0.005894094f, 0.051841937f, -0.1401738f, -0.02351249f,
558 0.0365468f, 0.07590991f, 0.08838724f, 0.021681072f,
559 -0.10086113f, 0.019608743f, -0.06195883f, 0.077335775f,
560 0.023646897f, -0.095322326f, 0.02233014f, 0.09756986f,
561 -0.048691444f, -0.009579111f, 0.07595467f, 0.11480546f,
562 -0.09801813f, 0.019894179f, 0.08502348f, 0.004032281f,
563 0.037211012f, 0.068537936f, -0.048005626f, -0.091520436f,
564 -0.028379958f, -0.01556313f, 0.06554592f, -0.045599163f,
565 -0.01672207f, -0.020169014f, -0.011877351f, -0.20212261f,
566 0.010889619f, 0.0047078193f, 0.038385306f, 0.08540671f,
567 -0.017140968f, -0.0035865551f, 0.016678626f, 0.005633034f,
568 0.015963363f, 0.00871737f, 0.060130805f, 0.028611384f,
569 0.10109069f, -0.015060172f, -0.07894427f, 0.06401885f,
570 0.011584063f, -0.024466386f, 0.0047652307f, -0.09041358f,
571 0.030737216f, -0.0046374933f, 0.14215417f, -0.11823516f,
572 0.019899689f, 0.006106124f, -0.027092824f, 0.0786356f,
573 0.05052217f, -0.058925f, -0.011402121f, -0.024987547f,
574 -0.0013661642f, -0.06832946f, -0.015667673f, -0.1083353f,
575 -0.00096863037f, -0.06988685f, -0.053350925f, -0.027275559f,
576 -0.033664223f, -0.07978348f, -0.025200296f, -0.017207067f,
577 -0.058403496f, -0.055697463f, 0.005798788f, 0.12965427f,
578 -0.062582195f, 0.0013350133f, -0.10482091f, 0.0379771f,
579 0.072521195f, -0.0029455067f, -0.13797039f, -0.03628521f,
580 0.013806405f, -0.017858358f, -0.01008298f, -0.07700066f,
581 -0.017081132f, 0.019358726f, 0.0027079724f, 0.004635139f,
582 0.062634714f, -0.02338735f, -0.039547626f, -0.02050681f,
583 0.03385117f, -0.083611414f, 0.002862572f, -0.09421313f,
584 0.058618143f, -0.08598433f, 0.00972939f, 0.023867095f,
585 -0.053934585f, -0.023203006f, 0.07452513f, -0.048767887f,
586 -0.07314807f, -0.056307215f, -0.10433547f, -0.06440842f,
587 0.04328182f, 0.04389765f, -0.020006588f, -0.09076438f,
588 -0.11652589f, -0.021705797f, 0.03345259f, -0.010329105f,
589 -0.025767034f, 0.013057034f, -0.07316461f, -0.10145612f,
590 0.06358255f, 0.18531723f, 0.07759293f, 0.12006465f,
591 0.1305557f, 0.058638252f, -0.03393652f, 0.09622831f,
592 -0.16253184f, -2.4580743e-06f, 0.079869635f, -0.070196845f,
593 -0.005644518f, 0.06857898f, -0.12598175f, -0.035084512f,
594 0.03156317f, -0.12794146f, -0.031963028f, 0.04692781f,
595 0.030070418f, 0.0071660685f, -0.095516115f, -0.004643372f,
596 0.040170413f, -0.062104587f, -0.0037324072f, 0.0554317f,
597 0.08184801f, -0.019164372f, 0.06791302f, 0.034257166f,
598 -0.10307039f, 0.021943003f, 0.046745934f, 0.0790918f,
599 -0.0265588f, -0.007824208f, 0.042546265f, -0.00977924f,
600 -0.0002440307f, -0.017384544f, -0.017990116f, 0.12252321f,
601 -0.014512694f, -0.08251313f, 0.08861942f, 0.13589665f,
602 0.026351685f, 0.012641483f, 0.07466548f, 0.044301085f,
603 -0.045414884f, -0.051112458f, 0.03444247f, -0.08502782f,
604 -0.04106223f, -0.028126027f, 0.028473156f, 0.10467447f };
606 std::vector<float> recurrentToForgetWeights = {-0.057784554f, -0.026057621f, -0.068447545f, -0.022581743f,
607 0.14811787f, 0.10826372f, 0.09471067f, 0.03987225f,
608 -0.0039523416f, 0.00030638507f, 0.053185795f, 0.10572994f,
609 0.08414449f, -0.022036452f, -0.00066928595f, -0.09203576f,
610 0.032950465f, -0.10985798f, -0.023809856f, 0.0021431844f,
611 -0.02196096f, -0.00326074f, 0.00058621005f, -0.074678116f,
612 -0.06193199f, 0.055729095f, 0.03736828f, 0.020123724f,
613 0.061878487f, -0.04729229f, 0.034919553f, -0.07585433f,
614 -0.04421272f, -0.044019096f, 0.085488975f, 0.04058006f,
615 -0.06890133f, -0.030951202f, -0.024628663f, -0.07672815f,
616 0.034293607f, 0.08556707f, -0.05293577f, -0.033561368f,
617 -0.04899627f, 0.0241671f, 0.015736353f, -0.095442444f,
618 -0.029564252f, 0.016493602f, -0.035026584f, 0.022337519f,
619 -0.026871363f, 0.004780428f, 0.0077918363f, -0.03601621f,
620 0.016435321f, -0.03263031f, -0.09543275f, -0.047392778f,
621 0.013454138f, 0.028934088f, 0.01685226f, -0.086110644f,
622 -0.046250615f, -0.01847454f, 0.047608484f, 0.07339695f,
623 0.034546845f, -0.04881143f, 0.009128804f, -0.08802852f,
624 0.03761666f, 0.008096139f, -0.014454086f, 0.014361001f,
625 -0.023502491f, -0.0011840804f, -0.07607001f, 0.001856849f,
626 -0.06509276f, -0.006021153f, -0.08570962f, -0.1451793f,
627 0.060212336f, 0.055259194f, 0.06974018f, 0.049454916f,
628 -0.027794661f, -0.08077226f, -0.016179763f, 0.1169753f,
629 0.17213494f, -0.0056326236f, -0.053934924f, -0.0124349f,
630 -0.11520337f, 0.05409887f, 0.088759385f, 0.0019655675f,
631 0.0042065294f, 0.03881498f, 0.019844765f, 0.041858196f,
632 -0.05695512f, 0.047233116f, 0.038937137f, -0.06542224f,
633 0.014429736f, -0.09719407f, 0.13908425f, -0.05379757f,
634 0.012321099f, 0.082840554f, -0.029899208f, 0.044217527f,
635 0.059855383f, 0.07711018f, -0.045319796f, 0.0948846f,
636 -0.011724666f, -0.0033288454f, -0.033542685f, -0.04764985f,
637 -0.13873616f, 0.040668588f, 0.034832682f, -0.015319203f,
638 -0.018715994f, 0.046002675f, 0.0599172f, -0.043107376f,
639 0.0294216f, -0.002314414f, -0.022424703f, 0.0030315618f,
640 0.0014641669f, 0.0029166266f, -0.11878115f, 0.013738511f,
641 0.12375372f, -0.0006038222f, 0.029104086f, 0.087442465f,
642 0.052958444f, 0.07558703f, 0.04817258f, 0.044462286f,
643 -0.015213451f, -0.08783778f, -0.0561384f, -0.003008196f,
644 0.047060397f, -0.002058388f, 0.03429439f, -0.018839769f,
645 0.024734668f, 0.024614193f, -0.042046934f, 0.09597743f,
646 -0.0043254104f, 0.04320769f, 0.0064070094f, -0.0019131786f,
647 -0.02558259f, -0.022822596f, -0.023273505f, -0.02464396f,
648 -0.10991725f, -0.006240552f, 0.0074488563f, 0.024044557f,
649 0.04383914f, -0.046476185f, 0.028658995f, 0.060410924f,
650 0.050786525f, 0.009452605f, -0.0073054377f, -0.024810238f,
651 0.0052906186f, 0.0066939713f, -0.0020913032f, 0.014515517f,
652 0.015898481f, 0.021362653f, -0.030262267f, 0.016587038f,
653 -0.011442813f, 0.041154444f, -0.007631438f, -0.03423484f,
654 -0.010977775f, 0.036152758f, 0.0066366293f, 0.11915515f,
655 0.02318443f, -0.041350313f, 0.021485701f, -0.10906167f,
656 -0.028218046f, -0.00954771f, 0.020531068f, -0.11995105f,
657 -0.03672871f, 0.024019798f, 0.014255957f, -0.05221243f,
658 -0.00661567f, -0.04630967f, 0.033188973f, 0.10107534f,
659 -0.014027541f, 0.030796422f, -0.10270911f, -0.035999842f,
660 0.15443139f, 0.07684145f, 0.036571592f, -0.035900835f,
661 -0.0034699554f, 0.06209149f, 0.015920248f, -0.031122351f,
662 -0.03858649f, 0.01849943f, 0.13872518f, 0.01503974f,
663 0.069941424f, -0.06948533f, -0.0088794185f, 0.061282158f,
664 -0.047401894f, 0.03100163f, -0.041533746f, -0.10430945f,
665 0.044574402f, -0.01425562f, -0.024290353f, 0.034563623f,
666 0.05866852f, 0.023947537f, -0.09445152f, 0.035450947f,
667 0.02247216f, -0.0042998926f, 0.061146557f, -0.10250651f,
668 0.020881841f, -0.06747029f, 0.10062043f, -0.0023941975f,
669 0.03532124f, -0.016341697f, 0.09685456f, -0.016764693f,
670 0.051808182f, 0.05875331f, -0.04536488f, 0.001626336f,
671 -0.028892258f, -0.01048663f, -0.009793449f, -0.017093895f,
672 0.010987891f, 0.02357273f, -0.00010856845f, 0.0099760275f,
673 -0.001845119f, -0.03551521f, 0.0018358806f, 0.05763657f,
674 -0.01769146f, 0.040995963f, 0.02235177f, -0.060430344f,
675 0.11475477f, -0.023854522f, 0.10071741f, 0.0686208f,
676 -0.014250481f, 0.034261297f, 0.047418304f, 0.08562733f,
677 -0.030519066f, 0.0060542435f, 0.014653856f, -0.038836084f,
678 0.04096551f, 0.032249358f, -0.08355519f, -0.026823482f,
679 0.056386515f, -0.010401743f, -0.028396193f, 0.08507674f,
680 0.014410365f, 0.020995233f, 0.17040324f, 0.11511526f,
681 0.02459721f, 0.0066619175f, 0.025853224f, -0.023133837f,
682 -0.081302024f, 0.017264642f, -0.009585969f, 0.09491168f,
683 -0.051313367f, 0.054532815f, -0.014298593f, 0.10657464f,
684 0.007076659f, 0.10964551f, 0.0409152f, 0.008275321f,
685 -0.07283536f, 0.07937492f, 0.04192024f, -0.1075027f };
687 std::vector<float> recurrentToCellWeights = { -0.037322544f, 0.018592842f, 0.0056175636f, -0.06253426f,
688 0.055647098f, -0.05713207f, -0.05626563f, 0.005559383f,
689 0.03375411f, -0.025757805f, -0.088049285f, 0.06017052f,
690 -0.06570978f, 0.007384076f, 0.035123326f, -0.07920549f,
691 0.053676967f, 0.044480428f, -0.07663568f, 0.0071805613f,
692 0.08089997f, 0.05143358f, 0.038261272f, 0.03339287f,
693 -0.027673481f, 0.044746667f, 0.028349208f, 0.020090483f,
694 -0.019443132f, -0.030755889f, -0.0040000007f, 0.04465846f,
695 -0.021585021f, 0.0031670958f, 0.0053199246f, -0.056117613f,
696 -0.10893326f, 0.076739706f, -0.08509834f, -0.027997585f,
697 0.037871376f, 0.01449768f, -0.09002357f, -0.06111149f,
698 -0.046195522f, 0.0422062f, -0.005683705f, -0.1253618f,
699 -0.012925729f, -0.04890792f, 0.06985068f, 0.037654128f,
700 0.03398274f, -0.004781977f, 0.007032333f, -0.031787455f,
701 0.010868644f, -0.031489216f, 0.09525667f, 0.013939797f,
702 0.0058680447f, 0.0167067f, 0.02668468f, -0.04797466f,
703 -0.048885044f, -0.12722108f, 0.035304096f, 0.06554885f,
704 0.00972396f, -0.039238118f, -0.05159735f, -0.11329045f,
705 0.1613692f, -0.03750952f, 0.06529313f, -0.071974665f,
706 -0.11769596f, 0.015524369f, -0.0013754242f, -0.12446318f,
707 0.02786344f, -0.014179351f, 0.005264273f, 0.14376344f,
708 0.015983658f, 0.03406988f, -0.06939408f, 0.040699873f,
709 0.02111075f, 0.09669095f, 0.041345075f, -0.08316494f,
710 -0.07684199f, -0.045768797f, 0.032298047f, -0.041805092f,
711 0.0119405f, 0.0061010392f, 0.12652606f, 0.0064572375f,
712 -0.024950314f, 0.11574242f, 0.04508852f, -0.04335324f,
713 0.06760663f, -0.027437469f, 0.07216407f, 0.06977076f,
714 -0.05438599f, 0.034033038f, -0.028602652f, 0.05346137f,
715 0.043184172f, -0.037189785f, 0.10420091f, 0.00882477f,
716 -0.054019816f, -0.074273005f, -0.030617684f, -0.0028467078f,
717 0.024302477f, -0.0038869337f, 0.005332455f, 0.0013399826f,
718 0.04361412f, -0.007001822f, 0.09631092f, -0.06702025f,
719 -0.042049985f, -0.035070654f, -0.04103342f, -0.10273396f,
720 0.0544271f, 0.037184782f, -0.13150354f, -0.0058036847f,
721 -0.008264958f, 0.042035464f, 0.05891794f, 0.029673764f,
722 0.0063542654f, 0.044788733f, 0.054816857f, 0.062257513f,
723 -0.00093483756f, 0.048938446f, -0.004952862f, -0.007730018f,
724 -0.04043371f, -0.017094059f, 0.07229206f, -0.023670016f,
725 -0.052195564f, -0.025616996f, -0.01520939f, 0.045104615f,
726 -0.007376126f, 0.003533447f, 0.006570588f, 0.056037236f,
727 0.12436656f, 0.051817212f, 0.028532185f, -0.08686856f,
728 0.11868599f, 0.07663395f, -0.07323171f, 0.03463402f,
729 -0.050708205f, -0.04458982f, -0.11590894f, 0.021273347f,
730 0.1251325f, -0.15313013f, -0.12224372f, 0.17228661f,
731 0.023029093f, 0.086124025f, 0.006445803f, -0.03496501f,
732 0.028332196f, 0.04449512f, -0.042436164f, -0.026587414f,
733 -0.006041347f, -0.09292539f, -0.05678812f, 0.03897832f,
734 0.09465633f, 0.008115513f, -0.02171956f, 0.08304309f,
735 0.071401566f, 0.019622514f, 0.032163795f, -0.004167056f,
736 0.02295182f, 0.030739572f, 0.056506045f, 0.004612461f,
737 0.06524936f, 0.059999723f, 0.046395954f, -0.0045512207f,
738 -0.1335546f, -0.030136576f, 0.11584653f, -0.014678886f,
739 0.0020118146f, -0.09688814f, -0.0790206f, 0.039770417f,
740 -0.0329582f, 0.07922767f, 0.029322514f, 0.026405897f,
741 0.04207835f, -0.07073373f, 0.063781224f, 0.0859677f,
742 -0.10925287f, -0.07011058f, 0.048005477f, 0.03438226f,
743 -0.09606514f, -0.006669445f, -0.043381985f, 0.04240257f,
744 -0.06955775f, -0.06769346f, 0.043903265f, -0.026784198f,
745 -0.017840602f, 0.024307009f, -0.040079936f, -0.019946516f,
746 0.045318738f, -0.12233574f, 0.026170589f, 0.0074471775f,
747 0.15978073f, 0.10185836f, 0.10298046f, -0.015476589f,
748 -0.039390966f, -0.072174534f, 0.0739445f, -0.1211869f,
749 -0.0347889f, -0.07943156f, 0.014809798f, -0.12412325f,
750 -0.0030663363f, 0.039695457f, 0.0647603f, -0.08291318f,
751 -0.018529687f, -0.004423833f, 0.0037507233f, 0.084633216f,
752 -0.01514876f, -0.056505352f, -0.012800942f, -0.06994386f,
753 0.012962922f, -0.031234352f, 0.07029052f, 0.016418684f,
754 0.03618972f, 0.055686004f, -0.08663945f, -0.017404709f,
755 -0.054761406f, 0.029065743f, 0.052404847f, 0.020238016f,
756 0.0048197987f, -0.0214882f, 0.07078733f, 0.013016777f,
757 0.06262858f, 0.009184685f, 0.020785125f, -0.043904778f,
758 -0.0270329f, -0.03299152f, -0.060088247f, -0.015162964f,
759 -0.001828936f, 0.12642565f, -0.056757294f, 0.013586685f,
760 0.09232601f, -0.035886683f, 0.06000002f, 0.05229691f,
761 -0.052580316f, -0.082029596f, -0.010794592f, 0.012947712f,
762 -0.036429964f, -0.085508935f, -0.13127148f, -0.017744139f,
763 0.031502828f, 0.036232427f, -0.031581745f, 0.023051167f,
764 -0.05325106f, -0.03421577f, 0.028793324f, -0.034633752f,
765 -0.009881397f, -0.043551125f, -0.018609839f, 0.0019097115f,
766 -0.008799762f, 0.056595087f, 0.0022273948f, 0.055752404f };
768 std::vector<float> recurrentToOutputWeights = { 0.025825322f, -0.05813119f, 0.09495884f,-0.045984812f, -0.01255415f,
769 -0.0026479573f,-0.08196161f,-0.054914974f,-0.0046604523f,
770 -0.029587349f, -0.044576716f, -0.07480124f, -0.082868785f,
771 0.023254942f, 0.027502948f, -0.0039728214f, -0.08683098f,
772 -0.08116779f, -0.014675607f, -0.037924774f, -0.023314456f,
773 -0.007401714f, -0.09255757f, 0.029460307f, -0.08829125f,
774 -0.005139627f, -0.08989442f, -0.0555066f, 0.13596267f,
775 -0.025062224f, -0.048351806f, -0.03850004f, 0.07266485f,
776 -0.022414139f, 0.05940088f, 0.075114764f, 0.09597592f,
777 -0.010211725f, -0.0049794707f, -0.011523867f, -0.025980417f,
778 0.072999895f, 0.11091378f, -0.081685916f, 0.014416728f,
779 0.043229222f, 0.034178585f, -0.07530371f, 0.035837382f,
780 -0.085607f, -0.007721233f, -0.03287832f, -0.043848954f,
781 -0.06404588f, -0.06632928f, -0.073643476f, 0.008214239f,
782 -0.045984086f, 0.039764922f, 0.03474462f, 0.060612556f,
783 -0.080590084f, 0.049127717f, 0.04151091f, -0.030063879f,
784 0.008801774f, -0.023021035f, -0.019558564f, 0.05158114f,
785 -0.010947698f, -0.011825728f, 0.0075720972f, 0.0699727f,
786 -0.0039981045f, 0.069350146f, 0.08799282f, 0.016156472f,
787 0.035502106f, 0.11695009f, 0.006217345f, 0.13392477f,
788 -0.037875112f, 0.025745004f, 0.08940699f, -0.00924166f,
789 0.0046702605f, -0.036598757f, -0.08811812f, 0.10522024f,
790 -0.032441203f, 0.008176899f, -0.04454919f, 0.07058152f,
791 0.0067963637f, 0.039206743f, 0.03259838f, 0.03725492f,
792 -0.09515802f, 0.013326398f, -0.052055415f, -0.025676316f,
793 0.03198509f, -0.015951829f, -0.058556724f, 0.036879618f,
794 0.043357447f, 0.028362012f, -0.05908629f, 0.0059240665f,
795 -0.04995891f, -0.019187413f,0.0276265f, -0.01628143f, 0.0025863599f,
796 0.08800015f, 0.035250366f, -0.022165963f, -0.07328642f,
797 -0.009415526f, -0.07455109f, 0.11690406f, 0.0363299f,
798 0.07411125f, 0.042103454f, -0.009660886f, 0.019076364f,
799 0.018299393f, -0.046004917f, 0.08891175f,0.0431396f, -0.026327137f,
800 -0.051502608f, 0.08979574f, -0.051670972f, 0.04940282f,
801 -0.07491107f, -0.021240504f, 0.022596184f, -0.034280192f,
802 0.060163025f, -0.058211457f, -0.051837247f, -0.01349775f,
803 -0.04639988f, -0.035936575f, -0.011681591f, 0.064818054f,
804 0.0073146066f, -0.021745546f, -0.043124277f, -0.06471268f,
805 -0.07053354f, -0.029321948f, -0.05330136f, 0.016933719f,
806 -0.053782392f, 0.13747959f, -0.1361751f, -0.11569455f,
807 0.0033329215f, 0.05693899f, -0.053219706f, 0.063698f,
808 0.07977434f, -0.07924483f, 0.06936997f, 0.0034815092f,
809 -0.007305279f, -0.037325785f, -0.07251102f, -0.033633437f,
810 -0.08677009f, 0.091591336f, -0.14165086f, 0.021752775f,
811 0.019683983f, 0.0011612234f, -0.058154266f, 0.049996935f,
812 0.0288841f, -0.0024567875f, -0.14345716f, 0.010955264f,-0.10234828f,
813 0.1183656f, -0.0010731248f, -0.023590032f,-0.072285876f,-0.0724771f,
814 -0.026382286f, -0.0014920527f, 0.042667855f, 0.0018776858f,
815 0.02986552f, 0.009814309f, 0.0733756f, 0.12289186f,
816 0.018043943f, -0.0458958f, 0.049412545f, 0.033632483f,
817 0.05495232f, 0.036686596f, -0.013781798f, -0.010036754f,
818 0.02576849f, -0.08307328f, 0.010112348f, 0.042521734f,
819 -0.05869831f, -0.071689695f, 0.03876447f, -0.13275425f, -0.0352966f,
820 -0.023077697f, 0.10285965f, 0.084736146f, 0.15568255f,
821 -0.00040734606f, 0.027835453f, -0.10292561f, -0.032401145f,
822 0.10053256f, -0.026142767f, -0.08271222f, -0.0030240538f,
823 -0.016368777f, 0.1070414f, 0.042672627f, 0.013456989f,
824 -0.0437609f, -0.022309763f, 0.11576483f, 0.04108048f,
825 0.061026827f, -0.0190714f, -0.0869359f, 0.037901703f, 0.0610107f,
826 0.07202949f, 0.01675338f, 0.086139716f, -0.08795751f,
827 -0.014898893f, -0.023771819f, -0.01965048f, 0.007955471f,
828 -0.043740474f, 0.03346837f, -0.10549954f, 0.090567775f,
829 0.042013682f, -0.03176985f, 0.12569028f, -0.02421228f,
830 -0.029526481f, 0.023851605f, 0.031539805f, 0.05292009f,
831 -0.02344001f, -0.07811758f, -0.08834428f, 0.10094801f,
832 0.16594367f, -0.06861939f, -0.021256343f, -0.041093912f,
833 -0.06669611f, 0.035498552f, 0.021757556f, -0.09302526f,
834 -0.015403468f, -0.06614931f, -0.051798206f, -0.013874718f,
835 0.03630673f, 0.010412845f, -0.08077351f, 0.046185967f,
836 0.0035662893f, 0.03541868f, -0.094149634f, -0.034814864f,
837 0.003128424f, -0.020674974f, -0.03944324f, -0.008110165f,
838 -0.11113267f, 0.08484226f, 0.043586485f, 0.040582247f,
839 0.0968012f, -0.065249965f, -0.028036479f, 0.0050708856f,
840 0.0017462453f, 0.0326779f, 0.041296225f, 0.09164146f,
841 -0.047743853f, -0.015952192f, -0.034451712f, 0.084197424f,
842 -0.05347844f, -0.11768019f, 0.085926116f, -0.08251791f,
843 -0.045081906f, 0.0948852f, 0.068401024f, 0.024856757f,
844 0.06978981f, -0.057309967f, -0.012775832f, -0.0032452994f,
845 0.01977615f, -0.041040014f, -0.024264973f,0.063464895f, 0.05431621f};
847 std::vector<float> cellToInputWeights = {0.040369894f, 0.030746894f, 0.24704495f, 0.018586371f, -0.037586458f,
848 -0.15312155f, -0.11812848f, -0.11465643f, 0.20259799f, 0.11418174f,
849 -0.10116027f, -0.011334949f, 0.12411352f, -0.076769054f,-0.052169047f,
850 0.21198851f, -0.38871562f, -0.09061183f, -0.09683246f, -0.21929175f};
853 std::vector<float> cellToForgetWeights = {-0.01998659f,-0.15568835f,-0.24248174f, -0.012770197f, 0.041331276f,
854 -0.072311886f, -0.052123554f,-0.0066330447f,-0.043891653f,0.036225766f,
855 -0.047248036f, 0.021479502f,0.033189066f, 0.11952997f, -0.020432774f,
856 0.64658105f, -0.06650122f, -0.03467612f, 0.095340036f, 0.23647355f};
858 std::vector<float> cellToOutputWeights = { 0.08286371f, -0.08261836f, -0.51210177f, 0.002913762f, 0.17764764f,
859 -0.5495371f, -0.08460716f, -0.24552552f, 0.030037103f, 0.04123544f,
860 -0.11940523f, 0.007358328f, 0.1890978f, 0.4833202f, -0.34441817f,
861 0.36312827f, -0.26375428f, 0.1457655f, -0.19724406f, 0.15548733f};
863 std::vector<float> projectionWeights={-0.009802181f, 0.09401916f, 0.0717386f, -0.13895074f, 0.09641832f,
864 0.060420845f, 0.08539281f, 0.054285463f, 0.061395317f, 0.034448683f,
865 -0.042991187f, 0.019801661f, -0.16840284f, -0.015726732f, -0.23041931f,
866 -0.024478018f, -0.10959692f, -0.013875541f, 0.18600968f, -0.061274476f,
867 0.0138165f, -0.08160894f, -0.07661644f, 0.032372914f, 0.16169067f,
868 0.22465782f, -0.03993472f, -0.004017731f, 0.08633481f, -0.28869787f,
869 0.08682067f, 0.17240396f, 0.014975425f, 0.056431185f, 0.031037588f,
870 0.16702051f, 0.0077946745f, 0.15140012f, 0.29405436f, 0.120285f,
871 -0.188994f, -0.027265169f, 0.043389652f, -0.022061434f, 0.014777949f,
872 -0.20203483f, 0.094781205f, 0.19100232f, 0.13987629f, -0.036132768f,
873 -0.06426278f, -0.05108664f, 0.13221376f, 0.009441198f, -0.16715929f,
874 0.15859416f, -0.040437475f, 0.050779544f, -0.022187516f, 0.012166504f,
875 0.027685808f, -0.07675938f, -0.0055694645f, -0.09444123f, 0.0046453946f,
876 0.050794356f, 0.10770313f, -0.20790008f, -0.07149004f, -0.11425117f,
877 0.008225835f, -0.035802525f, 0.14374903f, 0.15262283f, 0.048710253f,
878 0.1847461f, -0.007487823f, 0.11000021f, -0.09542012f, 0.22619456f,
879 -0.029149994f, 0.08527916f, 0.009043713f, 0.0042746216f, 0.016261552f,
880 0.022461696f, 0.12689082f, -0.043589946f, -0.12035478f, -0.08361797f,
881 -0.050666027f, -0.1248618f, -0.1275799f, -0.071875185f, 0.07377272f,
882 0.09944291f, -0.18897448f, -0.1593054f, -0.06526116f, -0.040107165f,
883 -0.004618631f, -0.067624845f, -0.007576253f, 0.10727444f, 0.041546922f,
884 -0.20424393f, 0.06907816f, 0.050412357f, 0.00724631f, 0.039827548f,
885 0.12449835f, 0.10747581f, 0.13708383f, 0.09134148f, -0.12617786f,
886 -0.06428341f, 0.09956831f, 0.1208086f, -0.14676677f, -0.0727722f,
887 0.1126304f, 0.010139365f, 0.015571211f, -0.038128063f, 0.022913318f,
888 -0.042050496f, 0.16842307f, -0.060597885f, 0.10531834f, -0.06411776f,
889 -0.07451711f, -0.03410368f, -0.13393489f, 0.06534304f, 0.003620307f,
890 0.04490757f, 0.05970546f, 0.05197996f, 0.02839995f, 0.10434969f,
891 -0.013699693f, -0.028353551f, -0.07260381f, 0.047201227f, -0.024575593f,
892 -0.036445823f, 0.07155557f, 0.009672501f, -0.02328883f, 0.009533515f,
893 -0.03606021f, -0.07421458f, -0.028082801f, -0.2678904f, -0.13221288f,
894 0.18419984f, -0.13012612f, -0.014588381f, -0.035059117f, -0.04824723f,
895 0.07830115f, -0.056184657f, 0.03277091f, 0.025466874f, 0.14494097f,
896 -0.12522776f, -0.098633975f, -0.10766018f, -0.08317623f, 0.08594209f,
897 0.07749552f, 0.039474737f, 0.1776665f, -0.07409566f, -0.0477268f,
898 0.29323658f, 0.10801441f, 0.1154011f, 0.013952499f, 0.10739139f,
899 0.10708251f, -0.051456142f, 0.0074137426f, -0.10430189f, 0.10034707f,
900 0.045594677f, 0.0635285f, -0.0715442f, -0.089667566f, -0.10811871f,
901 0.00026344223f, 0.08298446f, -0.009525053f, 0.006585689f, -0.24567553f,
902 -0.09450807f, 0.09648481f, 0.026996298f, -0.06419476f, -0.04752702f,
903 -0.11063944f, -0.23441927f, -0.17608605f, -0.052156363f, 0.067035615f,
904 0.19271925f, -0.0032889997f, -0.043264326f, 0.09663576f, -0.057112187f,
905 -0.10100678f, 0.0628376f, 0.04447668f, 0.017961001f, -0.10094388f,
906 -0.10190601f, 0.18335468f, 0.10494553f, -0.052095775f, -0.0026118709f,
907 0.10539724f, -0.04383912f, -0.042349473f, 0.08438151f, -0.1947263f,
908 0.02251204f, 0.11216432f, -0.10307853f, 0.17351969f, -0.039091777f,
909 0.08066188f, -0.00561982f, 0.12633002f, 0.11335965f, -0.0088127935f,
910 -0.019777594f, 0.06864014f, -0.059751723f, 0.016233567f, -0.06894641f,
911 -0.28651384f, -0.004228674f, 0.019708522f, -0.16305895f, -0.07468996f,
912 -0.0855457f, 0.099339016f, -0.07580735f, -0.13775392f, 0.08434318f,
913 0.08330512f, -0.12131499f, 0.031935584f, 0.09180414f, -0.08876437f,
914 -0.08049874f, 0.008753825f, 0.03498998f, 0.030215185f, 0.03907079f,
915 0.089751154f, 0.029194152f, -0.03337423f, -0.019092513f, 0.04331237f,
916 0.04299654f, -0.036394123f, -0.12915532f, 0.09793732f, 0.07512415f,
917 -0.11319543f, -0.032502122f, 0.15661901f, 0.07671967f, -0.005491124f,
918 -0.19379048f, -0.218606f, 0.21448623f, 0.017840758f, 0.1416943f,
919 -0.07051762f, 0.19488361f, 0.02664691f, -0.18104725f, -0.09334311f,
920 0.15026465f, -0.15493552f, -0.057762887f, -0.11604192f, -0.262013f,
921 -0.01391798f, 0.012185008f, 0.11156489f, -0.07483202f, 0.06693364f,
922 -0.26151478f, 0.046425626f, 0.036540434f, -0.16435726f, 0.17338543f,
923 -0.21401681f, -0.11385144f, -0.08283257f, -0.069031075f, 0.030635102f,
924 0.010969227f, 0.11109743f, 0.010919218f, 0.027526086f, 0.13519906f,
925 0.01891392f, -0.046839405f, -0.040167913f, 0.017953383f, -0.09700955f,
926 0.0061885654f, -0.07000971f, 0.026893595f, -0.038844477f, 0.14543656f};
928 std::vector<float> projectionBiasVector(outputSize, 0.f);
991 inputHandle->Allocate();
992 outputStateInHandle->Allocate();
993 cellStateInHandle->Allocate();
995 scratchHandle->Allocate();
996 outputStateOutHandle->Allocate();
997 cellStateOutHandle->Allocate();
998 outputHandle->Allocate();
1004 workload->Execute();
1010 outputHandle->GetShape(),
1011 outputTensorInfo.GetShape());
1014 template<armnn::DataType ArmnnType,
typename T = armnn::ResolveType<ArmnnType>>
1019 const std::vector<T>& input,
1020 const std::vector<T>& outputExpected,
1023 float qScale = 0.0f,
1024 int32_t qOffset = 0,
1028 bool cifgEnabled =
true;
1029 bool peepholeEnabled =
true;
1030 bool projectionEnabled =
false;
1037 const unsigned int cellSize = outputSize;
1040 armnn::TensorInfo inputTensorInfo({batchSize , inputSize}, ArmnnType, qScale, qOffset);
1041 armnn::TensorInfo outputStateInTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
1042 armnn::TensorInfo cellStateInTensorInfo({batchSize, cellSize}, ArmnnType, qScale, qOffset);
1044 unsigned int scratchBufferSize = cifgEnabled ? cellSize * 3 : cellSize * 4;
1045 armnn::TensorInfo scratchBufferTensorInfo({batchSize, scratchBufferSize}, ArmnnType, qScale, qOffset);
1046 armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
1047 armnn::TensorInfo cellStateOutTensorInfo({batchSize, cellSize}, ArmnnType, qScale, qOffset);
1048 armnn::TensorInfo outputTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
1051 std::vector<float> inputData;
1052 inputData.assign(input.data(), input.data() + batchSize*inputSize);
1054 std::vector<float> outputStateInVector(batchSize * outputSize, 0.f);
1056 std::vector<float> cellStateInVector(batchSize * cellSize, 0.f);
1060 armnn::TensorInfo tensorInfoInput({cellSize, inputSize}, constantDataType, qScale, qOffset);
1061 armnn::TensorInfo tensorInfoOutput({cellSize, outputSize}, constantDataType, qScale, qOffset);
1062 armnn::TensorInfo tensorInfoNumUnits({cellSize}, constantDataType, qScale, qOffset);
1064 std::vector<float> inputToCellWeights =
1066 -0.49770179f, -0.27711356f, -0.09624726f, 0.05100781f,
1067 0.04717243f, 0.48944736f, -0.38535351f,
1070 std::vector<float> inputToForgetWeights =
1072 -0.55291498f, -0.42866567f, 0.13056988f,
1073 -0.3633365f, -0.22755712f, 0.28253698f, 0.24407166f,
1076 std::vector<float> inputToOutputWeights =
1078 0.10725588f, -0.02335852f, -0.55932593f,
1079 -0.09426838f, -0.44257352f, 0.54939759f,
1080 0.01533556f, 0.42751634f
1082 std::vector<float> cellBias = {0.f, 0.f, 0.f, 0.f};
1083 std::vector<float> forgetGateBias = {1.f, 1.f, 1.f, 1.f};
1084 std::vector<float> outputGateBias = {0.f, 0.f, 0.f, 0.f};
1086 std::vector<float> recurrentToCellWeights =
1088 0.54066205f, -0.32668582f, -0.43562764f, -0.56094903f, 0.42957711f,
1089 0.01841056f, -0.32764608f, -0.33027974f, -0.10826075f, 0.20675004f,
1090 0.19069612f, -0.03026325f, -0.54532051f, 0.33003211f, 0.44901288f,
1093 std::vector<float> recurrentToForgetWeights =
1095 -0.13832897f, -0.0515101f, -0.2359007f, -0.16661474f, -0.14340827f,
1096 0.36986142f, 0.23414481f, 0.55899f, 0.10798943f, -0.41174671f, 0.17751795f,
1097 -0.34484994f, -0.35874045f, -0.11352962f, 0.27268326f, 0.54058349f
1100 std::vector<float> recurrentToOutputWeights =
1102 0.41613156f, 0.42610586f, -0.16495961f, -0.5663873f, 0.30579174f, -0.05115908f,
1103 -0.33941799f, 0.23364776f, 0.11178309f, 0.09481031f, -0.26424935f, 0.46261835f,
1104 0.50248802f, 0.26114327f, -0.43736315f, 0.33149987f
1107 std::vector<float> cellToForgetWeights = {0.47485286f, -0.51955009f, -0.24458408f, 0.31544167f};
1108 std::vector<float> cellToOutputWeights = {-0.17135078f, 0.82760304f, 0.85573703f, -0.77109635f};
1165 std::vector<T> scratchBufferVector(batchSize * scratchBufferSize, T());
1169 std::vector<T> outputStateOutVector(batchSize * outputSize, T());
1173 std::vector<T> cellStateOutVector(batchSize * cellSize, T());
1177 std::vector<T> outputData;
1178 outputData.assign(outputExpected.data(), outputExpected.data() + batchSize*outputSize);
1180 ret3.m_ExpectedData = outputData;
1182 std::vector<T> actualScratchBufferOutput(scratchBufferTensorInfo.GetNumElements());
1183 std::vector<T> actualOutputStateOutput(outputStateOutTensorInfo.GetNumElements());
1184 std::vector<T> actualCellStateOutput(cellStateOutTensorInfo.GetNumElements());
1185 std::vector<T> actualOutput(outputTensorInfo.GetNumElements());
1188 std::unique_ptr<armnn::ITensorHandle> inputHandle =
1190 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
1192 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
1195 std::unique_ptr<armnn::ITensorHandle> scratchBufferHandle =
1197 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
1199 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
1201 std::unique_ptr<armnn::ITensorHandle> outputHandle =
1205 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
1206 AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get());
1207 AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get());
1209 AddOutputToWorkload(data, info, scratchBufferTensorInfo, scratchBufferHandle.get());
1210 AddOutputToWorkload(data, info, outputStateOutTensorInfo, outputStateOutHandle.get());
1211 AddOutputToWorkload(data, info, cellStateOutTensorInfo, cellStateOutHandle.get());
1212 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
1216 inputHandle->Allocate();
1217 outputStateInHandle->Allocate();
1218 cellStateInHandle->Allocate();
1220 scratchBufferHandle->Allocate();
1221 outputStateOutHandle->Allocate();
1222 cellStateOutHandle->Allocate();
1223 outputHandle->Allocate();
1233 workload->Execute();
1240 ret0.m_ActualData = actualScratchBufferOutput;
1241 ret1.m_ActualData = actualOutputStateOutput;
1242 ret2.m_ActualData = actualCellStateOutput;
1243 ret3.m_ActualData = actualOutput;
1248 template<armnn::DataType ArmnnType,
typename T = armnn::ResolveType<ArmnnType>>
1253 const std::vector<T>& input,
1254 const std::vector<T>& outputExpected,
1255 float qScale = 0.0f,
1256 int32_t qOffset = 0,
1260 unsigned int batchSize = 2;
1261 unsigned int outputSize = 3;
1262 unsigned int inputSize = 5;
1263 unsigned numUnits = 4;
1265 armnn::TensorInfo inputTensorInfo({batchSize , inputSize}, ArmnnType, qScale, qOffset);
1266 armnn::TensorInfo cellStateInTensorInfo({batchSize , numUnits}, ArmnnType, qScale, qOffset);
1267 armnn::TensorInfo outputStateInTensorInfo({batchSize , outputSize}, ArmnnType, qScale, qOffset);
1270 armnn::TensorInfo scratchBufferTensorInfo({batchSize, numUnits * 4}, ArmnnType, qScale, qOffset);
1271 armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, ArmnnType, qScale, qOffset);
1272 armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
1273 armnn::TensorInfo outputTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
1275 std::vector<float> inputVector;
1276 inputVector.assign(input.data(), input.data() + (batchSize * inputSize));
1278 std::vector<float> cellStateInVector(batchSize * numUnits, 0.f);
1279 std::vector<float> outputStateInVector(batchSize * outputSize, 0.f);
1280 std::vector<float> scratchBufferVector(batchSize * numUnits * 4, 0.f);
1281 std::vector<float> outputStateOutVector(batchSize * outputSize, 0.f);
1282 std::vector<float> cellStateOutVector(batchSize * numUnits, 0.f);
1284 std::vector<float> actualOutput(outputTensorInfo.GetNumElements());
1286 std::vector<float> outputVector;
1287 outputVector.assign(outputExpected.data(), outputExpected.data() + (batchSize * outputSize));
1289 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.
CreateTensorHandle(inputTensorInfo);
1290 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
1292 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
1295 std::unique_ptr<armnn::ITensorHandle> scratchHandle =
1297 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
1299 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
1301 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.
CreateTensorHandle(outputTensorInfo);
1306 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
1307 AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get());
1308 AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get());
1310 AddOutputToWorkload(data, info, scratchBufferTensorInfo, scratchHandle.get());
1311 AddOutputToWorkload(data, info, outputStateOutTensorInfo, outputStateOutHandle.get());
1312 AddOutputToWorkload(data, info, cellStateOutTensorInfo, cellStateOutHandle.get());
1313 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
1317 armnn::TensorInfo tensorInfo4x5({numUnits, inputSize}, constantDataType, qScale, qOffset);
1318 armnn::TensorInfo tensorInfo4x3({numUnits, outputSize}, constantDataType, qScale, qOffset);
1319 armnn::TensorInfo tensorInfo3x4({outputSize, numUnits}, constantDataType, qScale, qOffset);
1321 std::vector<float> inputToInputWeights = {0.5f, 0.6f, 0.7f, -0.8f, -0.9f,
1322 0.1f, 0.2f, 0.3f, -0.4f, 0.5f,
1323 -0.8f, 0.7f, -0.6f, 0.5f, -0.4f,
1324 -0.5f, -0.4f, -0.3f, -0.2f, -0.1f};
1326 std::vector<float> inputToForgetWeights = { -0.6f, -0.1f, 0.3f, 0.2f, 0.9f,
1327 -0.5f, -0.2f, -0.4f, 0.3f, -0.8f,
1328 -0.4f, 0.3f, -0.5f, -0.4f, -0.6f,
1329 0.3f, -0.4f, -0.6f, -0.5f, -0.5f};
1331 std::vector<float> inputToCellWeights = {-0.4f, -0.3f, -0.2f, -0.1f, -0.5f,
1332 0.5f, -0.2f, -0.3f, -0.2f, -0.6f,
1333 0.6f, -0.1f, -0.4f, -0.3f, -0.7f,
1334 0.7f, -0.9f, -0.5f, 0.8f, 0.6f};
1336 std::vector<float> inputToOutputWeights = {-0.8f, -0.4f, -0.2f, -0.9f, -0.1f,
1337 -0.7f, 0.3f, -0.3f, -0.8f, -0.2f,
1338 0.6f, -0.2f, 0.4f, -0.7f, -0.3f,
1339 -0.5f, 0.1f, 0.5f, -0.6f, -0.4f};
1341 std::vector<float> inputGateBias = {0.03f, 0.15f, 0.22f, 0.38f};
1343 std::vector<float> forgetGateBias = {0.1f, -0.3f, -0.2f, 0.1f};
1345 std::vector<float> cellBias = {-0.05f, 0.72f, 0.25f, 0.08f};
1347 std::vector<float> outputGateBias = {0.05f, -0.01f, 0.2f, 0.1f};
1349 std::vector<float> recurrentToInputWeights ={-0.2f, -0.3f, 0.4f,
1351 -0.2f, -0.3f, -0.7f,
1352 0.05f, -0.2f, -0.6f};
1354 std::vector<float> recurrentToCellWeights = {-0.3f, 0.2f, 0.1f,
1355 -0.3f, 0.8f, -0.08f,
1357 -0.6f, -0.1f, 0.2f};
1359 std::vector<float> recurrentToForgetWeights = { -0.5f, -0.3f, -0.5f,
1364 std::vector<float> recurrentToOutputWeights = { 0.3f, -0.1f, 0.1f,
1365 -0.2f, -0.5f, -0.7f,
1366 -0.2f, -0.6f, -0.1f,
1367 -0.4f, -0.7f, -0.2f};
1369 std::vector<float> cellToInputWeights = {0.05f, 0.1f, 0.25f, 0.15f};
1371 std::vector<float> cellToForgetWeights = {-0.02f, -0.15f, -0.25f, -0.03f};
1373 std::vector<float> cellToOutputWeights = {0.1f, -0.1f, -0.5f, 0.05f};
1375 std::vector<float> projectionWeights = {-0.1f, 0.2f, 0.01f, -0.2f,
1376 0.1f, 0.5f, 0.3f, 0.08f,
1377 0.07f, 0.2f, -0.4f, 0.2f};
1379 std::vector<float> projectionBiasVector(outputSize, 0.f);
1381 std::vector<float> inputLayerNormWeights = {0.1f, 0.2f, 0.3f, 0.5f};
1383 std::vector<float> forgetLayerNormWeights = {0.2f, 0.2f, 0.4f, 0.3f};
1385 std::vector<float> cellLayerNormWeights = {0.7f, 0.2f, 0.3f, 0.8f};
1387 std::vector<float> outputLayerNormWeights = {0.6f, 0.2f, 0.2f, 0.5f};
1468 inputHandle->Allocate();
1469 outputStateInHandle->Allocate();
1470 cellStateInHandle->Allocate();
1472 scratchHandle->Allocate();
1473 outputStateOutHandle->Allocate();
1474 cellStateOutHandle->Allocate();
1475 outputHandle->Allocate();
1481 workload->Execute();
1487 outputHandle->GetShape(),
1488 outputTensorInfo.GetShape());
1495 const std::vector<uint8_t>& input,
1496 const std::vector<uint8_t>& outputExpected,
1506 float inputOutputScale = 0.0078125f;
1507 int32_t inputOutputOffset = 128;
1509 float cellStateScale = 0.00048828125f;
1510 int32_t cellStateOffset = 0;
1512 float weightsScale = 0.00408021f;
1513 int32_t weightsOffset = 100;
1515 float biasScale = 3.1876640625e-05f;
1516 int32_t biasOffset = 0;
1535 std::vector<uint8_t> inputVector;
1536 inputVector.assign(input.data(), input.data() + (numBatches * inputSize));
1539 std::vector<int16_t> cellStateInVector = {876, 1034, 955, -909, 761, 1029, 796, -1036};
1541 std::vector<uint8_t> outputStateInVector = {136, 150, 140, 115, 135, 152, 138, 112};
1544 std::vector<int16_t> cellStateOutVector = {1485, 1177, 1373, -1023, 1019, 1355, 1097, -1235};
1547 std::vector<uint8_t> outputVector;
1548 outputVector.assign(outputExpected.data(), outputExpected.data() + (numBatches * outputSize));
1550 std::vector<uint8_t> actualOutput(outputStateInfo.GetNumElements());
1553 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.
CreateTensorHandle(inputInfo);
1554 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
1556 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
1559 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
1561 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.
CreateTensorHandle(outputStateInfo);
1567 AddInputToWorkload(data, info, inputInfo, inputHandle.get());
1568 AddInputToWorkload(data, info, cellStateInfo, cellStateInHandle.get());
1569 AddInputToWorkload(data, info, outputStateInfo, outputStateInHandle.get());
1571 AddOutputToWorkload(data, info, cellStateInfo, cellStateOutHandle.get());
1572 AddOutputToWorkload(data, info, outputStateInfo, outputHandle.get());
1588 std::vector<uint8_t> inputToInputWeights = {146, 250, 235, 171, 10, 218, 171, 108};
1589 std::vector<uint8_t> inputToForgetWeights = {24, 50, 132, 179, 158, 110, 3, 169};
1590 std::vector<uint8_t> inputToCellWeights = {133, 34, 29, 49, 206, 109, 54, 183};
1591 std::vector<uint8_t> inputToOutputWeights = {195, 187, 11, 99, 109, 10, 218, 48};
1593 std::vector<uint8_t> recurrentToInputWeights =
1594 {254, 206, 77, 168, 71, 20, 215, 6, 223, 7, 118, 225, 59, 130, 174, 26};
1595 std::vector<uint8_t> recurrentToForgetWeights =
1596 {137, 240, 103, 52, 68, 51, 237, 112, 0, 220, 89, 23, 69, 4, 207, 253};
1597 std::vector<uint8_t> recurrentToCellWeights =
1598 {172, 60, 205, 65, 14, 0, 140, 168, 240, 223, 133, 56, 142, 64, 246, 216};
1599 std::vector<uint8_t> recurrentToOutputWeights =
1600 {106, 214, 67, 23, 59, 158, 45, 3, 119, 132, 49, 205, 129, 218, 11, 98};
1602 std::vector<int32_t> inputGateBias = {-7876, 13488, -726, 32839};
1603 std::vector<int32_t> forgetGateBias = {9206, -46884, -11693, -38724};
1604 std::vector<int32_t> cellBias = {39481, 48624, 48976, -21419};
1605 std::vector<int32_t> outputGateBias = {-58999, -17050, -41852, -40538};
1659 inputHandle->Allocate();
1660 outputStateInHandle->Allocate();
1661 cellStateInHandle->Allocate();
1663 cellStateOutHandle->Allocate();
1664 outputHandle->Allocate();
1670 workload->Execute();
1676 outputHandle->GetShape(),
1677 outputStateInfo.GetShape());
1685 const std::vector<int8_t>& input,
1686 const std::vector<int8_t>& outputExpected)
1689 unsigned int numBatches = 2;
1690 unsigned int inputSize = 5;
1691 unsigned int outputSize = 4;
1692 unsigned int numUnits = 4;
1694 bool cifgEnabled =
true;
1695 bool peepholeEnabled =
false;
1696 bool projectionEnabled =
false;
1697 bool layerNormEnabled =
true;
1700 float inputScale = 0.0078125f;
1701 int32_t inputOffset = 0;
1703 int32_t hiddenStateZeroPoint = 0;
1704 float hiddenStateScale = 0.007f;
1707 float outputScale = hiddenStateScale;
1708 int32_t outputOffset = hiddenStateZeroPoint;
1710 float cellStateScale = 3.05176e-05f;
1711 int32_t cellStateOffset = 0;
1713 float weightsScale = 0.00784314f;
1714 int32_t weightsOffset = 0;
1716 float layerNormScale = 3.05182e-05f;
1717 int32_t layerNormOffset = 0;
1719 float biasScale = layerNormScale / 1024;
1720 int32_t biasOffset = 0;
1722 float inputIntermediateScale = 0.007059f;
1723 float forgetIntermediateScale = 0.007812f;
1724 float cellIntermediateScale = inputIntermediateScale;
1725 float outputIntermediateScale = forgetIntermediateScale;
1727 float cellClip = 0.0f;
1728 float projectionClip = 0.0f;
1749 std::vector<int8_t> inputVector;
1750 inputVector.assign(input.data(), input.data() + (numBatches * inputSize));
1752 std::vector<int16_t> cellStateInVector = {0, 0, 0, 0, 0, 0, 0, 0};
1754 std::vector<int8_t> outputStateInVector = {0, 0, 0, 0, 0, 0, 0, 0};
1757 std::vector<int16_t> cellStateOutVector = {-11692, 9960, 5491, 8861, -9422, 7726, 2056, 13149};
1759 std::vector<int8_t> outputVector;
1760 outputVector.assign(outputExpected.data(), outputExpected.data() + (numBatches * outputSize));
1762 std::vector<int8_t> actualOutput(outputStateInfo.GetNumElements());
1765 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.
CreateTensorHandle(inputInfo);
1766 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
1768 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
1771 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
1773 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
1775 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.
CreateTensorHandle(outputStateInfo);
1781 AddInputToWorkload(data, info, inputInfo, inputHandle.get());
1782 AddInputToWorkload(data, info, outputStateInfo, outputStateInHandle.get());
1783 AddInputToWorkload(data, info, cellStateInfo, cellStateInHandle.get());
1785 AddOutputToWorkload(data, info, outputStateInfo, outputStateOutHandle.get());
1786 AddOutputToWorkload(data, info, cellStateInfo, cellStateOutHandle.get());
1787 AddOutputToWorkload(data, info, outputStateInfo, outputHandle.get());
1805 std::vector<int8_t> inputToForgetWeights =
1806 {-77, -13, 38, 25, 115, -64, -25, -51, 38, -102, -51, 38, -64, -51, -77, 38, -51, -77, -64, -64};
1807 std::vector<int8_t> inputToCellWeights =
1808 {-51, -38, -25, -13, -64, 64, -25, -38, -25, -77, 77, -13, -51, -38, -89, 89, -115, -64, 102, 77};
1809 std::vector<int8_t> inputToOutputWeights =
1810 {-102, -51, -25, -115, -13, -89, 38, -38, -102, -25, 77, -25, 51, -89, -38, -64, 13, 64, -77, -51};
1812 std::vector<int8_t> recurrentToForgetWeights =
1813 {-64, -38, -64, -25, 77, 51, 115, 38, -13, 25, 64, 25, 25, 38, -13, 51};
1814 std::vector<int8_t> recurrentToCellWeights =
1815 {-38, 25, 13, -38, 102, -10, -25, 38, 102, -77, -13, 25, 38, -13, 25, 64};
1816 std::vector<int8_t> recurrentToOutputWeights =
1817 {38, -13, 13, -25, -64, -89, -25, -77, -13, -51, -89, -25, 13, 64, 25, -38};
1819 std::vector<int32_t> forgetGateBias = {2147484, -6442451, -4294968, 2147484};
1820 std::vector<int32_t> cellBias = {-1073742, 15461883, 5368709, 1717987};
1821 std::vector<int32_t> outputGateBias = {1073742, -214748, 4294968, 2147484};
1823 std::vector<int16_t> forgetLayerNormWeights = {6553, 6553, 13107, 9830};
1824 std::vector<int16_t> cellLayerNormWeights = {22937, 6553, 9830, 26214};
1825 std::vector<int16_t> outputLayerNormWeights = {19660, 6553, 6553, 16384};
1896 inputHandle->Allocate();
1897 outputStateInHandle->Allocate();
1898 cellStateInHandle->Allocate();
1900 outputStateOutHandle->Allocate();
1901 cellStateOutHandle->Allocate();
1902 outputHandle->Allocate();
1908 workload->Execute();
1914 outputHandle->GetShape(),
1915 outputStateInfo.GetShape());
1923 const std::vector<int8_t>& input,
1924 const std::vector<int8_t>& outputExpected)
1927 unsigned int numBatches = 2;
1928 unsigned int inputSize = 5;
1929 unsigned int outputSize = 3;
1930 unsigned int numUnits = 4;
1932 bool cifgEnabled =
false;
1933 bool peepholeEnabled =
false;
1934 bool projectionEnabled =
true;
1935 bool layerNormEnabled =
true;
1938 float inputScale = 0.0078125f;
1939 int32_t inputOffset = 0;
1941 int32_t hiddenStateZeroPoint = 0;
1942 float hiddenStateScale = 0.007f;
1945 float outputScale = 3.05176e-05f;
1946 int32_t outputOffset = 0;
1948 float cellStateScale = 3.05176e-05f;
1949 int32_t cellStateOffset = 0;
1951 float weightsScale = 0.00784314f;
1952 int32_t weightsOffset = 0;
1954 float layerNormScale = 3.05182e-05f;
1955 int32_t layerNormOffset = 0;
1957 float biasScale = layerNormScale / 1024;
1958 int32_t biasOffset = 0;
1960 float projectionWeightsScale = 0.00392157f;
1962 float inputIntermediateScale = 0.007059f;
1963 float forgetIntermediateScale = 0.007812f;
1964 float cellIntermediateScale = inputIntermediateScale;
1965 float outputIntermediateScale = forgetIntermediateScale;
1967 float cellClip = 0.0f;
1968 float projectionClip = 0.0f;
1987 std::vector<int8_t> inputVector;
1988 inputVector.assign(input.data(), input.data() + (numBatches * inputSize));
1990 std::vector<int16_t> cellStateInVector = {0, 0, 0, 0, 0, 0, 0, 0};
1992 std::vector<int8_t> outputStateInVector = {0, 0, 0, 0, 0, 0};
1995 std::vector<int16_t> cellStateOutVector = {-14650, 8939, 5771, 6715, -11843, 7847, 1508, 12939};
1997 std::vector<int8_t> outputVector;
1998 outputVector.assign(outputExpected.data(), outputExpected.data() + (numBatches * outputSize));
2000 std::vector<int8_t> actualOutput(outputStateInfo.GetNumElements());
2003 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.
CreateTensorHandle(inputInfo);
2004 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
2006 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
2009 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
2011 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
2013 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.
CreateTensorHandle(outputStateInfo);
2019 AddInputToWorkload(data, info, inputInfo, inputHandle.get());
2020 AddInputToWorkload(data, info, outputStateInfo, outputStateInHandle.get());
2021 AddInputToWorkload(data, info, cellStateInfo, cellStateInHandle.get());
2023 AddOutputToWorkload(data, info, outputStateInfo, outputStateOutHandle.get());
2024 AddOutputToWorkload(data, info, cellStateInfo, cellStateOutHandle.get());
2025 AddOutputToWorkload(data, info, outputStateInfo, outputHandle.get());
2044 projectionWeightsScale,
2048 std::vector<int8_t> inputToInputWeights =
2049 {64, 77, 89, -102, -115, 13, 25, 38, -51, 64, -102, 89, -77, 64, -51, -64, -51, -38, -25, -13};
2050 std::vector<int8_t> inputToForgetWeights =
2051 {-77, -13, 38, 25, 115, -64, -25, -51, 38, -102, -51, 38, -64, -51, -77, 38, -51, -77, -64, -64};
2052 std::vector<int8_t> inputToCellWeights =
2053 {-51, -38, -25, -13, -64, 64, -25, -38, -25, -77, 77, -13, -51, -38, -89, 89, -115, -64, 102, 77};
2054 std::vector<int8_t> inputToOutputWeights =
2055 {-102, -51, -25, -115, -13, -89, 38, -38, -102, -25, 77, -25, 51, -89, -38, -64, 13, 64, -77, -51};
2057 std::vector<int8_t> recurrentToInputWeights = {-25, -38, 51, 13, -64, 115, -25, -38, -89, 6, -25, -77};
2058 std::vector<int8_t> recurrentToForgetWeights = {-64, -38, -64, -25, 77, 51, 115, 38, -13, 25, 64, 25};
2059 std::vector<int8_t> recurrentToCellWeights = {-38, 25, 13, -38, 102, -10, -25, 38, 102, -77, -13, 25};
2060 std::vector<int8_t> recurrentToOutputWeights = {38, -13, 13, -25, -64, -89, -25, -77, -13, -51, -89, -25};
2062 std::vector<int32_t> inputGateBias = {644245, 3221226, 4724464, 8160438};
2063 std::vector<int32_t> forgetGateBias = {2147484, -6442451, -4294968, 2147484};
2064 std::vector<int32_t> cellBias = {-1073742, 15461883, 5368709, 1717987};
2065 std::vector<int32_t> outputGateBias = {1073742, -214748, 4294968, 2147484};
2067 std::vector<int16_t> inputLayerNormWeights = {3277, 6553, 9830, 16384};
2068 std::vector<int16_t> forgetLayerNormWeights = {6553, 6553, 13107, 9830};
2069 std::vector<int16_t> cellLayerNormWeights = {22937, 6553, 9830, 26214};
2070 std::vector<int16_t> outputLayerNormWeights = {19660, 6553, 6553, 16384};
2072 std::vector<int8_t> projectionWeights = {-25, 51, 3, -51, 25, 127, 77, 20, 18, 51, -102, 51};
2161 inputHandle->Allocate();
2162 outputStateInHandle->Allocate();
2163 cellStateInHandle->Allocate();
2165 outputStateOutHandle->Allocate();
2166 cellStateOutHandle->Allocate();
2167 outputHandle->Allocate();
2173 workload->Execute();
2179 outputHandle->GetShape(),
2180 outputStateInfo.GetShape());
2188 const std::vector<int8_t>& input,
2189 const std::vector<int8_t>& outputExpected)
2192 unsigned int numBatches = 2;
2193 unsigned int inputSize = 5;
2194 unsigned int outputSize = 3;
2195 unsigned int numUnits = 4;
2197 bool cifgEnabled =
true;
2198 bool peepholeEnabled =
false;
2199 bool projectionEnabled =
true;
2200 bool layerNormEnabled =
true;
2203 float inputScale = 0.0078125f;
2204 int32_t inputOffset = 0;
2206 int32_t hiddenStateZeroPoint = 0;
2207 float hiddenStateScale = 0.007f;
2210 float outputScale = 3.05176e-05f;
2211 int32_t outputOffset = 0;
2213 float cellStateScale = 3.05176e-05f;
2214 int32_t cellStateOffset = 0;
2216 float weightsScale = 0.00784314f;
2217 int32_t weightsOffset = 0;
2219 float layerNormScale = 3.05182e-05f;
2220 int32_t layerNormOffset = 0;
2222 float biasScale = layerNormScale / 1024;
2223 int32_t biasOffset = 0;
2225 float projectionWeightsScale = 0.00392157f;
2227 float inputIntermediateScale = 0.007059f;
2228 float forgetIntermediateScale = 0.007812f;
2229 float cellIntermediateScale = inputIntermediateScale;
2230 float outputIntermediateScale = forgetIntermediateScale;
2232 float cellClip = 0.0f;
2233 float projectionClip = 0.0f;
2252 std::vector<int8_t> inputVector;
2253 inputVector.assign(input.data(), input.data() + (numBatches * inputSize));
2255 std::vector<int16_t> cellStateInVector = {0, 0, 0, 0, 0, 0, 0, 0};
2257 std::vector<int8_t> outputStateInVector = {0, 0, 0, 0, 0, 0};
2260 std::vector<int16_t> cellStateOutVector = {-14650, 8939, 5771, 6715, -11843, 7847, 1508, 12939};
2262 std::vector<int8_t> outputVector;
2263 outputVector.assign(outputExpected.data(), outputExpected.data() + (numBatches * outputSize));
2265 std::vector<int8_t> actualOutput(outputStateInfo.GetNumElements());
2268 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.
CreateTensorHandle(inputInfo);
2269 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
2271 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
2274 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
2276 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
2278 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.
CreateTensorHandle(outputStateInfo);
2284 AddInputToWorkload(data, info, inputInfo, inputHandle.get());
2285 AddInputToWorkload(data, info, outputStateInfo, outputStateInHandle.get());
2286 AddInputToWorkload(data, info, cellStateInfo, cellStateInHandle.get());
2288 AddOutputToWorkload(data, info, outputStateInfo, outputStateOutHandle.get());
2289 AddOutputToWorkload(data, info, cellStateInfo, cellStateOutHandle.get());
2290 AddOutputToWorkload(data, info, outputStateInfo, outputHandle.get());
2309 projectionWeightsScale,
2313 std::vector<int8_t> inputToForgetWeights =
2314 {-77, -13, 38, 25, 115, -64, -25, -51, 38, -102, -51, 38, -64, -51, -77, 38, -51, -77, -64, -64};
2315 std::vector<int8_t> inputToCellWeights =
2316 {-51, -38, -25, -13, -64, 64, -25, -38, -25, -77, 77, -13, -51, -38, -89, 89, -115, -64, 102, 77};
2317 std::vector<int8_t> inputToOutputWeights =
2318 {-102, -51, -25, -115, -13, -89, 38, -38, -102, -25, 77, -25, 51, -89, -38, -64, 13, 64, -77, -51};
2320 std::vector<int8_t> recurrentToForgetWeights =
2321 {-64, -38, -64, -25, 77, 51, 115, 38, -13, 25, 64, 25};
2322 std::vector<int8_t> recurrentToCellWeights =
2323 {-38, 25, 13, -38, 102, -10, -25, 38, 102, -77, -13, 25};
2324 std::vector<int8_t> recurrentToOutputWeights =
2325 {38, -13, 13, -25, -64, -89, -25, -77, -13, -51, -89, -25};
2327 std::vector<int32_t> forgetGateBias = {2147484, -6442451, -4294968, 2147484};
2328 std::vector<int32_t> cellBias = {-1073742, 15461883, 5368709, 1717987};
2329 std::vector<int32_t> outputGateBias = {1073742, -214748, 4294968, 2147484};
2331 std::vector<int16_t> forgetLayerNormWeights = {6553, 6553, 13107, 9830};
2332 std::vector<int16_t> cellLayerNormWeights = {22937, 6553, 9830, 26214};
2333 std::vector<int16_t> outputLayerNormWeights = {19660, 6553, 6553, 16384};
2335 std::vector<int8_t> projectionWeights = {-25, 51, 3, -51, 25, 127, 77, 20, 18, 51, -102, 51};
2412 inputHandle->Allocate();
2413 outputStateInHandle->Allocate();
2414 cellStateInHandle->Allocate();
2416 outputStateOutHandle->Allocate();
2417 cellStateOutHandle->Allocate();
2418 outputHandle->Allocate();
2424 workload->Execute();
2430 outputHandle->GetShape(),
2431 outputStateInfo.GetShape());
2437 #if defined(ARMNNREF_ENABLED) 2441 void LstmUtilsZeroVectorTest()
2444 std::vector<float> input = {2., 3., 3., 4.};
2445 std::vector<float> expectedOutput = {0., 0., 0., 0.};
2447 return LstmUtilsZeroVectorTestImpl<armnn::DataType::Float32>(input, 4, expectedOutput, inputDesc.GetShape());
2450 void LstmUtilsMeanStddevNormalizationNoneZeroInputTest()
2452 uint32_t batchSize = 2;
2453 uint32_t vecSize = 4;
2455 std::vector<float> input =
2456 { 0.1f, 0.2f, 0.3f, 0.4f,
2457 0.9f, 1.0f, 1.1f, 1.2f };
2459 std::vector<float> expectedOutput =
2460 { -1.34164071f, -0.447213531f, 0.44721365f, 1.34164071f,
2461 -1.34163153f, -0.447210163f, 0.447211236f, 1.3416326f };
2463 return LstmUtilsMeanStddevNormalizationTestImpl<armnn::DataType::Float32>(input,
2464 vecSize, batchSize, expectedOutput, inputDesc.GetShape());
2467 void LstmUtilsMeanStddevNormalizationAllZeroInputTest()
2469 uint32_t batchSize = 2;
2470 uint32_t vecSize = 4;
2472 std::vector<float> input =
2473 { 0.0f, 0.0f, 0.0f, 0.0f,
2474 0.0f, 0.0f, 0.0f, 0.0f };
2476 std::vector<float> expectedOutput =
2477 { 0.0f, 0.0f, 0.0f, 0.0f,
2478 0.0f, 0.0f, 0.0f, 0.0f };
2480 return LstmUtilsMeanStddevNormalizationTestImpl<armnn::DataType::Float32>(input,
2481 vecSize, batchSize, expectedOutput, inputDesc.GetShape());
2484 void LstmUtilsMeanStddevNormalizationMixedZeroInputTest()
2486 uint32_t batchSize = 2;
2487 uint32_t vecSize = 4;
2489 std::vector<float> input =
2490 { 0.0f, 0.0f, 0.0f, 0.0f,
2491 0.1f, 0.2f, 0.3f, 0.4f };
2493 std::vector<float> expectedOutput =
2494 { 0.0f, 0.0f, 0.0f, 0.0f,
2495 -1.34164071f, -0.447213531f, 0.44721365f, 1.34164071f };
2497 return LstmUtilsMeanStddevNormalizationTestImpl<armnn::DataType::Float32>(input,
2498 vecSize, batchSize, expectedOutput, inputDesc.GetShape());
2501 void LstmUtilsVectorBatchVectorCwiseProductTest()
2503 uint32_t batchSize = 4;
2504 uint32_t vecSize = 29;
2506 std::vector<float> vector =
2507 { 1.1f, 2.2f, 3.3f, 4.4f, 5.5f, 6.6f, 7.7f, 8.8f, 9.9f, 10.1f,
2508 11.11f, 12.12f, 13.13f, 14.14f, 15.15f, 16.16f, 17.17f, 18.18f, 19.19f, 20.2f,
2509 21.21f, 22.22f, 23.23f, 24.24f, 25.25f, 26.26f, 27.27f, 28.28f, 0.0f};
2512 std::vector<float> batchVector =
2514 1.1f, 2.2f, 3.3f, 4.4f, 5.5f, 6.6f, 7.7f, 8.8f, 9.9f, 10.1f,
2515 11.11f, 12.12f, 13.13f, 14.14f, 15.15f, 16.16f, 17.17f, 18.18f, 19.19f, 20.2f,
2516 21.21f, 22.22f, 23.23f, 24.24f, 25.25f, 26.26f, 27.27f, 28.28f, 0.0f,
2518 -1.1f, -2.2f, -3.3f, -4.4f, -5.5f, -6.6f, -7.7f, -8.8f, -9.9f, -10.1f,
2519 -11.11f, -12.12f, -13.13f, -14.14f, -15.15f, -16.16f, -17.17f, -18.18f, -19.19f, -20.2f,
2520 -21.21f, -22.22f, -23.23f, -24.24f, -25.25f, -26.26f, -27.27f, -28.28f, 0.0f,
2522 1.1f, -2.2f, 3.3f, -4.4f, 5.5f, -6.6f, 7.7f, -8.8f, 9.9f, -10.1f,
2523 11.11f, -12.12f, 13.13f, -14.14f, 15.15f, -16.16f, 17.17f, -18.18f, 19.19f, -20.2f,
2524 21.21f, -22.22f, 23.23f, -24.24f, 25.25f, -26.26f, 27.27f, -28.28f, 0.0f,
2526 -1.1f, 2.2f, -3.3f, 4.4f, -5.5f, 6.6f, -7.7f, 8.8f, -9.9f, 10.1f,
2527 -11.11f, 12.12f, -13.13f, 14.14f, -15.15f, 16.16f, -17.17f, 18.18f, -19.19f, 20.2f,
2528 -21.21f, 22.22f, -23.23f, 24.24f, -25.25f, 26.26f, -27.27f, 28.28f, 0.0f};
2531 std::vector<float> expectedOutput =
2533 1.210000f, 4.840000f, 10.889999f, 19.360001f, 30.250000f, 43.559998f,
2534 59.289997f, 77.440002f, 98.009995f, 102.010010f, 123.432091f, 146.894394f,
2535 172.396896f, 199.939606f, 229.522491f, 261.145599f, 294.808899f, 330.512421f,
2536 368.256134f, 408.040039f, 449.864075f, 493.728363f, 539.632874f, 587.577576f,
2537 637.562500f, 689.587585f, 743.652954f, 799.758423f, 0.000000f,
2539 -1.210000f, -4.840000f, -10.889999f, -19.360001f, -30.250000f, -43.559998f,
2540 -59.289997f, -77.440002f, -98.009995f, -102.010010f, -123.432091f, -146.894394f,
2541 -172.396896f, -199.939606f, -229.522491f, -261.145599f, -294.808899f, -330.512421f,
2542 -368.256134f, -408.040039f, -449.864075f, -493.728363f, -539.632874f, -587.577576f,
2543 -637.562500f, -689.587585f, -743.652954f, -799.758423f, 0.000000f,
2545 1.210000f, -4.840000f, 10.889999f, -19.360001f, 30.250000f, -43.559998f,
2546 59.289997f, -77.440002f, 98.009995f, -102.010010f, 123.432091f, -146.894394f,
2547 172.396896f, -199.939606f, 229.522491f, -261.145599f, 294.808899f, -330.512421f,
2548 368.256134f, -408.040039f, 449.864075f, -493.728363f, 539.632874f, -587.577576f,
2549 637.562500f, -689.587585f, 743.652954f, -799.758423f, 0.000000f,
2551 -1.210000f, 4.840000f, -10.889999f, 19.360001f, -30.250000f, 43.559998f,
2552 -59.289997f, 77.440002f, -98.009995f, 102.010010f, -123.432091f, 146.894394f,
2553 -172.396896f, 199.939606f, -229.522491f, 261.145599f, -294.808899f, 330.512421f,
2554 -368.256134f, 408.040039f, -449.864075f, 493.728363f, -539.632874f, 587.577576f,
2555 -637.562500f, 689.587585f, -743.652954f, 799.758423f, 0.000000f};
2557 return LstmUtilsVectorBatchVectorCwiseProductTestImpl<armnn::DataType::Float32>(vector, batchVector,
2558 vecSize, batchSize, expectedOutput, vecDesc.GetShape());
2561 void LstmUtilsVectorBatchVectorAddTest()
2563 uint32_t batchSize = 2;
2564 uint32_t vecSize = 3;
2566 std::vector<float> vector = { 0.0f, -0.5f, 1.0f};
2569 std::vector<float> batchVector =
2575 std::vector<float> expectedOutput =
2581 return LstmUtilsVectorBatchVectorAddTestImpl<armnn::DataType::Float32>(vector, batchVector,
2582 vecSize, batchSize, expectedOutput, batchVecDesc.GetShape());
2593 std::vector<float> input = { 2., 3., 3., 4. };
2596 std::vector<float> expectedOutput =
2597 {-0.36444446f, -0.00352185f, 0.12886585f, -0.05163646f,
2598 -0.42734814f, -0.00478661f, 0.13455015f, -0.03560682f};
2599 return LstmLayerWithCifgWithPeepholeNoProjectionTestImpl<armnn::DataType::Float32>(
2600 workloadFactory, memoryManager, tensorHandleFactory,
2601 input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape());
2610 std::vector<float> input =
2611 {0.787926f, 0.151646f, 0.071352f, 0.118426f, 0.458058f,
2612 0.295743f, 0.544053f, 0.690064f, 0.858138f, 0.497181f};
2615 std::vector<float> expectedOutput =
2616 {-0.00396806f, 0.029352f, -0.00279226f, 0.0159977f, -0.00835576f,
2617 -0.0211779f, 0.0283512f, -0.0114597f, 0.00907307f, -0.0244004f,
2618 -0.0152191f, -0.0259063f, 0.00914318f, 0.00415118f, 0.017147f,
2619 0.0134203f, -0.013869f, 0.0287268f, -0.00334693f, 0.00733398f, -0.0287926f,
2620 -0.0186926f, 0.0193662f, -0.0115437f, 0.00422612f, -0.0345232f,
2621 0.00223253f, -0.00957321f, 0.0210624f, 0.013331f, 0.0150954f, 0.02168f};
2622 return LstmLayerNoCifgWithPeepholeWithProjectionTestImpl<armnn::DataType::Float32>(
2623 workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput);
2632 std::vector<float> input = {2., 3., 3., 4.};
2635 std::vector<float> expectedOutput =
2636 {-0.02973187f, 0.1229473f, 0.20885126f, -0.15358765f,
2637 -0.0185422f, 0.11281417f, 0.24466537f, -0.1826292f};
2639 return LstmNoCifgNoPeepholeNoProjectionTestImpl<armnn::DataType::Float32>(
2640 workloadFactory, memoryManager, tensorHandleFactory,
2641 input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape());
2650 std::vector<float> input =
2651 {0.7f, 0.8f, 0.1f, 0.2f, 0.3f,
2652 0.3f, 0.2f, 0.9f, 0.8f, 0.1f};
2655 std::vector<float> expectedOutput =
2656 { 0.0244077f, 0.128027f, -0.00170918f,
2657 -0.00692428f, 0.0848741f, 0.063445f};
2658 return LstmLayerNoCifgWithPeepholeWithProjectionWithLayerNormTestImpl<armnn::DataType::Float32>(
2659 workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput);
2667 const float qScale = 1.0f;
2668 const int32_t qOffset = 0;
2674 std::vector<int16_t> input = armnnUtils::QuantizedVector<int16_t>({ 2.f, 3.f, 3.f, 4.f }, qScale, qOffset);
2677 std::vector<int16_t> expectedOutput = armnnUtils::QuantizedVector<int16_t>(
2679 -0.02973187f, 0.12294730f, 0.20885126f, -0.15358765f,
2680 -0.01854220f, 0.11281417f, 0.24466537f, -0.18262920f
2684 return LstmNoCifgNoPeepholeNoProjectionTestImpl<datatype>(
2685 workloadFactory, memoryManager, tensorHandleFactory,
2686 input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape(),
2687 qScale, qOffset, constantDatatype);
2696 const float qScale = 1.0f;
2697 const int32_t qOffset = 0;
2703 std::vector<int16_t> input = armnnUtils::QuantizedVector<int16_t>({ 2.f, 3.f, 3.f, 4.f }, qScale, qOffset);
2706 std::vector<int16_t> expectedOutput = armnnUtils::QuantizedVector<int16_t>(
2708 -0.36444446f, -0.00352185f, 0.12886585f, -0.05163646f,
2709 -0.42734814f, -0.00478661f, 0.13455015f, -0.03560682f
2713 return LstmLayerWithCifgWithPeepholeNoProjectionTestImpl<datatype>(
2714 workloadFactory, memoryManager, tensorHandleFactory,
2715 input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape(),
2716 qScale, qOffset, constantDatatype);
2724 const float qScale = 2.0f;
2725 const int32_t qOffset = 0;
2731 std::vector<int16_t> input = armnnUtils::QuantizedVector<int16_t>(
2733 0.787926f, 0.151646f, 0.071352f, 0.118426f, 0.458058f,
2734 0.295743f, 0.544053f, 0.690064f, 0.858138f, 0.497181f
2739 std::vector<int16_t> expectedOutput = armnnUtils::QuantizedVector<int16_t>(
2741 -0.00396806f, 0.02935200f, -0.00279226f, 0.01599770f,
2742 -0.00835576f, -0.02117790f, 0.02835120f, -0.01145970f,
2743 0.00907307f, -0.02440040f, -0.01521910f, -0.02590630f,
2744 0.00914318f, 0.00415118f, 0.01714700f, 0.01342030f,
2745 -0.01386900f, 0.02872680f, -0.00334693f, 0.00733398f,
2746 -0.02879260f, -0.01869260f, 0.01936620f, -0.01154370f,
2747 0.00422612f, -0.03452320f, 0.00223253f, -0.00957321f,
2748 0.02106240f, 0.01333100f, 0.01509540f, 0.02168000f
2752 return LstmLayerNoCifgWithPeepholeWithProjectionTestImpl<datatype>(
2753 workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput, qScale, qOffset, constantDatatype);
2761 const float qScale = 1.0f;
2762 const int32_t qOffset = 0;
2767 std::vector<int16_t> input = armnnUtils::QuantizedVector<int16_t>({ 2.f, 3.f, 3.f, 4.f }, qScale, qOffset);
2770 std::vector<int16_t> expectedOutput = armnnUtils::QuantizedVector<int16_t>(
2772 -0.02973187f, 0.12294730f, 0.20885126f, -0.15358765f,
2773 -0.01854220f, 0.11281417f, 0.24466537f, -0.18262920f
2777 return LstmNoCifgNoPeepholeNoProjectionTestImpl<datatype>(
2778 workloadFactory, memoryManager, tensorHandleFactory,
2779 input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape(),
2780 qScale, qOffset, datatype);
2793 std::vector<uint8_t> input = {166, 179, 50, 150};
2796 std::vector<uint8_t> expectedOutput = {140, 151, 146, 112, 136, 156, 142, 112 };
2798 return QuantizedLstmTestImpl(workloadFactory, memoryManager, tensorHandleFactory,
2799 input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape());
2809 std::vector<int8_t> input = {90, 102, 13, 26, 38, 102, 13, 26, 51, 64};
2812 std::vector<int8_t> expectedOutput = {-15, 21, 14, 20, -15, 15, 5, 27};
2814 return QLstmTestImpl(workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput);
2823 std::vector<int8_t> input = {90, 102, 13, 26, 38, 102, 13, 26, 51, 64};
2826 std::vector<int8_t> expectedOutput = {127, 127, -108, -67, 127, 127};
2828 return QLstmTestImpl1(workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput);
2837 std::vector<int8_t> input = {90, 102, 13, 26, 38, 102, 13, 26, 51, 64};
2840 std::vector<int8_t> expectedOutput = {127, 127, 127, -128, 127, 127};
2842 return QLstmTestImpl2(workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput);
const ConstTensorHandle * m_CellLayerNormWeights
void MeanStddevNormalization(armnn::Decoder< float > &input_vector, armnn::Encoder< float > &output_vector, uint32_t v_size, uint32_t n_batch, float normalization_epsilon)
void VectorBatchVectorAdd(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Decoder< float > &batchVector, uint32_t nBatch, armnn::Encoder< float > &outResult)
LayerTestResult< float, 2 > LstmLayerFloat32NoCifgWithPeepholeWithProjectionTest(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager, const armnn::ITensorHandleFactory &tensorHandleFactory)
const ConstTensorHandle * m_ProjectionWeights
LayerTestResult< int16_t, 2 > LstmLayerInt16WithCifgWithPeepholeNoProjectionTest(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager, const armnn::ITensorHandleFactory &tensorHandleFactory)
bool m_ProjectionEnabled
Enable/disable the projection layer.
const ConstTensorHandle * m_ProjectionWeights
const ConstTensorHandle * m_RecurrentToForgetWeights
const ConstTensorHandle * m_ForgetGateBias
const ConstTensorHandle * m_InputToOutputWeights
float m_ClippingThresProj
Clipping threshold value for the projection.
const ConstTensorHandle * m_InputGateBias
LayerTestResult< int8_t, 2 > QLstmTest1(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager, const armnn::ITensorHandleFactory &tensorHandleFactory)
const ConstTensorHandle * m_RecurrentToCellWeights
bool m_PeepholeEnabled
Enable/disable peephole.
const ConstTensorHandle * m_CellBias
float m_HiddenStateScale
Hidden State quantization scale.
LayerTestResult< int16_t, 2 > LstmLayerInt16NoCifgWithPeepholeWithProjectionTest(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager, const armnn::ITensorHandleFactory &tensorHandleFactory)
const ConstTensorHandle * m_InputToInputWeights
float m_OutputIntermediateScale
Output intermediate quantization scale.
LayerTestResult< int8_t, 2 > QLstmTest(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager, const armnn::ITensorHandleFactory &tensorHandleFactory)
const ConstTensorHandle * m_InputGateBias
const ConstTensorHandle * m_InputToOutputWeights
const ConstTensorHandle * m_OutputLayerNormWeights
void ZeroVector(armnn::Encoder< float > &vector, uint32_t vSize)
void IgnoreUnused(Ts &&...)
const ConstTensorHandle * m_RecurrentToInputWeights
LayerDescriptor m_Parameters
void AllocateAndCopyDataToITensorHandle(armnn::ITensorHandle *tensorHandle, const void *memory)
void VectorBatchVectorCwiseProduct(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Decoder< float > &batchVector, uint32_t nBatch, armnn::Encoder< float > &outResult)
const ConstTensorHandle * m_ForgetLayerNormWeights
const ConstTensorHandle * m_OutputGateBias
LayerTestResult< float, 2 > LstmLayerFloat32WithCifgWithPeepholeNoProjectionTest(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager, const armnn::ITensorHandleFactory &tensorHandleFactory)
const ConstTensorHandle * m_InputToForgetWeights
bool m_LayerNormEnabled
Enable/disable layer normalization.
const ConstTensorHandle * m_CellLayerNormWeights
std::shared_ptr< IMemoryManager > IMemoryManagerSharedPtr
float m_ProjectionClip
Clipping threshold value for the projection.
const ConstTensorHandle * m_InputToForgetWeights
const ConstTensorHandle * m_RecurrentToCellWeights
const ConstTensorHandle * m_ForgetGateBias
float m_InputIntermediateScale
Input intermediate quantization scale.
bool m_PeepholeEnabled
Enable/disable peephole.
void CopyDataFromITensorHandle(void *mem, const armnn::ITensorHandle *tensorHandle)
const ConstTensorHandle * m_CellBias
const ConstTensorHandle * m_CellToOutputWeights
LayerTestResult< uint8_t, 2 > QuantizedLstmTest(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager, const armnn::ITensorHandleFactory &tensorHandleFactory)
const ConstTensorHandle * m_RecurrentToOutputWeights
const ConstTensorHandle * m_InputLayerNormWeights
const ConstTensorHandle * m_InputToCellWeights
const ConstTensorHandle * m_OutputGateBias
const ConstTensorHandle * m_InputToForgetWeights
uint32_t m_ActivationFunc
The activation function to use.
const ConstTensorHandle * m_RecurrentToForgetWeights
LayerTestResult< float, 2 > LstmLayerFloat32NoCifgNoPeepholeNoProjectionTest(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager, const armnn::ITensorHandleFactory &tensorHandleFactory)
float m_ClippingThresCell
Clipping threshold value for the cell state.
armnn::PredicateResult CompareTensors(const std::vector< T > &actualData, const std::vector< T > &expectedData, const armnn::TensorShape &actualShape, const armnn::TensorShape &expectedShape, bool compareBoolean=false, bool isDynamic=false)
const ConstTensorHandle * m_RecurrentToInputWeights
float m_ForgetIntermediateScale
Forget intermediate quantization scale.
const ConstTensorHandle * m_InputToCellWeights
const ConstTensorHandle * m_ForgetGateBias
float m_CellClip
Clipping threshold value for the cell state.
bool m_CifgEnabled
Enable/disable cifg (coupled input & forget gate).
LayerTestResult< int16_t, 2 > LstmLayerInt16NoCifgNoPeepholeNoProjectionTest(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager, const armnn::ITensorHandleFactory &tensorHandleFactory)
const ConstTensorHandle * m_InputToOutputWeights
LayerTestResult< int16_t, 2 > LstmLayerInt16NoCifgNoPeepholeNoProjectionInt16ConstantTest(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager, const armnn::ITensorHandleFactory &tensorHandleFactory)
const ConstTensorHandle * m_CellToForgetWeights
void CopyDataToITensorHandle(armnn::ITensorHandle *tensorHandle, const void *memory)
const ConstTensorHandle * m_RecurrentToCellWeights
bool m_ProjectionEnabled
Enable/disable the projection layer.
const ConstTensorHandle * m_InputGateBias
const ConstTensorHandle * m_InputToInputWeights
const ConstTensorHandle * m_CellBias
const ConstTensorHandle * m_ProjectionBias
bool m_LayerNormEnabled
Enable/disable layer normalization.
std::enable_if_t< std::is_unsigned< Source >::value &&std::is_unsigned< Dest >::value, Dest > numeric_cast(Source source)
const ConstTensorHandle * m_ForgetLayerNormWeights
Contains information about TensorInfos of a layer.
const ConstTensorHandle * m_InputLayerNormWeights
const ConstTensorHandle * m_OutputGateBias
LayerTestResult< float, 2 > LstmLayerFloat32NoCifgWithPeepholeWithProjectionWithLayerNormTest(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager, const armnn::ITensorHandleFactory &tensorHandleFactory)
const ConstTensorHandle * m_OutputLayerNormWeights
const ConstTensorHandle * m_InputToCellWeights
const ConstTensorHandle * m_CellToInputWeights
virtual std::unique_ptr< IWorkload > CreateWorkload(LayerType type, const QueueDescriptor &descriptor, const WorkloadInfo &info) const
float m_CellIntermediateScale
Cell intermediate quantization scale.
virtual std::unique_ptr< ITensorHandle > CreateTensorHandle(const TensorInfo &tensorInfo) const =0
const ConstTensorHandle * m_InputToInputWeights
const ConstTensorHandle * m_RecurrentToOutputWeights
const ConstTensorHandle * m_RecurrentToInputWeights
bool m_CifgEnabled
Enable/disable CIFG (coupled input & forget gate).
LayerTestResult< int8_t, 2 > QLstmTest2(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager, const armnn::ITensorHandleFactory &tensorHandleFactory)
const ConstTensorHandle * m_RecurrentToForgetWeights
int32_t m_HiddenStateZeroPoint
Hidden State zero point.
const ConstTensorHandle * m_RecurrentToOutputWeights