From e5339e7013cf24e5a34509fb0a60377e5f8a244e Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Wed, 28 Jul 2021 17:33:28 +0100 Subject: MLCE-530 Add support for UnidirectionalSequenceLstm to RefWorkload * Add implementation of IsUnidirectionalSequenceLstmSupported to RefLayerSupport * Add RefUnidirectionalSequenceLstmWorkload * Refactor Lstm to be able to use for Lstm and SequenceLstm * Unit tests Signed-off-by: Narumol Prangnawarat Change-Id: Ibc066d213213a11b955dfefbe518de643298ba0c --- include/armnn/Descriptors.hpp | 2 +- src/backends/backendsCommon/WorkloadData.cpp | 4 +- src/backends/backendsCommon/common.mk | 3 +- src/backends/backendsCommon/test/CMakeLists.txt | 2 + src/backends/backendsCommon/test/LayerTests.hpp | 1 + .../UnidirectionalSequenceLstmTestImpl.cpp | 1030 ++++++++++++++++++++ .../UnidirectionalSequenceLstmTestImpl.hpp | 36 + src/backends/reference/RefLayerSupport.cpp | 147 +++ src/backends/reference/RefLayerSupport.hpp | 11 + src/backends/reference/RefWorkloadFactory.cpp | 7 + src/backends/reference/RefWorkloadFactory.hpp | 4 + src/backends/reference/backend.mk | 2 + src/backends/reference/test/RefLayerTests.cpp | 12 + src/backends/reference/workloads/CMakeLists.txt | 4 + src/backends/reference/workloads/Lstm.cpp | 259 +++++ src/backends/reference/workloads/Lstm.hpp | 61 ++ .../reference/workloads/RefLstmWorkload.cpp | 235 +---- .../RefUnidirectionalSequenceLstmWorkload.cpp | 307 ++++++ .../RefUnidirectionalSequenceLstmWorkload.hpp | 56 ++ src/backends/reference/workloads/RefWorkloads.hpp | 1 + 20 files changed, 1991 insertions(+), 193 deletions(-) create mode 100644 src/backends/backendsCommon/test/layerTests/UnidirectionalSequenceLstmTestImpl.cpp create mode 100644 src/backends/backendsCommon/test/layerTests/UnidirectionalSequenceLstmTestImpl.hpp create mode 100644 src/backends/reference/workloads/Lstm.cpp create mode 100644 src/backends/reference/workloads/Lstm.hpp create mode 100644 src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp create mode 100644 src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.hpp diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp index 7188a7bd3a..f4a5482768 100644 --- a/include/armnn/Descriptors.hpp +++ b/include/armnn/Descriptors.hpp @@ -926,7 +926,7 @@ struct LstmDescriptor : BaseDescriptor , m_PeepholeEnabled(false) , m_ProjectionEnabled(false) , m_LayerNormEnabled(false) - , m_TimeMajor(true) + , m_TimeMajor(false) {} bool operator ==(const LstmDescriptor& rhs) const diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 319cdb106b..d87f858601 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -3734,9 +3734,7 @@ void UnidirectionalSequenceLstmQueueDescriptor::Validate(const WorkloadInfo& wor std::vector supportedTypes = { - DataType::Float16, - DataType::Float32, - DataType::QAsymmS8 + DataType::Float32 }; // check for supported type of one input and match them with all the other input and output diff --git a/src/backends/backendsCommon/common.mk b/src/backends/backendsCommon/common.mk index ff9375dec1..5d339477d5 100644 --- a/src/backends/backendsCommon/common.mk +++ b/src/backends/backendsCommon/common.mk @@ -93,7 +93,8 @@ COMMON_TEST_SOURCES := \ test/layerTests/StackTestImpl.cpp \ test/layerTests/StridedSliceTestImpl.cpp \ test/layerTests/SubtractionTestImpl.cpp \ - test/layerTests/TransposeConvolution2dTestImpl.cpp + test/layerTests/TransposeConvolution2dTestImpl.cpp \ + test/layerTests/UnidirectionalSequenceLstmTestImpl.cpp ifeq ($(ARMNN_REF_ENABLED),1) COMMON_TEST_SOURCES += \ diff --git a/src/backends/backendsCommon/test/CMakeLists.txt b/src/backends/backendsCommon/test/CMakeLists.txt index 755cd21683..4561fd7739 100644 --- a/src/backends/backendsCommon/test/CMakeLists.txt +++ b/src/backends/backendsCommon/test/CMakeLists.txt @@ -169,6 +169,8 @@ list(APPEND armnnBackendsCommonUnitTests_sources layerTests/SubtractionTestImpl.hpp layerTests/TransposeConvolution2dTestImpl.cpp layerTests/TransposeConvolution2dTestImpl.hpp + layerTests/UnidirectionalSequenceLstmTestImpl.cpp + layerTests/UnidirectionalSequenceLstmTestImpl.hpp ) if (ARMNNREF) diff --git a/src/backends/backendsCommon/test/LayerTests.hpp b/src/backends/backendsCommon/test/LayerTests.hpp index 46eb6ee2a5..fcb1f71436 100644 --- a/src/backends/backendsCommon/test/LayerTests.hpp +++ b/src/backends/backendsCommon/test/LayerTests.hpp @@ -67,3 +67,4 @@ #include #include #include +#include diff --git a/src/backends/backendsCommon/test/layerTests/UnidirectionalSequenceLstmTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/UnidirectionalSequenceLstmTestImpl.cpp new file mode 100644 index 0000000000..ac22d5df48 --- /dev/null +++ b/src/backends/backendsCommon/test/layerTests/UnidirectionalSequenceLstmTestImpl.cpp @@ -0,0 +1,1030 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "UnidirectionalSequenceLstmTestImpl.hpp" + +#include + +#include + +#include +#include + +#include + +namespace { + +template> +LayerTestResult UnidirectionalSequenceLstmLayerFloat32TestImpl( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory, + const std::vector& input, + const std::vector& outputExpected, + const armnn::TensorShape& inputShape, + const armnn::TensorShape& outputExpectedShape, + float qScale = 0.0f, + int32_t qOffset = 0, + armnn::DataType constantDataType = armnn::DataType::Float32) { + IgnoreUnused(memoryManager); + unsigned int batchSize = armnn::numeric_cast(inputShape[0]); + unsigned int timeSize = armnn::numeric_cast(inputShape[1]); + unsigned int inputSize = armnn::numeric_cast(inputShape[2]); + unsigned int outputSize = armnn::numeric_cast(outputExpectedShape[2]); + unsigned numUnits = outputSize; + + armnn::TensorInfo inputTensorInfo({batchSize, timeSize, inputSize}, ArmnnType, qScale, qOffset); + armnn::TensorInfo cellStateInTensorInfo({batchSize, numUnits}, ArmnnType, qScale, qOffset); + armnn::TensorInfo outputStateInTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset); + + armnn::TensorInfo outputTensorInfo({batchSize, timeSize, outputSize}, ArmnnType, qScale, qOffset); + + std::vector inputVector; + inputVector.assign(input.data(), input.data() + (batchSize * timeSize * inputSize)); + + std::vector cellStateInVector(batchSize * numUnits, T()); + std::vector outputStateInVector(batchSize * outputSize, T()); + + std::vector actualOutput(outputTensorInfo.GetNumElements()); + + std::vector outputVector; + outputVector.assign(outputExpected.data(), outputExpected.data() + (batchSize * timeSize * outputSize)); + + std::unique_ptr inputHandle = tensorHandleFactory.CreateTensorHandle(inputTensorInfo); + std::unique_ptr cellStateInHandle = + tensorHandleFactory.CreateTensorHandle(cellStateInTensorInfo); + std::unique_ptr outputStateInHandle = + tensorHandleFactory.CreateTensorHandle(outputStateInTensorInfo); + + std::unique_ptr outputHandle = tensorHandleFactory.CreateTensorHandle(outputTensorInfo); + + armnn::UnidirectionalSequenceLstmQueueDescriptor data; + armnn::WorkloadInfo info; + + AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get()); + AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get()); + AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get()); + + AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get()); + + armnn::TensorInfo tensorInfo4({numUnits}, constantDataType, qScale, qOffset); + armnn::TensorInfo tensorInfo12({numUnits, 3}, constantDataType, qScale, qOffset); + armnn::TensorInfo tensorInfo16({numUnits, 4}, constantDataType, qScale, qOffset); + + std::vector inputToInputWeights = { -0.49536117f, -0.0556083915f, -0.102400711f, + -0.117484632f, 0.3298470976f, -0.1179017122f, + 0.214305695f, 0.42135173085f, 0.003878414626f, + -0.348303917f, -0.1881275477f, 0.0343011027f }; + + std::vector inputToForgetWeights = { 0.2415594226f, 0.15400093799f, 0.4566498398f, + -0.3810434485f, 0.268383264f, -0.009807467424f, + -0.3522925403f, -0.24275735512f, -0.28344226125f, + 0.13512269116f, -0.4932442977f, -0.10039821991f }; + + std::vector inputToCellWeights = { -0.2504855627f, 0.184490025045f, -0.2480507493f, + 0.386399507f, -0.259465157985f, -0.16545993089f, + -0.4230232555f, 0.341664791103f, -0.18127849691f, + -0.2277662414f, -0.55275535589f, 0.34184026718f }; + + std::vector inputToOutputWeights = { 0.2303854227f, 0.5218806862f, -0.4865379333f, + 0.53969591851f, 0.23393625035f, -0.27140527306f, + 0.50009280443f, 0.07511717046f, 0.3998299249f, + -0.51717478049f, 0.1889653282f, -0.367323637f }; + + std::vector recurrentToInputWeights = { -0.128009796112f, 0.1995525098f, -0.07745539397f, 0.1558421701f, + -0.265254765766f, -0.38837709614f, -0.05636804124f, 0.4259087456f, + 0.17628988623f, 0.3877420127f, 0.53300309181f, -0.0959980934f, + 0.00302857416f, 0.3266998827f, -0.142509296562f, -0.04433270756f }; + + std::vector recurrentToForgetWeights = { -0.09499983487f, -0.08814888417f, -0.04834804721f, 0.1516668247f, + -0.3967529535f, -0.06463699788f, 0.4952811002f, 0.003274492938f, + -0.0968840941f, 0.17928104102f, 0.0031281141592f, -0.3387276584f, + -0.3587934076f, 0.06705895066f, 0.22463923692f, 0.1961955726f }; + + std::vector recurrentToCellWeights = { -0.21938985582f, -0.3023648226f, -0.1170005202f, -0.3509177422f, + -0.4286288613f, 0.2726137042f, 0.09216640889f, -0.06551410215f, + 0.20453298098f, 0.2393476665f, 0.11846517771f, 0.2630801796f, + 0.3954237699f, -0.19407111404f, 0.30412107706f, -0.27342408554f }; + + std::vector recurrentToOutputWeights = { -0.32921677827f, 0.32624614238f, -0.1388191282f, -0.17879831790f, + -0.15185534954f, -0.16918526583f, -0.10087361183f, -0.5436913968f, + 0.016758225858f, 0.30454617738f, -0.41493862867f, -0.005565764375f, + -0.12584099173f, -0.12319286912f, 0.2407919466f, -0.08879069983f }; + + std::vector inputGateBias = { 0., 0., 0., 0. }; + + std::vector forgetGateBias = { 1., 1., 1., 1. }; + + std::vector cellBias = { 0., 0., 0., 0. }; + + std::vector outputGateBias = { 0., 0., 0., 0. }; + + armnn::ScopedTensorHandle inputToInputWeightsTensor(tensorInfo12); + armnn::ScopedTensorHandle inputToForgetWeightsTensor(tensorInfo12); + armnn::ScopedTensorHandle inputToCellWeightsTensor(tensorInfo12); + armnn::ScopedTensorHandle inputToOutputWeightsTensor(tensorInfo12); + armnn::ScopedTensorHandle recurrentToInputWeightsTensor(tensorInfo16); + armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(tensorInfo16); + armnn::ScopedTensorHandle recurrentToCellWeightsTensor(tensorInfo16); + armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(tensorInfo16); + armnn::ScopedTensorHandle inputGateBiasTensor(tensorInfo4); + armnn::ScopedTensorHandle forgetGateBiasTensor(tensorInfo4); + armnn::ScopedTensorHandle cellBiasTensor(tensorInfo4); + armnn::ScopedTensorHandle outputGateBiasTensor(tensorInfo4); + + AllocateAndCopyDataToITensorHandle(&inputToInputWeightsTensor, inputToInputWeights.data()); + AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data()); + AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data()); + AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data()); + AllocateAndCopyDataToITensorHandle(&recurrentToInputWeightsTensor, recurrentToInputWeights.data()); + AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data()); + AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data()); + AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data()); + AllocateAndCopyDataToITensorHandle(&inputGateBiasTensor, inputGateBias.data()); + AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data()); + AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data()); + AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data()); + + data.m_InputToInputWeights = &inputToInputWeightsTensor; + data.m_InputToForgetWeights = &inputToForgetWeightsTensor; + data.m_InputToCellWeights = &inputToCellWeightsTensor; + data.m_InputToOutputWeights = &inputToOutputWeightsTensor; + data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor; + data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor; + data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor; + data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor; + data.m_InputGateBias = &inputGateBiasTensor; + data.m_ForgetGateBias = &forgetGateBiasTensor; + data.m_CellBias = &cellBiasTensor; + data.m_OutputGateBias = &outputGateBiasTensor; + + // Flags to set test configuration + data.m_Parameters.m_ClippingThresCell = 10; + data.m_Parameters.m_ClippingThresProj = 0; + data.m_Parameters.m_ActivationFunc = 4; + data.m_Parameters.m_CifgEnabled = false; + data.m_Parameters.m_PeepholeEnabled = false; + data.m_Parameters.m_ProjectionEnabled = false; + data.m_Parameters.m_TimeMajor = false; + + std::unique_ptr workload = workloadFactory.CreateUnidirectionalSequenceLstm(data, info); + inputHandle->Allocate(); + outputStateInHandle->Allocate(); + cellStateInHandle->Allocate(); + + outputHandle->Allocate(); + + CopyDataToITensorHandle(inputHandle.get(), inputVector.data()); + CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data()); + CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data()); + + workload->Execute(); + + CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get()); + + return LayerTestResult(actualOutput, + outputVector, + outputHandle->GetShape(), + outputTensorInfo.GetShape()); +} + +template> +LayerTestResult +UnidirectionalSequenceLstmLayerFloat32TimeMajorTestImpl( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory, + const std::vector& input, + const std::vector& outputExpected, + const armnn::TensorShape& inputShape, + const armnn::TensorShape& outputExpectedShape, + float qScale = 0.0f, + int32_t qOffset = 0, + armnn::DataType constantDataType = armnn::DataType::Float32) { + IgnoreUnused(memoryManager); + unsigned int batchSize = armnn::numeric_cast(inputShape[1]); + unsigned int timeSize = armnn::numeric_cast(inputShape[0]); + unsigned int inputSize = armnn::numeric_cast(inputShape[2]); + unsigned int outputSize = armnn::numeric_cast(outputExpectedShape[2]); + unsigned numUnits = outputSize; + + armnn::TensorInfo inputTensorInfo({timeSize, batchSize, inputSize}, ArmnnType, qScale, qOffset); + armnn::TensorInfo cellStateInTensorInfo({batchSize, numUnits}, ArmnnType, qScale, qOffset); + armnn::TensorInfo outputStateInTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset); + + armnn::TensorInfo outputTensorInfo({timeSize, batchSize, outputSize}, ArmnnType, qScale, qOffset); + + std::vector inputVector; + inputVector.assign(input.data(), input.data() + (batchSize * timeSize * inputSize)); + + std::vector cellStateInVector(batchSize * numUnits, T()); + std::vector outputStateInVector(batchSize * outputSize, T()); + + std::vector actualOutput(outputTensorInfo.GetNumElements()); + + std::vector outputVector; + outputVector.assign(outputExpected.data(), outputExpected.data() + (batchSize * timeSize * outputSize)); + + std::unique_ptr inputHandle = tensorHandleFactory.CreateTensorHandle(inputTensorInfo); + std::unique_ptr cellStateInHandle = + tensorHandleFactory.CreateTensorHandle(cellStateInTensorInfo); + std::unique_ptr outputStateInHandle = + tensorHandleFactory.CreateTensorHandle(outputStateInTensorInfo); + + std::unique_ptr outputHandle = tensorHandleFactory.CreateTensorHandle(outputTensorInfo); + + armnn::UnidirectionalSequenceLstmQueueDescriptor data; + armnn::WorkloadInfo info; + + AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get()); + AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get()); + AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get()); + + AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get()); + + armnn::TensorInfo tensorInfo4({numUnits}, constantDataType, qScale, qOffset); + armnn::TensorInfo tensorInfo12({numUnits, 3}, constantDataType, qScale, qOffset); + armnn::TensorInfo tensorInfo16({numUnits, 4}, constantDataType, qScale, qOffset); + + std::vector inputToInputWeights = { 0.27277296781539917f, 0.3813590407371521f, -0.394489049911499f, + 0.2782636880874634f, -0.3793870210647583f, -0.018918335437774658f, + 0.2724653482437134f, -0.19314253330230713f, -0.2947450876235962f, + -0.30253493785858154f, 0.4241350293159485f, -0.22560018301010132f }; + + std::vector inputToForgetWeights = { -0.2667974531650543f, -0.05505800247192383f, -0.20932340621948242f, + -0.14345619082450867f, 0.09666192531585693f, -0.2604355812072754f, + -0.2681812047958374f, -0.3314584493637085f, 0.4485899806022644f, + -0.23467743396759033f, 0.5072842240333557f, -0.4192768931388855f }; + + std::vector inputToCellWeights = { -0.15782442688941956f, -0.027530014514923096f, 0.4789854884147644f, + 0.23227906227111816f, 0.28259342908859253f, -0.030095696449279785f, + 0.10071521997451782f, -0.08535495400428772f, 0.18563997745513916f, + -0.3049069046974182f, -0.478048175573349f, 0.025234103202819824f }; + + std::vector inputToOutputWeights = { -0.04584759473800659f, -0.2716066539287567f, 0.012970447540283203f, + -0.4729190170764923f, -0.37422770261764526f, 0.49352723360061646f, + 0.3163864016532898f, -0.436781644821167f, -0.33074596524238586f, + -0.32885751128196716f, -0.40959352254867554f, -0.2124689817428589f }; + + std::vector recurrentToInputWeights = { 0.23788475990f, -0.24948765337f, 0.50044941902f, 0.14431896805f, + -0.115940228137f, -0.717082679f, -0.17208620906f, 0.17850610617f, + -0.16702319684f, -0.11384502053f, -0.309785276245f, -0.3316611672f, + 0.52380162477f, -0.06839632987f, -0.391478359627f, -0.10756178963f }; + + std::vector recurrentToForgetWeights = { 0.11383482068f, 0.1676601767f, -0.08550968004f, 0.03399394089f, + 0.08042152225f, -0.2133381964f, 0.05182432704f, 0.38161808255f, + -0.5018365979f, -0.08043262364f, 0.07894329014f, -0.07547105155f, + 0.12047368288f, 0.2986997961f, 0.0485043078f, -0.13372567296f }; + + std::vector recurrentToCellWeights = { 0.0433832928545f, 0.07587072294f, -0.120520234107f, 0.604576051f, + -0.434353142986f, 0.009314475068f, 0.005085289478f, 0.08488202038f, + -0.00025437487886f, 0.15245915082f, -0.1936587542f, 0.004754020f, + -0.1582719236f, 0.3307867646f, 0.0236605107784f, 0.307716339826f }; + + std::vector recurrentToOutputWeights = { -0.079031050201f, 0.041414566286f, -0.583727357285f, 0.1025384515f, + -0.172372072937f, 0.09214124082f, 0.178184121827f, -0.2439443916f, + 0.104485116899f, 0.2600405514f, 0.064414866268f, 0.24141204357f, + 0.281875759363f, -0.14234502664f, 0.15126448862f, -0.24421440064f }; + + std::vector inputGateBias = { 0., 0., 0., 0. }; + + std::vector forgetGateBias = { 1., 1., 1., 1. }; + + std::vector cellBias = { 0., 0., 0., 0. }; + + std::vector outputGateBias = { 0., 0., 0., 0. }; + + armnn::ScopedTensorHandle inputToInputWeightsTensor(tensorInfo12); + armnn::ScopedTensorHandle inputToForgetWeightsTensor(tensorInfo12); + armnn::ScopedTensorHandle inputToCellWeightsTensor(tensorInfo12); + armnn::ScopedTensorHandle inputToOutputWeightsTensor(tensorInfo12); + armnn::ScopedTensorHandle recurrentToInputWeightsTensor(tensorInfo16); + armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(tensorInfo16); + armnn::ScopedTensorHandle recurrentToCellWeightsTensor(tensorInfo16); + armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(tensorInfo16); + armnn::ScopedTensorHandle inputGateBiasTensor(tensorInfo4); + armnn::ScopedTensorHandle forgetGateBiasTensor(tensorInfo4); + armnn::ScopedTensorHandle cellBiasTensor(tensorInfo4); + armnn::ScopedTensorHandle outputGateBiasTensor(tensorInfo4); + + AllocateAndCopyDataToITensorHandle(&inputToInputWeightsTensor, inputToInputWeights.data()); + AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data()); + AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data()); + AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data()); + AllocateAndCopyDataToITensorHandle(&recurrentToInputWeightsTensor, recurrentToInputWeights.data()); + AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data()); + AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data()); + AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data()); + AllocateAndCopyDataToITensorHandle(&inputGateBiasTensor, inputGateBias.data()); + AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data()); + AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data()); + AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data()); + + data.m_InputToInputWeights = &inputToInputWeightsTensor; + data.m_InputToForgetWeights = &inputToForgetWeightsTensor; + data.m_InputToCellWeights = &inputToCellWeightsTensor; + data.m_InputToOutputWeights = &inputToOutputWeightsTensor; + data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor; + data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor; + data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor; + data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor; + data.m_InputGateBias = &inputGateBiasTensor; + data.m_ForgetGateBias = &forgetGateBiasTensor; + data.m_CellBias = &cellBiasTensor; + data.m_OutputGateBias = &outputGateBiasTensor; + + // Flags to set test configuration + data.m_Parameters.m_ClippingThresCell = 10; + data.m_Parameters.m_ClippingThresProj = 0; + data.m_Parameters.m_ActivationFunc = 4; + data.m_Parameters.m_CifgEnabled = false; + data.m_Parameters.m_PeepholeEnabled = false; + data.m_Parameters.m_ProjectionEnabled = false; + data.m_Parameters.m_TimeMajor = true; + + std::unique_ptr workload = workloadFactory.CreateUnidirectionalSequenceLstm(data, info); + inputHandle->Allocate(); + outputStateInHandle->Allocate(); + cellStateInHandle->Allocate(); + + outputHandle->Allocate(); + + CopyDataToITensorHandle(inputHandle.get(), inputVector.data()); + CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data()); + CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data()); + + workload->Execute(); + + CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get()); + + return LayerTestResult(actualOutput, + outputVector, + outputHandle->GetShape(), + outputTensorInfo.GetShape()); +} + +} // anonymous namespace + +LayerTestResult UnidirectionalSequenceLstmLayerFloat32Test( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) { + armnn::TensorInfo inputInfo({3, 2, 3}, armnn::DataType::Float32); + std::vector input = { 1., 2., 3., 4., 5., 4., + 3., 2., 1., 2., 3., 4., + 5., 4., 3., 2., 1., 2. }; + + armnn::TensorInfo outputInfo({3, 2, 4}, armnn::DataType::Float32); + std::vector expectedOutput = { -0.07149004f, -0.1621171f, -0.17516759f, -0.0232934225f, + -0.16810727f, -0.41412935f, -0.5498753f, -0.00803578f, + -0.06687349f, 0.204077631f, -0.4276504f, -0.03123213f, + -0.12000261f, -0.0941918f, -0.45639035f, -0.02870186f, + -0.03429216f, 0.20824050f, -0.6569892f, -0.004152651f, + -0.10493034f, 0.14210969f, -0.58347696f, -0.03297536f }; + return UnidirectionalSequenceLstmLayerFloat32TestImpl( + workloadFactory, memoryManager, tensorHandleFactory, + input, expectedOutput, inputInfo.GetShape(), outputInfo.GetShape()); +} + +LayerTestResult UnidirectionalSequenceLstmLayerFloat32TimeMajorTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) { + armnn::TensorInfo inputInfo({2, 3, 3}, armnn::DataType::Float32); + std::vector input = { 1., 2., 3., 4., 5., 4., + 3., 2., 1., 2., 3., 4., + 5., 4., 3., 2., 1., 2. }; + + armnn::TensorInfo outputInfo({2, 3, 4}, armnn::DataType::Float32); + std::vector expectedOutput = { 0.135657698f, 0.124672532f, 0.0212090332f, -0.0530203655f, + 0.106138252f, 0.0404792242f, 0.0151643595f, -0.00675163185f, + -0.0128514022f, 0.0644884035f, 0.0709072053f, -0.0454045124f, + 0.16288602f, 0.16649379f, 0.02770456f, -0.03698075f, + 0.11171641f, 0.043119f , 0.0762981f , -0.01228541f, + 0.10439701f, 0.21439962f, 0.11919238f, -0.08390583f }; + return UnidirectionalSequenceLstmLayerFloat32TimeMajorTestImpl( + workloadFactory, memoryManager, tensorHandleFactory, + input, expectedOutput, inputInfo.GetShape(), outputInfo.GetShape()); +} + +LayerTestResult UnidirectionalSequenceLstmLayerNoCifgWithPeepholeWithProjectionTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) +{ + IgnoreUnused(memoryManager); + unsigned int batchSize = 2; + unsigned int timeSize = 3; + unsigned int outputSize = 5; + unsigned int inputSize = 4; + unsigned numUnits = 6; + + armnn::TensorInfo inputTensorInfo({batchSize, timeSize, inputSize}, armnn::DataType::Float32); + armnn::TensorInfo cellStateInTensorInfo({batchSize , numUnits}, armnn::DataType::Float32); + armnn::TensorInfo outputStateInTensorInfo({batchSize , outputSize}, armnn::DataType::Float32); + armnn::TensorInfo outputTensorInfo({batchSize, timeSize, outputSize}, armnn::DataType::Float32); + + const std::vector inputVector = { 1., 2., 3., 4., 5., 4., + 3., 2., 1., 2., 3., 4., + 5., 4., 3., 2., 1., 2., + 1., 2., 3., 4., 5., 4.}; + + std::vector cellStateInVector(batchSize * numUnits, 0.f); + std::vector outputStateInVector(batchSize * outputSize, 0.f); + + std::vector actualOutput(outputTensorInfo.GetNumElements()); + + const std::vector expectedOutput = { -0.0135612f, -0.0263441f, 0.0314008f, -0.00883455f, 0.00763052f, + -0.00126877f, -0.0292959f, 0.0449957f, -0.00976195f, -0.00492338f, + -0.0175702f, -0.0431753f, 0.0597117f, -0.0169154f, 0.0142087f, + 0.00472515f, -0.0196355f, 0.0342524f, -0.00407936f, -0.0253189f, + -0.00512944f, -0.0293754f, 0.0512771f, -0.0151874f, -0.0246433f, + -0.00744986f, -0.0345103f, 0.0450666f, -0.00944991f, 0.0127171f }; + + std::unique_ptr inputHandle = tensorHandleFactory.CreateTensorHandle(inputTensorInfo); + std::unique_ptr cellStateInHandle = + tensorHandleFactory.CreateTensorHandle(cellStateInTensorInfo); + std::unique_ptr outputStateInHandle = + tensorHandleFactory.CreateTensorHandle(outputStateInTensorInfo); + std::unique_ptr outputHandle = tensorHandleFactory.CreateTensorHandle(outputTensorInfo); + + armnn::UnidirectionalSequenceLstmQueueDescriptor data; + armnn::WorkloadInfo info; + + AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get()); + AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get()); + AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get()); + AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get()); + + armnn::TensorInfo tensorInfo5({outputSize}, armnn::DataType::Float32); + armnn::TensorInfo tensorInfo6({numUnits}, armnn::DataType::Float32); + armnn::TensorInfo tensorInfo6x4({numUnits, inputSize}, armnn::DataType::Float32); + armnn::TensorInfo tensorInfo6x5({numUnits, outputSize}, armnn::DataType::Float32); + armnn::TensorInfo tensorInfo5x6({outputSize, numUnits}, armnn::DataType::Float32); + + std::vector inputToInputWeights = { 0.021393683f, 0.06124551f, 0.046905167f, -0.014657677f, + -0.03149463f, 0.09171803f, 0.14647801f, 0.10797193f, + -0.0057968358f, 0.0019193048f, -0.2726754f, 0.10154029f, + -0.018539885f, 0.080349885f, -0.10262385f, -0.022599787f, + -0.09121155f, -0.008675967f, -0.045206103f, -0.0821282f, + -0.008045952f, 0.015478081f, 0.055217247f, 0.038719587f }; + + std::vector inputToForgetWeights = { -0.0018401089f, -0.004852237f, 0.03698424f, 0.014181704f, + 0.028273236f, -0.016726194f, -0.05249759f, -0.10204261f, + 0.00861066f, -0.040979505f, -0.009899187f, 0.01923892f, + -0.028177269f, -0.08535103f, -0.14585495f, 0.10662567f, + -0.01909731f, -0.017883534f, -0.0047269356f, -0.045103323f, + 0.0030784295f, 0.076784775f, 0.07463696f, 0.094531395f}; + + std::vector inputToCellWeights = { -0.04580283f, -0.09549462f, -0.032418985f, -0.06454633f, + -0.043528453f, 0.043018587f, -0.049152344f, -0.12418144f, + -0.078985475f, -0.07596889f, 0.019484362f, -0.11434962f, + -0.0074034138f, -0.06314844f, -0.092981495f, 0.0062155537f, + -0.025034338f, -0.0028890965f, 0.048929527f, 0.06235075f, + 0.10665918f, -0.032036792f, -0.08505916f, -0.10843358f }; + + std::vector inputToOutputWeights = { -0.0998932f, -0.07201956f, -0.052803773f, -0.15629593f, + -0.15001918f, -0.07650751f, 0.02359855f, -0.075155355f, + -0.08037709f, -0.15093534f, 0.029517552f, -0.04751393f, + 0.010350531f, -0.02664851f, -0.016839722f, -0.023121163f, + 0.0077019283f, 0.012851257f, -0.05040649f, -0.0129761f, + -0.021737747f, -0.038305793f, -0.06870586f, -0.01481247f }; + + std::vector inputGateBias = { 0.02234832f, 0.14757581f, 0.18176508f, + 0.10380666f, 0.053110216f, -0.06928846f }; + + std::vector forgetGateBias = { 0.035185695f, -0.042891346f, -0.03032477f, + 0.23027696f, 0.11098921f, 0.08989442f }; + + std::vector cellBias = { -0.024379363f, 0.0055531194f, 0.23377132f, + 0.033463873f, -0.1483596f, 0.029460307f }; + + std::vector outputGateBias = { 0.046159424f, -0.0012809046f, 0.03563469f, + 0.12648113f, 0.027195795f, 0.35373217f }; + + std::vector recurrentToInputWeights = { -0.001374326f, -0.078856036f, 0.10672688f, 0.029162422f, + -0.11585556f, 0.02557986f, -0.13446963f, -0.035785314f, + -0.01244275f, 0.025961924f, -0.02337298f, -0.044228926f, + -0.055839065f, -0.046598054f, -0.010546039f, -0.06900766f, + 0.027239809f, 0.022582639f, -0.013296484f, -0.05459212f, + 0.08981f, -0.045407712f, 0.08682226f, -0.06867011f, + -0.14390695f, -0.02916037f, 0.000996957f, 0.091420636f, + 0.14283475f, -0.07390571f }; + + std::vector recurrentToCellWeights = { -0.037322544f, 0.018592842f, 0.0056175636f, -0.06253426f, + 0.055647098f, -0.05713207f, -0.05626563f, 0.005559383f, + 0.03375411f, -0.025757805f, -0.088049285f, 0.06017052f, + -0.06570978f, 0.007384076f, 0.035123326f, -0.07920549f, + 0.053676967f, 0.044480428f, -0.07663568f, 0.0071805613f, + 0.08089997f, 0.05143358f, 0.038261272f, 0.03339287f, + -0.027673481f, 0.044746667f, 0.028349208f, 0.020090483f, + -0.019443132f, -0.030755889f }; + + std::vector recurrentToForgetWeights = { -0.057784554f, -0.026057621f, -0.068447545f, -0.022581743f, + 0.14811787f, 0.10826372f, 0.09471067f, 0.03987225f, + -0.0039523416f, 0.00030638507f, 0.053185795f, 0.10572994f, + 0.08414449f, -0.022036452f, -0.00066928595f, -0.09203576f, + 0.032950465f, -0.10985798f, -0.023809856f, 0.0021431844f, + -0.02196096f, -0.00326074f, 0.00058621005f, -0.074678116f, + -0.06193199f, 0.055729095f, 0.03736828f, 0.020123724f, + 0.061878487f, -0.04729229f }; + + std::vector recurrentToOutputWeights = { 0.025825322f, -0.05813119f, 0.09495884f, + -0.045984812f,-0.01255415f, -0.0026479573f, + -0.08196161f, -0.054914974f, -0.0046604523f, + -0.029587349f, -0.044576716f, -0.07480124f, + -0.082868785f, 0.023254942f, 0.027502948f, + -0.0039728214f, -0.08683098f, -0.08116779f, + -0.014675607f, -0.037924774f, -0.023314456f, + -0.007401714f, -0.09255757f, 0.029460307f, + -0.08829125f, -0.005139627f, -0.08989442f, + -0.0555066f, 0.13596267f, 0.025062224f }; + + std::vector cellToInputWeights = { 0.040369894f, 0.030746894f, 0.24704495f, + 0.018586371f, -0.037586458f, -0.15312155f }; + + std::vector cellToForgetWeights = { -0.01998659f, -0.15568835f, -0.24248174f, + -0.012770197f, 0.041331276f, -0.072311886f }; + + std::vector cellToOutputWeights = { 0.08286371f, -0.08261836f, -0.51210177f, + 0.002913762f, 0.17764764f, -0.5495371f }; + + std::vector projectionWeights = { -0.009802181f, 0.09401916f, 0.0717386f, -0.13895074f, 0.09641832f, + 0.060420845f, 0.08539281f, 0.054285463f, 0.061395317f, 0.034448683f, + -0.042991187f, 0.019801661f, -0.16840284f, -0.015726732f, -0.23041931f, + -0.024478018f, -0.10959692f, -0.013875541f, 0.18600968f, -0.061274476f, + 0.0138165f, -0.08160894f, -0.07661644f, 0.032372914f, 0.16169067f, + 0.22465782f, -0.03993472f, -0.004017731f, 0.08633481f, -0.28869787f }; + + std::vector projectionBiasVector(outputSize, 0.f); //{outputSize} + + armnn::ScopedTensorHandle inputToInputWeightsTensor(tensorInfo6x4); + armnn::ScopedTensorHandle inputToForgetWeightsTensor(tensorInfo6x4); + armnn::ScopedTensorHandle inputToCellWeightsTensor(tensorInfo6x4); + armnn::ScopedTensorHandle inputToOutputWeightsTensor(tensorInfo6x4); + armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(tensorInfo6x5); + armnn::ScopedTensorHandle recurrentToInputWeightsTensor(tensorInfo6x5); + armnn::ScopedTensorHandle recurrentToCellWeightsTensor(tensorInfo6x5); + armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(tensorInfo6x5); + armnn::ScopedTensorHandle cellToInputWeightsTensor(tensorInfo6); + armnn::ScopedTensorHandle inputGateBiasTensor(tensorInfo6); + armnn::ScopedTensorHandle forgetGateBiasTensor(tensorInfo6); + armnn::ScopedTensorHandle cellBiasTensor(tensorInfo6); + armnn::ScopedTensorHandle outputGateBiasTensor(tensorInfo6); + armnn::ScopedTensorHandle cellToForgetWeightsTensor(tensorInfo6); + armnn::ScopedTensorHandle cellToOutputWeightsTensor(tensorInfo6); + armnn::ScopedTensorHandle projectionWeightsTensor(tensorInfo5x6); + armnn::ScopedTensorHandle projectionBiasTensor(tensorInfo5); + + AllocateAndCopyDataToITensorHandle(&inputToInputWeightsTensor, inputToInputWeights.data()); + AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data()); + AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data()); + AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data()); + AllocateAndCopyDataToITensorHandle(&recurrentToInputWeightsTensor, recurrentToInputWeights.data()); + AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data()); + AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data()); + AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data()); + AllocateAndCopyDataToITensorHandle(&cellToInputWeightsTensor, cellToInputWeights.data()); + AllocateAndCopyDataToITensorHandle(&inputGateBiasTensor, inputGateBias.data()); + AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data()); + AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data()); + AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data()); + AllocateAndCopyDataToITensorHandle(&cellToForgetWeightsTensor, cellToForgetWeights.data()); + AllocateAndCopyDataToITensorHandle(&cellToOutputWeightsTensor, cellToOutputWeights.data()); + AllocateAndCopyDataToITensorHandle(&projectionWeightsTensor, projectionWeights.data()); + AllocateAndCopyDataToITensorHandle(&projectionBiasTensor, projectionBiasVector.data()); + + data.m_InputToInputWeights = &inputToInputWeightsTensor; + data.m_InputToForgetWeights = &inputToForgetWeightsTensor; + data.m_InputToCellWeights = &inputToCellWeightsTensor; + data.m_InputToOutputWeights = &inputToOutputWeightsTensor; + data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor; + data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor; + data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor; + data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor; + data.m_CellToInputWeights = &cellToInputWeightsTensor; + data.m_InputGateBias = &inputGateBiasTensor; + data.m_ForgetGateBias = &forgetGateBiasTensor; + data.m_CellBias = &cellBiasTensor; + data.m_OutputGateBias = &outputGateBiasTensor; + data.m_CellToForgetWeights = &cellToForgetWeightsTensor; + data.m_CellToOutputWeights = &cellToOutputWeightsTensor; + data.m_ProjectionWeights = &projectionWeightsTensor; + data.m_ProjectionBias = &projectionBiasTensor; + + // Flags to set test configuration + data.m_Parameters.m_ActivationFunc = 4; + data.m_Parameters.m_CifgEnabled = false; + data.m_Parameters.m_PeepholeEnabled = true; + data.m_Parameters.m_ProjectionEnabled = true; + data.m_Parameters.m_LayerNormEnabled = false; + data.m_Parameters.m_TimeMajor = false; + data.m_Parameters.m_ClippingThresCell = 10.0f; + + + std::unique_ptr workload = workloadFactory.CreateUnidirectionalSequenceLstm(data, info); + inputHandle->Allocate(); + outputStateInHandle->Allocate(); + cellStateInHandle->Allocate(); + outputHandle->Allocate(); + + CopyDataToITensorHandle(inputHandle.get(), inputVector.data()); + CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data()); + CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data()); + + workload->Execute(); + + CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get()); + + return LayerTestResult(actualOutput, + expectedOutput, + outputHandle->GetShape(), + outputTensorInfo.GetShape()); +} + +LayerTestResult UnidirectionalSequenceLstmLayerNoCifgWithPeepholeWithProjectionWithLayerNormTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) +{ + IgnoreUnused(memoryManager); + unsigned int batchSize = 3; + unsigned int timeSize = 2; + unsigned int outputSize = 4; + unsigned int inputSize = 3; + unsigned numUnits = 5; + + armnn::TensorInfo inputTensorInfo({batchSize, timeSize, inputSize}, armnn::DataType::Float32); + armnn::TensorInfo cellStateInTensorInfo({batchSize , numUnits}, armnn::DataType::Float32); + armnn::TensorInfo outputStateInTensorInfo({batchSize , outputSize}, armnn::DataType::Float32); + armnn::TensorInfo outputTensorInfo({batchSize, timeSize, outputSize}, armnn::DataType::Float32); + + const std::vector inputVector = { 1., 2., 3., 4., 5., 4., + 3., 2., 1., 2., 3., 4., + 5., 4., 3., 2., 1., 2. }; + + std::vector cellStateInVector(batchSize * numUnits, 0.f); + std::vector outputStateInVector(batchSize * outputSize, 0.f); + + std::vector actualOutput(outputTensorInfo.GetNumElements()); + + const std::vector expectedOutput = { 0.0642256f, 0.0343966f, 0.184122f, 0.114717f, + 0.11458f, 0.0407109f, 0.300327f, 0.174301f, + 0.0864761f, 0.0362912f, 0.178635f, 0.115689f, + 0.108008f, 0.0386623f, 0.273471f, 0.167115f, + 0.0859545f, 0.0331481f, 0.186051f, 0.11888f, + 0.106649f, 0.0276847f, 0.229863f, 0.166958f }; + + std::unique_ptr inputHandle = tensorHandleFactory.CreateTensorHandle(inputTensorInfo); + std::unique_ptr cellStateInHandle = + tensorHandleFactory.CreateTensorHandle(cellStateInTensorInfo); + std::unique_ptr outputStateInHandle = + tensorHandleFactory.CreateTensorHandle(outputStateInTensorInfo); + + std::unique_ptr outputHandle = tensorHandleFactory.CreateTensorHandle(outputTensorInfo); + + armnn::UnidirectionalSequenceLstmQueueDescriptor data; + armnn::WorkloadInfo info; + + AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get()); + AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get()); + AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get()); + + AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get()); + + armnn::TensorInfo tensorInfo4({outputSize}, armnn::DataType::Float32); + armnn::TensorInfo tensorInfo5({numUnits}, armnn::DataType::Float32); + armnn::TensorInfo tensorInfo5x3({numUnits, inputSize}, armnn::DataType::Float32); + armnn::TensorInfo tensorInfo5x4({numUnits, outputSize}, armnn::DataType::Float32); + armnn::TensorInfo tensorInfo4x5({outputSize, numUnits}, armnn::DataType::Float32); + + std::vector inputToInputWeights = { -0.49536117f, -0.0556083915f, -0.102400711f, + -0.117484632f, 0.3298470976f, -0.1179017122f, + 0.214305695f, 0.42135173085f, 0.003878414626f, + -0.348303917f, -0.1881275477f, 0.0343011027f, + -0.38837709614f, -0.05636804124f, 0.4259087456f}; + + std::vector inputToForgetWeights = { 0.2415594226f, 0.15400093799f, 0.4566498398f, + -0.3810434485f, 0.268383264f, -0.009807467424f, + -0.3522925403f, -0.24275735512f, -0.28344226125f, + 0.13512269116f, -0.4932442977f, -0.10039821991f, + 0.2726137042f, 0.09216640889f, -0.06551410215f}; + + std::vector inputToCellWeights = { -0.2504855627f, 0.184490025045f, -0.2480507493f, + 0.386399507f, -0.259465157985f, -0.16545993089f, + -0.4230232555f, 0.341664791103f, -0.18127849691f, + -0.2277662414f, -0.55275535589f, 0.34184026718f, + 0.3954237699f, -0.19407111404f, 0.30412107706f}; + + std::vector inputToOutputWeights = { 0.2303854227f, 0.5218806862f, -0.4865379333f, + 0.53969591851f, 0.23393625035f, -0.27140527306f, + 0.50009280443f, 0.07511717046f, 0.3998299249f, + -0.51717478049f, 0.1889653282f, -0.367323637f, + -0.12584099173f, -0.12319286912f, 0.2407919466f}; + + std::vector inputGateBias{ 0.03f, 0.15f, 0.22f, 0.38f, 0.05f }; + std::vector forgetGateBias{ 0.1f, -0.3f, -0.2f, 0.1f, 0.4f }; + std::vector cellBias{ -0.05f, 0.72f, 0.25f, 0.08f, 0.1f }; + std::vector outputGateBias{ 0.05f, -0.01f, 0.2f, 0.1f, -0.2f }; + + std::vector recurrentToInputWeights = { -0.128009796112f, 0.1995525098f, -0.07745539397f, 0.1558421701f, + -0.265254765766f, -0.38837709614f, -0.05636804124f, 0.4259087456f, + 0.17628988623f, 0.3877420127f, 0.53300309181f, -0.0959980934f, + 0.00302857416f, 0.3266998827f, -0.142509296562f, -0.04433270756f, + 0.54066205f, -0.32668582f, -0.43562764f, -0.56094903f }; + + std::vector recurrentToForgetWeights = { -0.09499983487f, -0.08814888417f, -0.04834804721f, 0.1516668247f, + -0.3967529535f, -0.06463699788f, 0.4952811002f, 0.003274492938f, + -0.0968840941f, 0.17928104102f, 0.0031281141592f, -0.3387276584f, + -0.3587934076f, 0.06705895066f, 0.22463923692f, 0.1961955726f, + 0.01841056f, -0.32764608f, -0.33027974f, -0.10826075f }; + + std::vector recurrentToCellWeights = { -0.21938985582f, -0.3023648226f, -0.1170005202f, -0.3509177422f, + -0.4286288613f, 0.2726137042f, 0.09216640889f, -0.06551410215f, + 0.20453298098f, 0.2393476665f, 0.11846517771f, 0.2630801796f, + 0.3954237699f, -0.19407111404f, 0.30412107706f, -0.27342408554f, + 0.19069612f, -0.03026325f, -0.54532051f, 0.33003211f }; + + std::vector recurrentToOutputWeights = { -0.32921677827f, 0.32624614238f, -0.1388191282f, -0.17879831790f, + -0.15185534954f, -0.16918526583f, -0.10087361183f, -0.5436913968f, + 0.016758225858f, 0.30454617738f, -0.41493862867f, -0.005565764375f, + -0.12584099173f, -0.12319286912f, 0.2407919466f, -0.08879069983f, + 0.11178309f, 0.09481031f, -0.26424935f, 0.46261835f }; + + std::vector cellToInputWeights { 0.05f, 0.1f, 0.25f, 0.15f, -0.02f }; + std::vector cellToForgetWeights { -0.02f, -0.15f, -0.25f, -0.03f, 0.15f }; + std::vector cellToOutputWeights { 0.1f, -0.1f, -0.5f, 0.05f, 0.01f }; + + std::vector projectionWeights{ -0.1f, 0.2f, 0.01f, -0.2f, + 0.1f, 0.5f, 0.3f, 0.08f, + 0.07f, 0.2f, -0.4f, 0.2f, + 0.5f, -0.4f, 0.3f, -0.2f, + 0.3f, 0.08f, -0.07f, 0.2f}; + + std::vector projectionBiasVector(outputSize, 0.f); //{outputSize} + + std::vector inputLayerNormWeights{ 0.1f, 0.2f, 0.3f, 0.5f, 0.8f }; + std::vector forgetLayerNormWeights{ 0.1f, 0.2f, 0.3f, 0.5f, 0.2f }; + std::vector cellLayerNormWeights{ 0.7f, 0.2f, 0.3f, 0.8f, 0.5f }; + std::vector outputLayerNormWeights{ 0.6f, 0.2f, 0.2f, 0.5f, 0.1f }; + + armnn::ScopedTensorHandle inputToInputWeightsTensor(tensorInfo5x3); + armnn::ScopedTensorHandle inputToForgetWeightsTensor(tensorInfo5x3); + armnn::ScopedTensorHandle inputToCellWeightsTensor(tensorInfo5x3); + armnn::ScopedTensorHandle inputToOutputWeightsTensor(tensorInfo5x3); + armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(tensorInfo5x4); + armnn::ScopedTensorHandle recurrentToInputWeightsTensor(tensorInfo5x4); + armnn::ScopedTensorHandle recurrentToCellWeightsTensor(tensorInfo5x4); + armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(tensorInfo5x4); + armnn::ScopedTensorHandle cellToInputWeightsTensor(tensorInfo5); + armnn::ScopedTensorHandle inputGateBiasTensor(tensorInfo5); + armnn::ScopedTensorHandle forgetGateBiasTensor(tensorInfo5); + armnn::ScopedTensorHandle cellBiasTensor(tensorInfo5); + armnn::ScopedTensorHandle outputGateBiasTensor(tensorInfo5); + armnn::ScopedTensorHandle cellToForgetWeightsTensor(tensorInfo5); + armnn::ScopedTensorHandle cellToOutputWeightsTensor(tensorInfo5); + armnn::ScopedTensorHandle projectionWeightsTensor(tensorInfo4x5); + armnn::ScopedTensorHandle projectionBiasTensor(tensorInfo4); + + armnn::ScopedTensorHandle inputLayerNormWeightsTensor(tensorInfo5); + armnn::ScopedTensorHandle forgetLayerNormWeightsTensor(tensorInfo5); + armnn::ScopedTensorHandle cellLayerNormWeightsTensor(tensorInfo5); + armnn::ScopedTensorHandle outputLayerNormWeightsTensor(tensorInfo5); + + AllocateAndCopyDataToITensorHandle(&inputToInputWeightsTensor, inputToInputWeights.data()); + AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data()); + AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data()); + AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data()); + AllocateAndCopyDataToITensorHandle(&recurrentToInputWeightsTensor, recurrentToInputWeights.data()); + AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data()); + AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data()); + AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data()); + AllocateAndCopyDataToITensorHandle(&cellToInputWeightsTensor, cellToInputWeights.data()); + AllocateAndCopyDataToITensorHandle(&inputGateBiasTensor, inputGateBias.data()); + AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data()); + AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data()); + AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data()); + AllocateAndCopyDataToITensorHandle(&cellToForgetWeightsTensor, cellToForgetWeights.data()); + AllocateAndCopyDataToITensorHandle(&cellToOutputWeightsTensor, cellToOutputWeights.data()); + AllocateAndCopyDataToITensorHandle(&projectionWeightsTensor, projectionWeights.data()); + AllocateAndCopyDataToITensorHandle(&projectionBiasTensor, projectionBiasVector.data()); + + AllocateAndCopyDataToITensorHandle(&inputLayerNormWeightsTensor, inputLayerNormWeights.data()); + AllocateAndCopyDataToITensorHandle(&forgetLayerNormWeightsTensor, forgetLayerNormWeights.data()); + AllocateAndCopyDataToITensorHandle(&cellLayerNormWeightsTensor, cellLayerNormWeights.data()); + AllocateAndCopyDataToITensorHandle(&outputLayerNormWeightsTensor, outputLayerNormWeights.data()); + + data.m_InputToInputWeights = &inputToInputWeightsTensor; + data.m_InputToForgetWeights = &inputToForgetWeightsTensor; + data.m_InputToCellWeights = &inputToCellWeightsTensor; + data.m_InputToOutputWeights = &inputToOutputWeightsTensor; + data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor; + data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor; + data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor; + data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor; + data.m_CellToInputWeights = &cellToInputWeightsTensor; + data.m_InputGateBias = &inputGateBiasTensor; + data.m_ForgetGateBias = &forgetGateBiasTensor; + data.m_CellBias = &cellBiasTensor; + data.m_OutputGateBias = &outputGateBiasTensor; + data.m_CellToForgetWeights = &cellToForgetWeightsTensor; + data.m_CellToOutputWeights = &cellToOutputWeightsTensor; + data.m_ProjectionWeights = &projectionWeightsTensor; + data.m_ProjectionBias = &projectionBiasTensor; + + data.m_InputLayerNormWeights = &inputLayerNormWeightsTensor; + data.m_ForgetLayerNormWeights = &forgetLayerNormWeightsTensor; + data.m_CellLayerNormWeights = &cellLayerNormWeightsTensor; + data.m_OutputLayerNormWeights = &outputLayerNormWeightsTensor; + + // Flags to set test configuration + data.m_Parameters.m_ActivationFunc = 4; + data.m_Parameters.m_CifgEnabled = false; + data.m_Parameters.m_PeepholeEnabled = true; + data.m_Parameters.m_ProjectionEnabled = true; + data.m_Parameters.m_LayerNormEnabled = true; + data.m_Parameters.m_TimeMajor = false; + data.m_Parameters.m_ClippingThresCell = 10.0f; + + std::unique_ptr workload = workloadFactory.CreateUnidirectionalSequenceLstm(data, info); + inputHandle->Allocate(); + outputStateInHandle->Allocate(); + cellStateInHandle->Allocate(); + outputHandle->Allocate(); + + CopyDataToITensorHandle(inputHandle.get(), inputVector.data()); + CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data()); + CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data()); + + workload->Execute(); + + CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get()); + + return LayerTestResult(actualOutput, + expectedOutput, + outputHandle->GetShape(), + outputTensorInfo.GetShape()); +} + +LayerTestResult UnidirectionalSequenceLstmWithCifgWithPeepholeNoProjectionTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) +{ + IgnoreUnused(memoryManager); + unsigned int batchSize = 3; + unsigned int timeSize = 2; + unsigned int inputSize = 3; + unsigned int outputSize = 4; + unsigned numUnits = outputSize; + + armnn::TensorInfo inputTensorInfo({batchSize, timeSize, inputSize}, armnn::DataType::Float32); + armnn::TensorInfo cellStateInTensorInfo({batchSize, numUnits}, armnn::DataType::Float32); + armnn::TensorInfo outputStateInTensorInfo({batchSize, outputSize}, armnn::DataType::Float32); + + armnn::TensorInfo outputTensorInfo({batchSize, timeSize, outputSize}, armnn::DataType::Float32); + + std::vector inputVector = { 1., 2., 3., 4., 5., 4., + 3., 2., 1., 2., 3., 4., + 5., 4., 3., 2., 1., 2. }; + + std::vector cellStateInVector(batchSize * numUnits, 0.f); + std::vector outputStateInVector(batchSize * outputSize, 0.f); + + std::vector actualOutput(outputTensorInfo.GetNumElements()); + + std::vector outputVector = { -0.0129257f, -0.070531f, -0.153508f, -0.0392391f, + -0.0300169f, -0.195717f, -0.528679f, -0.0818106f, + -0.0332748f, 0.155429f, -0.353966f, -0.0801505f, + -0.032312f, -0.0407911f, -0.435053f, -0.0932317f, + -0.0108233f, 0.165584f, -0.640424f, -0.0447535f, + -0.031675f, 0.125987f, -0.526695f, -0.110093f }; + + std::unique_ptr inputHandle = tensorHandleFactory.CreateTensorHandle(inputTensorInfo); + std::unique_ptr cellStateInHandle = + tensorHandleFactory.CreateTensorHandle(cellStateInTensorInfo); + std::unique_ptr outputStateInHandle = + tensorHandleFactory.CreateTensorHandle(outputStateInTensorInfo); + + std::unique_ptr outputHandle = tensorHandleFactory.CreateTensorHandle(outputTensorInfo); + + armnn::UnidirectionalSequenceLstmQueueDescriptor data; + armnn::WorkloadInfo info; + + AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get()); + AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get()); + AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get()); + + AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get()); + + armnn::TensorInfo tensorInfo4({numUnits}, armnn::DataType::Float32); + armnn::TensorInfo tensorInfo12({numUnits, 3}, armnn::DataType::Float32); + armnn::TensorInfo tensorInfo16({numUnits, 4}, armnn::DataType::Float32); + + std::vector inputToForgetWeights = { 0.2415594226f, 0.15400093799f, 0.4566498398f, + -0.3810434485f, 0.268383264f, -0.009807467424f, + -0.3522925403f, -0.24275735512f, -0.28344226125f, + 0.13512269116f, -0.4932442977f, -0.10039821991f }; + + std::vector inputToCellWeights = { -0.2504855627f, 0.184490025045f, -0.2480507493f, + 0.386399507f, -0.259465157985f, -0.16545993089f, + -0.4230232555f, 0.341664791103f, -0.18127849691f, + -0.2277662414f, -0.55275535589f, 0.34184026718f }; + + std::vector inputToOutputWeights = { 0.2303854227f, 0.5218806862f, -0.4865379333f, + 0.53969591851f, 0.23393625035f, -0.27140527306f, + 0.50009280443f, 0.07511717046f, 0.3998299249f, + -0.51717478049f, 0.1889653282f, -0.367323637f }; + + std::vector recurrentToForgetWeights = { -0.09499983487f, -0.08814888417f, -0.04834804721f, 0.1516668247f, + -0.3967529535f, -0.06463699788f, 0.4952811002f, 0.003274492938f, + -0.0968840941f, 0.17928104102f, 0.0031281141592f, -0.3387276584f, + -0.3587934076f, 0.06705895066f, 0.22463923692f, 0.1961955726f }; + + std::vector recurrentToCellWeights = { -0.21938985582f, -0.3023648226f, -0.1170005202f, -0.3509177422f, + -0.4286288613f, 0.2726137042f, 0.09216640889f, -0.06551410215f, + 0.20453298098f, 0.2393476665f, 0.11846517771f, 0.2630801796f, + 0.3954237699f, -0.19407111404f, 0.30412107706f, -0.27342408554f }; + + std::vector recurrentToOutputWeights = { -0.32921677827f, 0.32624614238f, -0.1388191282f, -0.17879831790f, + -0.15185534954f, -0.16918526583f, -0.10087361183f, -0.5436913968f, + 0.016758225858f, 0.30454617738f, -0.41493862867f, -0.005565764375f, + -0.12584099173f, -0.12319286912f, 0.2407919466f, -0.08879069983f }; + + std::vector cellToForgetWeights{ 0.47485286f, -0.51955009f, -0.24458408f, 0.31544167f }; + + std::vector cellToOutputWeights{ -0.17135078f, 0.82760304f, 0.85573703f, -0.77109635f }; + + std::vector forgetGateBias = { 1., 1., 1., 1. }; + + std::vector cellBias = { 0., 0., 0., 0. }; + + std::vector outputGateBias = { 0., 0., 0., 0. }; + + armnn::ScopedTensorHandle inputToForgetWeightsTensor(tensorInfo12); + armnn::ScopedTensorHandle inputToCellWeightsTensor(tensorInfo12); + armnn::ScopedTensorHandle inputToOutputWeightsTensor(tensorInfo12); + armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(tensorInfo16); + armnn::ScopedTensorHandle recurrentToCellWeightsTensor(tensorInfo16); + armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(tensorInfo16); + armnn::ScopedTensorHandle cellToForgetWeightsTensor(tensorInfo4); + armnn::ScopedTensorHandle cellToOutputWeightsTensor(tensorInfo4); + armnn::ScopedTensorHandle forgetGateBiasTensor(tensorInfo4); + armnn::ScopedTensorHandle cellBiasTensor(tensorInfo4); + armnn::ScopedTensorHandle outputGateBiasTensor(tensorInfo4); + + AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data()); + AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data()); + AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data()); + AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data()); + AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data()); + AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data()); + AllocateAndCopyDataToITensorHandle(&cellToForgetWeightsTensor, cellToForgetWeights.data()); + AllocateAndCopyDataToITensorHandle(&cellToOutputWeightsTensor, cellToOutputWeights.data()); + AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data()); + AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data()); + AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data()); + + data.m_InputToForgetWeights = &inputToForgetWeightsTensor; + data.m_InputToCellWeights = &inputToCellWeightsTensor; + data.m_InputToOutputWeights = &inputToOutputWeightsTensor; + data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor; + data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor; + data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor; + data.m_CellToForgetWeights = &cellToForgetWeightsTensor; + data.m_CellToOutputWeights = &cellToOutputWeightsTensor; + data.m_ForgetGateBias = &forgetGateBiasTensor; + data.m_CellBias = &cellBiasTensor; + data.m_OutputGateBias = &outputGateBiasTensor; + + // Flags to set test configuration + data.m_Parameters.m_ClippingThresCell = 10; + data.m_Parameters.m_ClippingThresProj = 0; + data.m_Parameters.m_ActivationFunc = 4; + data.m_Parameters.m_CifgEnabled = true; + data.m_Parameters.m_PeepholeEnabled = true; + data.m_Parameters.m_ProjectionEnabled = false; + data.m_Parameters.m_TimeMajor = false; + + std::unique_ptr workload = workloadFactory.CreateUnidirectionalSequenceLstm(data, info); + inputHandle->Allocate(); + outputStateInHandle->Allocate(); + cellStateInHandle->Allocate(); + + outputHandle->Allocate(); + + CopyDataToITensorHandle(inputHandle.get(), inputVector.data()); + CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data()); + CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data()); + + workload->Execute(); + + CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get()); + + return LayerTestResult(actualOutput, + outputVector, + outputHandle->GetShape(), + outputTensorInfo.GetShape()); +} diff --git a/src/backends/backendsCommon/test/layerTests/UnidirectionalSequenceLstmTestImpl.hpp b/src/backends/backendsCommon/test/layerTests/UnidirectionalSequenceLstmTestImpl.hpp new file mode 100644 index 0000000000..7b14065728 --- /dev/null +++ b/src/backends/backendsCommon/test/layerTests/UnidirectionalSequenceLstmTestImpl.hpp @@ -0,0 +1,36 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "LayerTestResult.hpp" + +#include +#include + +LayerTestResult UnidirectionalSequenceLstmLayerFloat32Test( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory); + +LayerTestResult UnidirectionalSequenceLstmLayerFloat32TimeMajorTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory); + +LayerTestResult UnidirectionalSequenceLstmLayerNoCifgWithPeepholeWithProjectionTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory); + +LayerTestResult UnidirectionalSequenceLstmLayerNoCifgWithPeepholeWithProjectionWithLayerNormTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory); + +LayerTestResult UnidirectionalSequenceLstmWithCifgWithPeepholeNoProjectionTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory); \ No newline at end of file diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 1b05c4e0f4..2603371927 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -1242,6 +1242,7 @@ bool RefLayerSupport::IsLstmSupported(const TensorInfo& input, "Reference Lstm: input and outputStateOut types are mismatched"); supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported, "Reference Lstm: input and cellStateOut types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported, "Reference Lstm: input and output types are mismatched"); // check layer parameters @@ -2288,4 +2289,150 @@ bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input, return supported; } +bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported( + const TensorInfo& input, + const TensorInfo& outputStateIn, + const TensorInfo& cellStateIn, + const TensorInfo& output, + const Optional& hiddenStateOutput, + const Optional& cellStateOutput, + const UnidirectionalSequenceLstmDescriptor& descriptor, + const LstmInputParamsInfo& paramsInfo, + Optional reasonIfUnsupported) const +{ + IgnoreUnused(descriptor); + IgnoreUnused(paramsInfo); + IgnoreUnused(outputStateIn); + IgnoreUnused(cellStateIn); + bool supported = true; + + if (hiddenStateOutput.has_value() || cellStateOutput.has_value()) + { + reasonIfUnsupported.value() += "Reference UnidirectionalSequenceLstm: hidden state output " + "and cell state output are not supported at the moment."; + } + + std::array supportedTypes = + { + DataType::Float32 + }; + + std::array supportedWeightTypes = + { + DataType::Float32 + }; + + // check inputs and outputs + supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: input is not a supported type."); + supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: input and outputStateIn types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: input and cellStateIn types are mismatched"); + + supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: input and output types are mismatched"); + // check layer parameters + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: InputToForgetWeights " + "is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToCellWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: InputToCellWeights is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToOutputWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: InputToOutputWeights " + "is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToForgetWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: RecurrentToForgetWeights " + "is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToCellWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: RecurrentToCellWeights " + "is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToOutputWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights " + "is not a supported type."); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: input and ForgetGateBias types " + "are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: input and CellBias types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: input and OutputGateBias types " + "are mismatched"); + if (!descriptor.m_CifgEnabled) + { + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: InputToInputWeights " + "is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToInputWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights " + "is not a supported type."); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: input and InputGateBias types " + "are mismatched"); + if (descriptor.m_PeepholeEnabled) + { + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: CellToInputWeights " + "is not a supported type."); + } + } + if (descriptor.m_PeepholeEnabled) + { + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToForgetWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: CellToForgetWeights " + "is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToOutputWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: CellToOutputWeights " + "is not a supported type."); + } + if (descriptor.m_ProjectionEnabled) + { + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetProjectionWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: ProjectionWeights " + "is not a supported type."); + if (paramsInfo.m_ProjectionBias != nullptr) + { + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: input and ProjectionBias types " + "are mismatched"); + } + } + if (descriptor.m_LayerNormEnabled) + { + if (!descriptor.m_CifgEnabled) + { + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputLayerNormWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: InputLayerNormWeights " + "is not a supported type."); + } + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetLayerNormWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: ForgetLayerNormWeights " + "is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellLayerNormWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: CellLayerNormWeights " + "is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputLayerNormWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: OutputLayerNormWeights " + "is not a supported type."); + } + + return supported; +} + } // namespace armnn diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp index c060f79b5a..a1b4dc7f47 100644 --- a/src/backends/reference/RefLayerSupport.hpp +++ b/src/backends/reference/RefLayerSupport.hpp @@ -370,6 +370,17 @@ public: const TensorInfo& output, const TransposeDescriptor& descriptor, Optional reasonIfUnsupported = EmptyOptional()) const override; + + bool IsUnidirectionalSequenceLstmSupported( + const TensorInfo& input, + const TensorInfo& outputStateIn, + const TensorInfo& cellStateIn, + const TensorInfo& output, + const Optional& hiddenStateOutput, + const Optional& cellStateOutput, + const UnidirectionalSequenceLstmDescriptor& descriptor, + const LstmInputParamsInfo& paramsInfo, + Optional reasonIfUnsupported = EmptyOptional()) const override; }; } // namespace armnn diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 606f531630..16cf17cc79 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -712,4 +712,11 @@ std::unique_ptr RefWorkloadFactory::CreateTransposeConvolution2d( return std::make_unique(descriptor, info); } +std::unique_ptr RefWorkloadFactory::CreateUnidirectionalSequenceLstm( + const UnidirectionalSequenceLstmQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + return std::make_unique(descriptor, info);; +} + } // namespace armnn diff --git a/src/backends/reference/RefWorkloadFactory.hpp b/src/backends/reference/RefWorkloadFactory.hpp index 2beffa77f3..113aca70ef 100644 --- a/src/backends/reference/RefWorkloadFactory.hpp +++ b/src/backends/reference/RefWorkloadFactory.hpp @@ -276,6 +276,10 @@ public: std::unique_ptr CreateTransposeConvolution2d(const TransposeConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const override; + std::unique_ptr CreateUnidirectionalSequenceLstm( + const UnidirectionalSequenceLstmQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; + private: template std::unique_ptr MakeWorkload(const QueueDescriptorType& descriptor, const WorkloadInfo& info) const; diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk index bf18284143..17ddbe0df1 100644 --- a/src/backends/reference/backend.mk +++ b/src/backends/reference/backend.mk @@ -37,6 +37,7 @@ BACKEND_SOURCES := \ workloads/Gather.cpp \ workloads/InstanceNorm.cpp \ workloads/LogSoftmax.cpp \ + workloads/Lstm.cpp \ workloads/LstmUtils.cpp \ workloads/Concatenate.cpp \ workloads/Pad.cpp \ @@ -95,6 +96,7 @@ BACKEND_SOURCES := \ workloads/RefSplitterWorkload.cpp \ workloads/RefTransposeConvolution2dWorkload.cpp \ workloads/RefTransposeWorkload.cpp \ + workloads/RefUnidirectionalSequenceLstmWorkload.cpp \ workloads/Resize.cpp \ workloads/Slice.cpp \ workloads/SpaceToBatchNd.cpp \ diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index 45e3717268..0cf36f2c6e 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -2330,4 +2330,16 @@ ARMNN_AUTO_TEST_CASE_WITH_THF(ReduceMax2Float32, ReduceMaxSimpleTest2) ARMNN_AUTO_TEST_CASE_WITH_THF(ReduceMinNegativeAxisFloat32, ReduceMinNegativeAxisTest) +// Unidirectional Sequence Lstm +ARMNN_AUTO_TEST_CASE_WITH_THF(UnidirectionalSequenceLstmLayerFloat32, + UnidirectionalSequenceLstmLayerFloat32Test) +ARMNN_AUTO_TEST_CASE_WITH_THF(UnidirectionalSequenceLstmLayerFloat32TimeMajor, + UnidirectionalSequenceLstmLayerFloat32TimeMajorTest) +ARMNN_AUTO_TEST_CASE_WITH_THF(UnidirectionalSequenceLstmLayerNoCifgWithPeepholeWithProjection, + UnidirectionalSequenceLstmLayerNoCifgWithPeepholeWithProjectionTest) +ARMNN_AUTO_TEST_CASE_WITH_THF(UnidirectionalSequenceLstmLayerNoCifgWithPeepholeWithProjectionWithLayerNorm, + UnidirectionalSequenceLstmLayerNoCifgWithPeepholeWithProjectionWithLayerNormTest) +ARMNN_AUTO_TEST_CASE_WITH_THF(UnidirectionalSequenceLstmWithCifgWithPeepholeNoProjection, + UnidirectionalSequenceLstmWithCifgWithPeepholeNoProjectionTest) + } \ No newline at end of file diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index 7a769e5246..b9f477cb6d 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -42,6 +42,8 @@ list(APPEND armnnRefBackendWorkloads_sources Log.hpp LogSoftmax.cpp LogSoftmax.hpp + Lstm.cpp + Lstm.hpp LstmUtils.hpp LstmUtils.cpp Maximum.hpp @@ -162,6 +164,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefTransposeConvolution2dWorkload.hpp RefTransposeWorkload.cpp RefTransposeWorkload.hpp + RefUnidirectionalSequenceLstmWorkload.cpp + RefUnidirectionalSequenceLstmWorkload.hpp RefWorkloads.hpp RefWorkloadUtils.hpp Resize.cpp diff --git a/src/backends/reference/workloads/Lstm.cpp b/src/backends/reference/workloads/Lstm.cpp new file mode 100644 index 0000000000..c1fb2bf4aa --- /dev/null +++ b/src/backends/reference/workloads/Lstm.cpp @@ -0,0 +1,259 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "Activation.hpp" +#include "Lstm.hpp" +#include "LstmUtils.hpp" + +namespace armnn +{ + +void LstmImpl(const LstmDescriptor& descriptor, + const TensorInfo& inputInfo, + const TensorInfo& outputInfo, + const TensorShape& inputToOutputWeightsShape, + const TensorShape& recurrentToOutputWeightsShape, + std::unique_ptr>& inputData, + std::unique_ptr>& outputStateIn, + std::unique_ptr>& cellStateIn, + std::unique_ptr>& outputStateOut, + std::unique_ptr>& cellStateOut, + std::unique_ptr>& output, + std::unique_ptr>& cellStateOutDecoder, + std::unique_ptr>& outputDecoder, + std::unique_ptr>& inputToInputWeightsTensor, + std::unique_ptr>& inputToForgetWeightsTensor, + std::unique_ptr>& inputToCellWeightsTensor, + std::unique_ptr>& inputToOutputWeightsTensor, + std::unique_ptr>& recurrentToInputWeightsTensor, + std::unique_ptr>& recurrentToForgetWeightsTensor, + std::unique_ptr>& recurrentToCellWeightsTensor, + std::unique_ptr>& recurrentToOutputWeightsTensor, + std::unique_ptr>& cellToInputWeightsTensor, + std::unique_ptr>& cellToForgetWeightsTensor, + std::unique_ptr>& cellToOutputWeightsTensor, + std::unique_ptr>& inputGateBiasTensor, + std::unique_ptr>& forgetGateBiasTensor, + std::unique_ptr>& cellBiasTensor, + std::unique_ptr>& outputGateBiasTensor, + std::unique_ptr>& projectionWeightsTensor, + std::unique_ptr>& projectionBiasTensor, + std::unique_ptr>& inputLayerNormWeights, + std::unique_ptr>& forgetLayerNormWeights, + std::unique_ptr>& cellLayerNormWeights, + std::unique_ptr>& outputLayerNormWeights, + std::unique_ptr>& inputGateScratch, + std::unique_ptr>& cellScratch, + std::unique_ptr>& forgetGateScratch, + std::unique_ptr>& outputGateScratch, + std::unique_ptr>& inputGateScratchDecoder, + std::unique_ptr>& cellScratchDecoder, + std::unique_ptr>& forgetGateScratchDecoder, + std::unique_ptr>& outputGateScratchDecoder, + float layerNormEpsilon) +{ + // This is a porting of the LSTM::Eval() method in the Android code base + // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp + + const TensorShape& inputShape = inputInfo.GetShape(); + const DataType& outputType = outputInfo.GetDataType(); + + const uint32_t nBatch = inputShape[0]; + const uint32_t nInput = inputShape[1]; + + const uint32_t nCell = inputToOutputWeightsShape[0]; + const uint32_t nOutput = recurrentToOutputWeightsShape[1]; + + const bool useCifg = descriptor.m_CifgEnabled; + const bool usePeephole = descriptor.m_PeepholeEnabled; + const bool useLayerNorm = descriptor.m_LayerNormEnabled; + + if (!useLayerNorm) + { + // Initialize scratch buffers with bias. + if (!useCifg) + { + VectorBatchVectorAssign(*inputGateBiasTensor, + nCell, nBatch, *inputGateScratch); + } + VectorBatchVectorAssign(*forgetGateBiasTensor, + nCell, nBatch, *forgetGateScratch); + VectorBatchVectorAssign(*cellBiasTensor, + nCell, nBatch, *cellScratch); + VectorBatchVectorAssign(*outputGateBiasTensor, + nCell, nBatch, *outputGateScratch); + } + else + { + // Initialize scratch buffers with zeroes. + if (!useCifg) + { + ZeroVector(*inputGateScratch, nCell * nBatch); + } + ZeroVector(*forgetGateScratch, nCell * nBatch); + ZeroVector(*cellScratch , nCell * nBatch); + ZeroVector(*outputGateScratch, nCell * nBatch); + } + + // For each batch and cell: compute input_weight * input. + if (!useCifg) + { + MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsTensor, + nCell, nInput, *inputData, nBatch, *inputGateScratch); + } + MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsTensor, + nCell, nInput, *inputData, nBatch, *forgetGateScratch); + MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsTensor, + nCell, nInput, *inputData, nBatch, *cellScratch); + MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsTensor, + nCell, nInput, *inputData, nBatch, *outputGateScratch); + + // For each batch and cell: compute recurrent_weight * output_state. + if (!useCifg) + { + MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsTensor, + nCell, nOutput, *outputStateIn, nBatch, *inputGateScratch); + } + MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsTensor, + nCell, nOutput, *outputStateIn, nBatch, *forgetGateScratch); + MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsTensor, + nCell, nOutput, *outputStateIn, nBatch, *cellScratch); + MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsTensor, + nCell, nOutput, *outputStateIn, nBatch, *outputGateScratch); + + // For each batch and cell: update input gate. + if (!useCifg) + { + if (usePeephole) + { + VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsTensor, + nCell, *cellStateIn, nBatch, *inputGateScratch); + } + if (useLayerNorm) + { + MeanStddevNormalization(*inputGateScratchDecoder, + *inputGateScratch, nCell, nBatch, layerNormEpsilon); + VectorBatchVectorCwiseProduct(*inputLayerNormWeights, + nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch); + VectorBatchVectorAdd(*inputGateBiasTensor, + nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch); + } + Activation(*inputGateScratchDecoder, *inputGateScratch, + TensorInfo({nCell, nBatch}, outputType), + ActivationFunction::Sigmoid, 0, 0); + } + + // For each batch and cell: update forget gate. + if (usePeephole) + { + VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsTensor, nCell, + *cellStateIn, nBatch, *forgetGateScratch); + } + if (useLayerNorm) + { + MeanStddevNormalization(*forgetGateScratchDecoder, + *forgetGateScratch, nCell, nBatch, layerNormEpsilon); + VectorBatchVectorCwiseProduct(*forgetLayerNormWeights, + nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch); + VectorBatchVectorAdd(*forgetGateBiasTensor, + nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch); + } + Activation(*forgetGateScratchDecoder, *forgetGateScratch, + TensorInfo({nCell, nBatch}, outputType), + ActivationFunction::Sigmoid, 0, 0); + + // For each batch and cell: update the cell. + if (useLayerNorm) + { + MeanStddevNormalization(*cellScratchDecoder, + *cellScratch, nCell, nBatch, layerNormEpsilon); + VectorBatchVectorCwiseProduct(*cellLayerNormWeights, + nCell, *cellScratchDecoder, nBatch, *cellScratch); + VectorBatchVectorAdd(*cellBiasTensor, + nCell, *cellScratchDecoder, nBatch, *cellScratch); + } + + VectorVectorCwiseProduct(*forgetGateScratchDecoder, *cellStateIn, nBatch * nCell, *cellStateOut); + + ActivationFunction armnnActivationFunc = ActivationFunction::Sigmoid; + float a = 0; + float b = 0; + SetActivationParameters(descriptor.m_ActivationFunc, armnnActivationFunc, a, b); + + if (descriptor.m_ActivationFunc > 0) + { + Activation(*cellScratchDecoder, *cellScratch, + TensorInfo({nCell, nBatch}, outputType), + armnnActivationFunc, a, b); + } + if (useCifg) + { + Sub1Vector(*forgetGateScratchDecoder, nBatch * nCell, *forgetGateScratch); + VectorVectorCwiseProductAccumulate( + *cellScratchDecoder, *forgetGateScratchDecoder, nBatch * nCell, *cellStateOut); + } + else + { + VectorVectorCwiseProductAccumulate( + *cellScratchDecoder, *inputGateScratchDecoder, nBatch * nCell, *cellStateOut); + } + if (descriptor.m_ClippingThresCell > 0.0) + { + ClipVector(*cellStateOutDecoder, nBatch * nCell, descriptor.m_ClippingThresCell, *cellStateOut); + } + + // For each batch and cell: update the output gate. + if (usePeephole) + { + VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsTensor, + nCell, *cellStateOutDecoder, nBatch, *outputGateScratch); + } + if (useLayerNorm) + { + MeanStddevNormalization(*outputGateScratchDecoder, + *outputGateScratch, nCell, nBatch, layerNormEpsilon); + VectorBatchVectorCwiseProduct(*outputLayerNormWeights, + nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch); + VectorBatchVectorAdd(*outputGateBiasTensor, + nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch); + } + Activation(*outputGateScratchDecoder, *outputGateScratch, + TensorInfo({nCell, nBatch}, outputType), + ActivationFunction::Sigmoid, 0, 0); + + if (descriptor.m_ActivationFunc > 0) + { + Activation(*cellStateOutDecoder, *cellScratch, + TensorInfo({nCell, nBatch}, outputType), + armnnActivationFunc, a, b); + } + + VectorVectorCwiseProduct(*outputGateScratchDecoder, *cellScratchDecoder, nBatch * nCell, *outputGateScratch); + + // For each batch: update the projection and output_state. + if (descriptor.m_ProjectionEnabled) + { + if (projectionBiasTensor) + { + VectorBatchVectorAssign(*projectionBiasTensor, + nOutput, nBatch, *output); + } + MatrixBatchVectorMultiplyAccumulate(*projectionWeightsTensor, + nOutput, nCell, *outputGateScratchDecoder, nBatch, *output); + + if (descriptor.m_ClippingThresProj > 0.0) + { + ClipVector(*outputDecoder, nBatch * nOutput, descriptor.m_ClippingThresProj, *output); + } + } + else + { + CopyVector(*outputGateScratchDecoder, nBatch * nOutput, *output); + } + + CopyVector(*outputDecoder, nBatch * nOutput, *outputStateOut); +} + +} //namespace armnn diff --git a/src/backends/reference/workloads/Lstm.hpp b/src/backends/reference/workloads/Lstm.hpp new file mode 100644 index 0000000000..7d0a1d436e --- /dev/null +++ b/src/backends/reference/workloads/Lstm.hpp @@ -0,0 +1,61 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include +#include + +#include "Encoders.hpp" +#include "Decoders.hpp" + +namespace armnn +{ + +void LstmImpl(const LstmDescriptor& descriptor, + const TensorInfo& inputInfo, + const TensorInfo& outputInfo, + const TensorShape& inputToOutputWeightsShape, + const TensorShape& recurrentToOutputWeightsShape, + std::unique_ptr>& inputData, + std::unique_ptr>& outputStateIn, + std::unique_ptr>& cellStateIn, + std::unique_ptr>& outputStateOut, + std::unique_ptr>& cellStateOut, + std::unique_ptr>& output, + std::unique_ptr>& cellStateOutDecoder, + std::unique_ptr>& outputDecoder, + std::unique_ptr>& inputToInputWeightsTensor, + std::unique_ptr>& inputToForgetWeightsTensor, + std::unique_ptr>& inputToCellWeightsTensor, + std::unique_ptr>& inputToOutputWeightsTensor, + std::unique_ptr>& recurrentToInputWeightsTensor, + std::unique_ptr>& recurrentToForgetWeightsTensor, + std::unique_ptr>& recurrentToCellWeightsTensor, + std::unique_ptr>& recurrentToOutputWeightsTensor, + std::unique_ptr>& cellToInputWeightsTensor, + std::unique_ptr>& cellToForgetWeightsTensor, + std::unique_ptr>& cellToOutputWeightsTensor, + std::unique_ptr>& inputGateBiasTensor, + std::unique_ptr>& forgetGateBiasTensor, + std::unique_ptr>& cellBiasTensor, + std::unique_ptr>& outputGateBiasTensor, + std::unique_ptr>& projectionWeightsTensor, + std::unique_ptr>& projectionBiasTensor, + std::unique_ptr>& inputLayerNormWeights, + std::unique_ptr>& forgetLayerNormWeights, + std::unique_ptr>& cellLayerNormWeights, + std::unique_ptr>& outputLayerNormWeights, + std::unique_ptr>& inputGateScratch, + std::unique_ptr>& cellScratch, + std::unique_ptr>& forgetGateScratch, + std::unique_ptr>& outputGateScratch, + std::unique_ptr>& inputGateScratchDecoder, + std::unique_ptr>& cellScratchDecoder, + std::unique_ptr>& forgetGateScratchDecoder, + std::unique_ptr>& outputGateScratchDecoder, + float layerNormEpsilon); + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefLstmWorkload.cpp b/src/backends/reference/workloads/RefLstmWorkload.cpp index 3ddfd334b8..1ff6f50ed5 100644 --- a/src/backends/reference/workloads/RefLstmWorkload.cpp +++ b/src/backends/reference/workloads/RefLstmWorkload.cpp @@ -7,6 +7,7 @@ #include "Activation.hpp" #include "Encoders.hpp" #include "Decoders.hpp" +#include "Lstm.hpp" #include "LstmUtils.hpp" #include "RefWorkloadUtils.hpp" @@ -57,7 +58,6 @@ void RefLstmWorkload::Execute(std::vector inputs, std::vector> outputStateOut = MakeEncoder(outputInfo, outputs[1]->Map()); std::unique_ptr> cellStateOut = MakeEncoder(outputInfo, outputs[2]->Map()); @@ -71,10 +71,7 @@ void RefLstmWorkload::Execute(std::vector inputs, std::vector> cellStateIn = MakeDecoder(inputInfo, inputs[2]->Map()); const uint32_t nBatch = inputShape[0]; - const uint32_t nInput = inputShape[1]; - const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0]; - const uint32_t nOutput = m_RecurrentToOutputWeightsTensor->GetShape()[1]; const bool useCifg = m_Data.m_Parameters.m_CifgEnabled; const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled; @@ -154,6 +151,9 @@ void RefLstmWorkload::Execute(std::vector inputs, std::vector> cellLayerNormWeights; std::unique_ptr> outputLayerNormWeights; + const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape(); + const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape(); + if (useLayerNorm) { if (!useCifg) @@ -204,190 +204,49 @@ void RefLstmWorkload::Execute(std::vector inputs, std::vector 0) - { - Activation(*cellScratchDecoder, *cellScratch, - TensorInfo({nCell, nBatch}, outputType), - armnnActivationFunc, a, b); - } - if (useCifg) - { - Sub1Vector(*forgetGateScratchDecoder, nBatch * nCell, *forgetGateScratch); - VectorVectorCwiseProductAccumulate( - *cellScratchDecoder, *forgetGateScratchDecoder, nBatch * nCell, *cellStateOut); - } - else - { - VectorVectorCwiseProductAccumulate( - *cellScratchDecoder, *inputGateScratchDecoder, nBatch * nCell, *cellStateOut); - } - if (m_Data.m_Parameters.m_ClippingThresCell > 0.0) - { - ClipVector(*cellStateOutDecoder, nBatch * nCell, m_Data.m_Parameters.m_ClippingThresCell, *cellStateOut); - } - - // For each batch and cell: update the output gate. - if (usePeephole) - { - VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsTensor, - nCell, *cellStateOutDecoder, nBatch, *outputGateScratch); - } - if (useLayerNorm) - { - MeanStddevNormalization(*outputGateScratchDecoder, - *outputGateScratch, nCell, nBatch, m_LayerNormEpsilon); - VectorBatchVectorCwiseProduct(*outputLayerNormWeights, - nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch); - VectorBatchVectorAdd(*outputGateBiasTensor, - nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch); - } - Activation(*outputGateScratchDecoder, *outputGateScratch, - TensorInfo({nCell, nBatch}, outputType), - ActivationFunction::Sigmoid, 0, 0); - - if (m_Data.m_Parameters.m_ActivationFunc > 0) - { - Activation(*cellStateOutDecoder, *cellScratch, - TensorInfo({nCell, nBatch}, outputType), - armnnActivationFunc, a, b); - } - - VectorVectorCwiseProduct(*outputGateScratchDecoder, *cellScratchDecoder, nBatch * nCell, *outputGateScratch); - - // For each batch: update the projection and output_state. - if (m_Data.m_Parameters.m_ProjectionEnabled) - { - if (m_ProjectionBiasTensor) - { - VectorBatchVectorAssign(*projectionBiasTensor, - nOutput, nBatch, *output); - } - MatrixBatchVectorMultiplyAccumulate(*projectionWeightsTensor, - nOutput, nCell, *outputGateScratchDecoder, nBatch, *output); - - if (m_Data.m_Parameters.m_ClippingThresProj > 0.0) - { - ClipVector(*outputDecoder, nBatch * nOutput, m_Data.m_Parameters.m_ClippingThresProj, *output); - } - } - else - { - CopyVector(*outputGateScratchDecoder, nBatch * nOutput, *output); - } - - CopyVector(*outputDecoder, nBatch * nOutput, *outputStateOut); + LstmImpl(m_Data.m_Parameters, + inputInfo, + outputInfo, + inputToOutputWeightsShape, + recurrentToOutputWeightsShape, + inputData, + outputStateIn, + cellStateIn, + outputStateOut, + cellStateOut, + output, + cellStateOutDecoder, + outputDecoder, + inputToInputWeightsTensor, + inputToForgetWeightsTensor, + inputToCellWeightsTensor, + inputToOutputWeightsTensor, + recurrentToInputWeightsTensor, + recurrentToForgetWeightsTensor, + recurrentToCellWeightsTensor, + recurrentToOutputWeightsTensor, + cellToInputWeightsTensor, + cellToForgetWeightsTensor, + cellToOutputWeightsTensor, + inputGateBiasTensor, + forgetGateBiasTensor, + cellBiasTensor, + outputGateBiasTensor, + projectionWeightsTensor, + projectionBiasTensor, + inputLayerNormWeights, + forgetLayerNormWeights, + cellLayerNormWeights, + outputLayerNormWeights, + inputGateScratch, + cellScratch, + forgetGateScratch, + outputGateScratch, + inputGateScratchDecoder, + cellScratchDecoder, + forgetGateScratchDecoder, + outputGateScratchDecoder, + m_LayerNormEpsilon); } } //namespace armnn diff --git a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp new file mode 100644 index 0000000000..311fa18f91 --- /dev/null +++ b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp @@ -0,0 +1,307 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefUnidirectionalSequenceLstmWorkload.hpp" +#include "Activation.hpp" +#include "Encoders.hpp" +#include "Decoders.hpp" +#include "Lstm.hpp" +#include "LstmUtils.hpp" +#include "RefWorkloadUtils.hpp" + +#include + +namespace armnn +{ + +RefUnidirectionalSequenceLstmWorkload::RefUnidirectionalSequenceLstmWorkload( + const UnidirectionalSequenceLstmQueueDescriptor& descriptor, + const WorkloadInfo& info) + : BaseWorkload(descriptor, info) + , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights)) + , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights)) + , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights)) + , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights)) + , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights)) + , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights)) + , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights)) + , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights)) + , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights)) + , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights)) + , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights)) + , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias)) + , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias)) + , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias)) + , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias)) + , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights)) + , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias)) + , m_InputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights)) + , m_ForgetLayerNormWeights (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights)) + , m_CellLayerNormWeights (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights)) + , m_OutputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights)) +{} + +void RefUnidirectionalSequenceLstmWorkload::Execute() const +{ + Execute(m_Data.m_Inputs, m_Data.m_Outputs); +} + +void RefUnidirectionalSequenceLstmWorkload::ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) +{ + Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs); +} + +void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector inputs, + std::vector outputs) const +{ + TensorInfo inputInfo = GetTensorInfo(inputs[0]); + const TensorInfo& outputStateInfo = GetTensorInfo(inputs[1]); + const TensorInfo& cellStateInfo = GetTensorInfo(inputs[2]); + TensorInfo outputInfo = GetTensorInfo(outputs[0]); + TensorShape& inputShape = inputInfo.GetShape(); + TensorShape& outputShape= outputInfo.GetShape(); + auto inputTensor = reinterpret_cast(inputs[0]->Map()); + + if (!m_Data.m_Parameters.m_TimeMajor) + { + // Permute to time major + const PermutationVector& mappings = {1U, 0U, 2U}; + std::vector inputValue(inputTensor, inputTensor + inputInfo.GetNumElements()); + inputShape = armnnUtils::Permuted(inputInfo.GetShape(), mappings); + inputInfo.SetShape(inputShape); + armnnUtils::Permute(inputShape, mappings, inputValue.data(), inputTensor, sizeof(float)); + + outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings); + outputInfo.SetShape(outputShape); + } + unsigned int maxTime = inputShape[0]; + unsigned int batchSize = inputShape[1]; + unsigned int outputSize = outputShape[2]; + unsigned int inputSize = inputShape[2]; + + TensorInfo scratchInfo = outputInfo; + scratchInfo.SetShape({batchSize, cellStateInfo.GetShape()[1]}); + + std::vector inputGateScratchBuffer; + std::vector cellScratchBuffer(scratchInfo.GetNumElements(), 0.); + std::vector forgetGateScratchBuffer(scratchInfo.GetNumElements(), 0.); + std::vector outputGateScratchBuffer(scratchInfo.GetNumElements(), 0.); + + std::vector outputStateOutBuffer(outputStateInfo.GetNumElements(), 0.); + std::vector cellStateOutBuffer(cellStateInfo.GetNumElements(), 0.); + + void* outputStateOutData = outputStateOutBuffer.data(); + void* cellStateOutData = cellStateOutBuffer.data(); + + std::unique_ptr> inputGateScratch; + std::unique_ptr> cellScratch = MakeEncoder(scratchInfo, cellScratchBuffer.data()); + std::unique_ptr> forgetGateScratch = MakeEncoder(scratchInfo, forgetGateScratchBuffer.data()); + std::unique_ptr> outputGateScratch = MakeEncoder(scratchInfo, outputGateScratchBuffer.data()); + + std::unique_ptr> inputGateScratchDecoder; + std::unique_ptr> cellScratchDecoder = MakeDecoder(scratchInfo, cellScratchBuffer.data()); + std::unique_ptr> forgetGateScratchDecoder = MakeDecoder(scratchInfo, + forgetGateScratchBuffer.data()); + std::unique_ptr> outputGateScratchDecoder = MakeDecoder(scratchInfo, + outputGateScratchBuffer.data()); + + const bool useCifg = m_Data.m_Parameters.m_CifgEnabled; + const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled; + const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled; + + if (!useCifg) + { + inputGateScratchBuffer.resize(scratchInfo.GetNumElements(), 0.); + inputGateScratch = MakeEncoder(scratchInfo, inputGateScratchBuffer.data()); + inputGateScratchDecoder = MakeDecoder(scratchInfo, inputGateScratchBuffer.data()); + } + + std::unique_ptr> outputStateOut = MakeEncoder(outputStateInfo, outputStateOutData); + std::unique_ptr> cellStateOut = MakeEncoder(cellStateInfo, cellStateOutData); + std::unique_ptr> cellStateOutDecoder = MakeDecoder(cellStateInfo, cellStateOutData); + + TensorInfo lstmInputInfo = inputInfo; + TensorShape batchInputShape = TensorShape({batchSize, inputSize}); + lstmInputInfo.SetShape(batchInputShape); + + TensorInfo lstmOutputInfo = outputInfo; + lstmOutputInfo.SetShape({batchSize, outputSize}); + + const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape(); + const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape(); + unsigned int nOutput = recurrentToOutputWeightsShape[1]; + auto outputStateInData = inputs[1]->Map(); + std::unique_ptr> outputStateIn = MakeDecoder(outputStateInfo, outputStateInData); + + auto cellStateInData = inputs[2]->Map(); + std::unique_ptr> cellStateIn = MakeDecoder(cellStateInfo, cellStateInData); + + auto currentInputData = reinterpret_cast(inputs[0]->Map()); + std::unique_ptr> inputData = MakeDecoder(lstmInputInfo, currentInputData); + auto currentOutputData = reinterpret_cast(outputs[0]->Map()); + std::unique_ptr> output = MakeEncoder(lstmOutputInfo, currentOutputData); + std::unique_ptr> outputDecoder = MakeDecoder(lstmOutputInfo, currentOutputData); + + std::unique_ptr> inputToInputWeightsTensor; + std::unique_ptr> inputToForgetWeightsTensor = MakeDecoder( + m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor()); + std::unique_ptr> inputToCellWeightsTensor = MakeDecoder( + m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor()); + std::unique_ptr> inputToOutputWeightsTensor = MakeDecoder( + m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor()); + + std::unique_ptr> recurrentToInputWeightsTensor; + std::unique_ptr> recurrentToForgetWeightsTensor = MakeDecoder( + m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor()); + std::unique_ptr> recurrentToCellWeightsTensor = MakeDecoder( + m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor()); + std::unique_ptr> recurrentToOutputWeightsTensor = MakeDecoder( + m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor()); + + std::unique_ptr> inputGateBiasTensor; + std::unique_ptr> forgetGateBiasTensor = MakeDecoder( + m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor()); + std::unique_ptr> cellBiasTensor = MakeDecoder( + m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor()); + std::unique_ptr> outputGateBiasTensor = MakeDecoder( + m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor()); + + std::unique_ptr> cellToInputWeightsTensor; + std::unique_ptr> cellToForgetWeightsTensor; + std::unique_ptr> cellToOutputWeightsTensor; + + std::unique_ptr> projectionWeightsTensor; + std::unique_ptr> projectionBiasTensor; + + std::unique_ptr> inputLayerNormWeights; + std::unique_ptr> forgetLayerNormWeights; + std::unique_ptr> cellLayerNormWeights; + std::unique_ptr> outputLayerNormWeights; + + if (useLayerNorm) + { + if (!useCifg) + { + inputLayerNormWeights = MakeDecoder( + m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor()); + } + forgetLayerNormWeights = MakeDecoder( + m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor()); + cellLayerNormWeights = MakeDecoder( + m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor()); + outputLayerNormWeights = MakeDecoder( + m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor()); + } + + if (!useCifg) + { + inputToInputWeightsTensor = MakeDecoder( + m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor()); + inputGateBiasTensor = MakeDecoder( + m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor()); + recurrentToInputWeightsTensor = MakeDecoder( + m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor()); + } + + if (usePeephole) + { + cellToForgetWeightsTensor = MakeDecoder( + m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor()); + cellToOutputWeightsTensor = MakeDecoder( + m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor()); + } + + if (!useCifg && usePeephole) + { + cellToInputWeightsTensor = MakeDecoder( + m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor()); + } + + if (m_Data.m_Parameters.m_ProjectionEnabled) + { + projectionWeightsTensor = MakeDecoder( + m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor()); + if (m_ProjectionBiasTensor) + { + projectionBiasTensor = MakeDecoder( + m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor()); + } + } + + unsigned int batchInputSize = batchSize * inputSize; + unsigned int batchOutputSize = batchSize * nOutput; + + for (unsigned int t = 0; t < maxTime; ++t) + { + LstmImpl(m_Data.m_Parameters, + lstmInputInfo, + lstmOutputInfo, + inputToOutputWeightsShape, + recurrentToOutputWeightsShape, + inputData, + outputStateIn, + cellStateIn, + outputStateOut, + cellStateOut, + output, + cellStateOutDecoder, + outputDecoder, + inputToInputWeightsTensor, + inputToForgetWeightsTensor, + inputToCellWeightsTensor, + inputToOutputWeightsTensor, + recurrentToInputWeightsTensor, + recurrentToForgetWeightsTensor, + recurrentToCellWeightsTensor, + recurrentToOutputWeightsTensor, + cellToInputWeightsTensor, + cellToForgetWeightsTensor, + cellToOutputWeightsTensor, + inputGateBiasTensor, + forgetGateBiasTensor, + cellBiasTensor, + outputGateBiasTensor, + projectionWeightsTensor, + projectionBiasTensor, + inputLayerNormWeights, + forgetLayerNormWeights, + cellLayerNormWeights, + outputLayerNormWeights, + inputGateScratch, + cellScratch, + forgetGateScratch, + outputGateScratch, + inputGateScratchDecoder, + cellScratchDecoder, + forgetGateScratchDecoder, + outputGateScratchDecoder, + m_LayerNormEpsilon); + + currentInputData += batchInputSize; + inputData = MakeDecoder(lstmInputInfo, currentInputData); + currentOutputData += batchOutputSize; + output = MakeEncoder(lstmOutputInfo, currentOutputData); + outputDecoder = MakeDecoder(lstmOutputInfo, currentOutputData); + + // Assign output state out to the next output state in + outputStateIn = MakeDecoder(outputStateInfo, outputStateOutData); + + // Assign cell state out to the next cell state in + cellStateIn = MakeDecoder(cellStateInfo, cellStateOutData); + } + + if (!m_Data.m_Parameters.m_TimeMajor) + { + // Permute Output back to batch major + const PermutationVector& mappings = {1U, 0U, 2U}; + auto outputData = reinterpret_cast(outputs[0]->Map()); + std::vector outputValue(outputData, outputData + outputInfo.GetNumElements()); + outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings); + outputInfo.SetShape(outputShape); + armnnUtils::Permute(outputShape, mappings, outputValue.data(), outputData, sizeof(float)); + } +} + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.hpp b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.hpp new file mode 100644 index 0000000000..8ba7bdc0c6 --- /dev/null +++ b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.hpp @@ -0,0 +1,56 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include + +#include +#include + +#include "Encoders.hpp" +#include "Decoders.hpp" + +namespace armnn +{ + +class RefUnidirectionalSequenceLstmWorkload : public BaseWorkload +{ +public: + explicit RefUnidirectionalSequenceLstmWorkload(const UnidirectionalSequenceLstmQueueDescriptor& descriptor, + const WorkloadInfo& info); + + void Execute() const override; + void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; + + +private: + void Execute(std::vector inputs, std::vector outputs) const; + std::unique_ptr m_InputToInputWeightsTensor; + std::unique_ptr m_InputToForgetWeightsTensor; + std::unique_ptr m_InputToCellWeightsTensor; + std::unique_ptr m_InputToOutputWeightsTensor; + std::unique_ptr m_RecurrentToInputWeightsTensor; + std::unique_ptr m_RecurrentToForgetWeightsTensor; + std::unique_ptr m_RecurrentToCellWeightsTensor; + std::unique_ptr m_RecurrentToOutputWeightsTensor; + std::unique_ptr m_CellToInputWeightsTensor; + std::unique_ptr m_CellToForgetWeightsTensor; + std::unique_ptr m_CellToOutputWeightsTensor; + std::unique_ptr m_InputGateBiasTensor; + std::unique_ptr m_ForgetGateBiasTensor; + std::unique_ptr m_CellBiasTensor; + std::unique_ptr m_OutputGateBiasTensor; + std::unique_ptr m_ProjectionWeightsTensor; + std::unique_ptr m_ProjectionBiasTensor; + std::unique_ptr m_InputLayerNormWeights; + std::unique_ptr m_ForgetLayerNormWeights; + std::unique_ptr m_CellLayerNormWeights; + std::unique_ptr m_OutputLayerNormWeights; + + float m_LayerNormEpsilon = static_cast(1e-8); +}; + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index afe63d13c0..d3ae58ea15 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -69,6 +69,7 @@ #include "RefSpaceToDepthWorkload.hpp" #include "RefTransposeConvolution2dWorkload.hpp" #include "RefTransposeWorkload.hpp" +#include "RefUnidirectionalSequenceLstmWorkload.hpp" #include "RefWorkloadUtils.hpp" #include "Resize.hpp" #include "Softmax.hpp" -- cgit v1.2.1