From ad5293a86e315049de36afd723dcd1a7e70681a7 Mon Sep 17 00:00:00 2001 From: Jan Eilers Date: Mon, 8 Jul 2019 09:57:55 +0100 Subject: IVGCVSW-3337 Add Neon backend support for LSTM layer normalisation * Update neon lstm workload * Add unit tests * Add isLstmSupported Change-Id: I493c159137f6544b0f2532d16d4fafd7a7e587e5 Signed-off-by: Jan Eilers --- src/backends/neon/NeonLayerSupport.cpp | 25 ++++ src/backends/neon/NeonLayerSupport.hpp | 11 ++ src/backends/neon/test/NeonCreateWorkloadTests.cpp | 23 ++++ src/backends/neon/test/NeonLayerTests.cpp | 2 + .../neon/workloads/NeonLstmFloatWorkload.cpp | 148 ++++++++++++++------- .../neon/workloads/NeonLstmFloatWorkload.hpp | 22 +-- 6 files changed, 164 insertions(+), 67 deletions(-) diff --git a/src/backends/neon/NeonLayerSupport.cpp b/src/backends/neon/NeonLayerSupport.cpp index 4fee53f51f..ea875f6926 100644 --- a/src/backends/neon/NeonLayerSupport.cpp +++ b/src/backends/neon/NeonLayerSupport.cpp @@ -26,6 +26,7 @@ #include "workloads/NeonDequantizeWorkload.hpp" #include "workloads/NeonGreaterWorkload.hpp" #include "workloads/NeonL2NormalizationFloatWorkload.hpp" +#include "workloads/NeonLstmFloatWorkload.hpp" #include "workloads/NeonMaximumWorkload.hpp" #include "workloads/NeonMeanWorkload.hpp" #include "workloads/NeonConcatWorkload.hpp" @@ -334,6 +335,30 @@ bool NeonLayerSupport::IsL2NormalizationSupported(const TensorInfo& input, FORWARD_WORKLOAD_VALIDATE_FUNC(NeonL2NormalizationWorkloadValidate, reasonIfUnsupported, input, output, descriptor); } +bool NeonLayerSupport::IsLstmSupported(const TensorInfo& input, + const TensorInfo& outputStateIn, + const TensorInfo& cellStateIn, + const TensorInfo& scratchBuffer, + const TensorInfo& outputStateOut, + const TensorInfo& cellStateOut, + const TensorInfo& output, + const LstmDescriptor& descriptor, + const LstmInputParamsInfo& paramsInfo, + Optional reasonIfUnsupported) const +{ + FORWARD_WORKLOAD_VALIDATE_FUNC(NeonLstmFloatWorkloadValidate, + reasonIfUnsupported, + input, + outputStateIn, + cellStateIn, + scratchBuffer, + outputStateOut, + cellStateOut, + output, + descriptor, + paramsInfo); +} + bool NeonLayerSupport::IsMaximumSupported(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& output, diff --git a/src/backends/neon/NeonLayerSupport.hpp b/src/backends/neon/NeonLayerSupport.hpp index 315248c79d..318cad7424 100644 --- a/src/backends/neon/NeonLayerSupport.hpp +++ b/src/backends/neon/NeonLayerSupport.hpp @@ -96,6 +96,17 @@ public: const L2NormalizationDescriptor& descriptor, Optional reasonIfUnsupported = EmptyOptional()) const override; + bool IsLstmSupported(const TensorInfo& input, + const TensorInfo& outputStateIn, + const TensorInfo& cellStateIn, + const TensorInfo& scratchBuffer, + const TensorInfo& outputStateOut, + const TensorInfo& cellStateOut, + const TensorInfo& output, + const LstmDescriptor& descriptor, + const LstmInputParamsInfo& paramsInfo, + Optional reasonIfUnsupported = EmptyOptional()) const override; + bool IsMaximumSupported(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& output, diff --git a/src/backends/neon/test/NeonCreateWorkloadTests.cpp b/src/backends/neon/test/NeonCreateWorkloadTests.cpp index 4968d0ed90..49c5a72a90 100644 --- a/src/backends/neon/test/NeonCreateWorkloadTests.cpp +++ b/src/backends/neon/test/NeonCreateWorkloadTests.cpp @@ -710,6 +710,29 @@ BOOST_AUTO_TEST_CASE(CreateL2NormalizationNhwcWorkload) NeonCreateL2NormalizationWorkloadTest(DataLayout::NHWC); } +template +static void NeonCreateLstmWorkloadTest() +{ + Graph graph; + NeonWorkloadFactory factory = + NeonWorkloadFactoryHelper::GetFactory(NeonWorkloadFactoryHelper::GetMemoryManager()); + + auto workload = CreateLstmWorkloadTest(factory, graph); + + LstmQueueDescriptor queueDescriptor = workload->GetData(); + + auto inputHandle = boost::polymorphic_downcast(queueDescriptor.m_Inputs[0]); + auto outputHandle = boost::polymorphic_downcast(queueDescriptor.m_Outputs[1]); + + BOOST_TEST(TestNeonTensorHandleInfo(inputHandle, TensorInfo({ 2, 2 }, DataType::Float32))); + BOOST_TEST(TestNeonTensorHandleInfo(outputHandle, TensorInfo({ 2, 4 }, DataType::Float32))); +} + +BOOST_AUTO_TEST_CASE(CreateLSTMWorkloadFloatWorkload) +{ + NeonCreateLstmWorkloadTest(); +} + template static void NeonCreateConcatWorkloadTest(std::initializer_list outputShape, unsigned int concatAxis) diff --git a/src/backends/neon/test/NeonLayerTests.cpp b/src/backends/neon/test/NeonLayerTests.cpp index 51fd219365..049680aafe 100644 --- a/src/backends/neon/test/NeonLayerTests.cpp +++ b/src/backends/neon/test/NeonLayerTests.cpp @@ -469,6 +469,8 @@ ARMNN_AUTO_TEST_CASE(LstmLayerFloat32NoCifgNoPeepholeNoProjection, LstmLayerFloat32NoCifgNoPeepholeNoProjectionTest) ARMNN_AUTO_TEST_CASE(LstmLayerFloat32NoCifgWithPeepholeWithProjection, LstmLayerFloat32NoCifgWithPeepholeWithProjectionTest) +ARMNN_AUTO_TEST_CASE(LstmLayerFloat32NoCifgWithPeepholeWithProjectionWithLayerNorm, + LstmLayerFloat32NoCifgWithPeepholeWithProjectionWithLayerNormTest) // Mean ARMNN_AUTO_TEST_CASE(MeanSimpleFloat32, MeanSimpleTest) diff --git a/src/backends/neon/workloads/NeonLstmFloatWorkload.cpp b/src/backends/neon/workloads/NeonLstmFloatWorkload.cpp index c7f5f090ce..6dd9f4f698 100644 --- a/src/backends/neon/workloads/NeonLstmFloatWorkload.cpp +++ b/src/backends/neon/workloads/NeonLstmFloatWorkload.cpp @@ -97,6 +97,30 @@ NeonLstmFloatWorkload::NeonLstmFloatWorkload(const LstmQueueDescriptor &descript lstm_param.set_peephole_params(m_CellToForgetWeightsTensor.get(), m_CellToOutputWeightsTensor.get()); } + if (m_Data.m_Parameters.m_LayerNormEnabled) + { + m_InputLayerNormWeightsTensor = std::make_unique(); + if (!m_Data.m_Parameters.m_CifgEnabled) + { + BuildArmComputeTensor(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights->GetTensorInfo()); + } + + m_ForgetLayerNormWeightsTensor = std::make_unique(); + BuildArmComputeTensor(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights->GetTensorInfo()); + + m_CellLayerNormWeightsTensor = std::make_unique(); + BuildArmComputeTensor(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights->GetTensorInfo()); + + m_OutputLayerNormWeightsTensor = std::make_unique(); + BuildArmComputeTensor(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights->GetTensorInfo()); + + lstm_param.set_layer_normalization_params(m_Data.m_Parameters.m_CifgEnabled ? + nullptr : m_InputLayerNormWeightsTensor.get(), + m_ForgetLayerNormWeightsTensor.get(), + m_CellLayerNormWeightsTensor.get(), + m_OutputLayerNormWeightsTensor.get()); + } + const arm_compute::ITensor& input = static_cast(m_Data.m_Inputs[0])->GetTensor(); const arm_compute::ITensor& output_state_in = static_cast(m_Data.m_Inputs[1])->GetTensor(); const arm_compute::ITensor& cell_state_in = static_cast(m_Data.m_Inputs[2])->GetTensor(); @@ -113,13 +137,13 @@ NeonLstmFloatWorkload::NeonLstmFloatWorkload(const LstmQueueDescriptor &descript m_ScratchBuffer = std::make_unique(); if (m_Data.m_Parameters.m_CifgEnabled) { - // 2D tensor with dimensions [num_units * 4, batch_size] with CIFG + // 2D tensor with dimensions [num_units * 3, batch_size] with CIFG armnn::TensorInfo scratchBuffer1({ batch_size, num_units * 3 }, DataType::Float32); BuildArmComputeTensor(*m_ScratchBuffer, scratchBuffer1); } else { - // scratch_buffer [num_units * 3, batch_size] without CIFG + // scratch_buffer [num_units * 4, batch_size] without CIFG armnn::TensorInfo scratchBuffer2({ batch_size, num_units * 4 }, DataType::Float32); BuildArmComputeTensor(*m_ScratchBuffer, scratchBuffer2); } @@ -222,6 +246,17 @@ NeonLstmFloatWorkload::NeonLstmFloatWorkload(const LstmQueueDescriptor &descript m_Data.m_CellToOutputWeights); } + if (m_Data.m_Parameters.m_LayerNormEnabled) + { + if (!m_Data.m_Parameters.m_CifgEnabled) + { + InitializeArmComputeTensorData(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights); + } + InitializeArmComputeTensorData(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights); + InitializeArmComputeTensorData(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights); + InitializeArmComputeTensorData(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights); + } + // Force Compute Library to perform the necessary copying and reshaping, after which // delete all the input tensors that will no longer be needed m_LstmLayer.prepare(); @@ -241,27 +276,11 @@ arm_compute::Status NeonLstmFloatWorkloadValidate(const TensorInfo& input, const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor& descriptor, - const TensorInfo& inputToForgetWeights, - const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, - const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, - const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, - const TensorInfo& cellBias, - const TensorInfo& outputGateBias, - const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, - const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, - const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, - const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights) + const LstmInputParamsInfo& paramsInfo) { arm_compute::LSTMParams lstm_params_info; - // The inputs and the outputs + // The inputs and outputs const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input); const arm_compute::TensorInfo aclOutputStateInInfo = BuildArmComputeTensorInfo(outputStateIn); const arm_compute::TensorInfo aclCellStateInInfo = BuildArmComputeTensorInfo(cellStateIn); @@ -271,18 +290,24 @@ arm_compute::Status NeonLstmFloatWorkloadValidate(const TensorInfo& input, const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output); // Basic parameters - const arm_compute::TensorInfo aclInputToForgetWeightsInfo = BuildArmComputeTensorInfo(inputToForgetWeights); - const arm_compute::TensorInfo aclInputToCellWeightsInfo = BuildArmComputeTensorInfo(inputToCellWeights); - const arm_compute::TensorInfo aclInputToOutputWeightsInfo = BuildArmComputeTensorInfo(inputToOutputWeights); + const arm_compute::TensorInfo aclInputToForgetWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.get_InputToForgetWeights()); + const arm_compute::TensorInfo aclInputToCellWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.get_InputToCellWeights()); + const arm_compute::TensorInfo aclInputToOutputWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.get_InputToOutputWeights()); const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo - = BuildArmComputeTensorInfo(recurrentToForgetWeights); + = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToForgetWeights()); const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo - = BuildArmComputeTensorInfo(recurrentToCellWeights); + = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToCellWeights()); const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo - = BuildArmComputeTensorInfo(recurrentToOutputWeights); - const arm_compute::TensorInfo aclForgetGateBiasInfo = BuildArmComputeTensorInfo(forgetGateBias); - const arm_compute::TensorInfo aclCellBiasInfo = BuildArmComputeTensorInfo(cellBias); - const arm_compute::TensorInfo aclOutputGateBiasInfo = BuildArmComputeTensorInfo(outputGateBias); + = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToOutputWeights()); + const arm_compute::TensorInfo aclForgetGateBiasInfo + = BuildArmComputeTensorInfo(paramsInfo.get_ForgetGateBias()); + const arm_compute::TensorInfo aclCellBiasInfo + = BuildArmComputeTensorInfo(paramsInfo.get_CellBias()); + const arm_compute::TensorInfo aclOutputGateBiasInfo + = BuildArmComputeTensorInfo(paramsInfo.get_OutputGateBias()); arm_compute::TensorInfo aclInputToInputWeightsInfo; arm_compute::TensorInfo aclRecurrentToInputWeightsInfo; @@ -293,48 +318,65 @@ arm_compute::Status NeonLstmFloatWorkloadValidate(const TensorInfo& input, arm_compute::TensorInfo aclCellToForgetWeightsInfo; arm_compute::TensorInfo aclCellToOutputWeightsInfo; + arm_compute::TensorInfo aclInputLayerNormWeightsInfo; + arm_compute::TensorInfo aclForgetLayerNormWeightsInfo; + arm_compute::TensorInfo aclCellLayerNormWeightsInfo; + arm_compute::TensorInfo aclOutputLayerNormWeightsInfo; + + if (!descriptor.m_CifgEnabled) { - armnn::TensorInfo inputToInputWInfo = *inputToInputWeights; - aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(inputToInputWInfo); - armnn::TensorInfo recurrentToInputWInfo = *recurrentToInputWeights; - aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(recurrentToInputWInfo); - - if (cellToInputWeights != nullptr) + if (descriptor.m_PeepholeEnabled) { - armnn::TensorInfo cellToInputWInfo = *cellToInputWeights; - aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(cellToInputWInfo); + aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToInputWeights()); } - armnn::TensorInfo inputGateBiasInfo = *inputGateBias; - aclInputGateBiasInfo = BuildArmComputeTensorInfo(inputGateBiasInfo); + aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputToInputWeights()); + aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToInputWeights()); + aclInputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputGateBias()); + lstm_params_info.set_cifg_params(&aclInputToInputWeightsInfo, &aclRecurrentToInputWeightsInfo, - cellToInputWeights != nullptr ? &aclCellToInputWeightsInfo: nullptr, + descriptor.m_PeepholeEnabled ? &aclCellToInputWeightsInfo : nullptr, &aclInputGateBiasInfo); } if (descriptor.m_ProjectionEnabled) { - const armnn::TensorInfo& projectionWInfo = *projectionWeights; - aclProjectionWeightsInfo = BuildArmComputeTensorInfo(projectionWInfo); - - if (projectionBias != nullptr) + if (paramsInfo.m_ProjectionBias != nullptr) { - const armnn::TensorInfo& projectionBiasInfo = *projectionBias; - aclProjectionBiasInfo = BuildArmComputeTensorInfo(projectionBiasInfo); + aclProjectionBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_ProjectionBias()); } + aclProjectionWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_ProjectionWeights()); + lstm_params_info.set_projection_params(&aclProjectionWeightsInfo, - projectionBias != nullptr ? &aclProjectionBiasInfo: nullptr); + paramsInfo.m_ProjectionBias != nullptr ? + &aclProjectionBiasInfo : nullptr); } if (descriptor.m_PeepholeEnabled) { - const armnn::TensorInfo& cellToForgetWInfo = *cellToForgetWeights; - aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(cellToForgetWInfo); - const armnn::TensorInfo& cellToOutputWInfo = *cellToOutputWeights; - aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(cellToOutputWInfo); + aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToForgetWeights()); + aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToOutputWeights()); + lstm_params_info.set_peephole_params(&aclCellToForgetWeightsInfo, &aclCellToOutputWeightsInfo); } + if (descriptor.m_LayerNormEnabled) + { + if (!descriptor.m_CifgEnabled) + { + aclInputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputLayerNormWeights()); + } + aclForgetLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_ForgetLayerNormWeights()); + aclCellLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellLayerNormWeights()); + aclOutputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_OutputLayerNormWeights()); + + lstm_params_info.set_layer_normalization_params(descriptor.m_CifgEnabled ? + nullptr : &aclInputLayerNormWeightsInfo, + &aclForgetLayerNormWeightsInfo, + &aclCellLayerNormWeightsInfo, + &aclOutputLayerNormWeightsInfo); + } + float cell_threshold = descriptor.m_ClippingThresCell; float projection_threshold = descriptor.m_ClippingThresProj; @@ -407,6 +449,10 @@ void NeonLstmFloatWorkload::FreeUnusedTensors() FreeTensorIfUnused(m_ProjectionWeightsTensor); FreeTensorIfUnused(m_ProjectionBiasTensor); FreeTensorIfUnused(m_ScratchBuffer); + FreeTensorIfUnused(m_InputLayerNormWeightsTensor); + FreeTensorIfUnused(m_ForgetLayerNormWeightsTensor); + FreeTensorIfUnused(m_CellLayerNormWeightsTensor); + FreeTensorIfUnused(m_OutputLayerNormWeightsTensor); } } //namespace armnn diff --git a/src/backends/neon/workloads/NeonLstmFloatWorkload.hpp b/src/backends/neon/workloads/NeonLstmFloatWorkload.hpp index f87f24d88a..c116cdd967 100644 --- a/src/backends/neon/workloads/NeonLstmFloatWorkload.hpp +++ b/src/backends/neon/workloads/NeonLstmFloatWorkload.hpp @@ -43,6 +43,11 @@ private: std::unique_ptr m_ScratchBuffer; + std::unique_ptr m_InputLayerNormWeightsTensor; + std::unique_ptr m_ForgetLayerNormWeightsTensor; + std::unique_ptr m_CellLayerNormWeightsTensor; + std::unique_ptr m_OutputLayerNormWeightsTensor; + void FreeUnusedTensors(); }; @@ -50,21 +55,6 @@ arm_compute::Status NeonLstmFloatWorkloadValidate(const TensorInfo& input, const const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer, const TensorInfo& outputStateOut, const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor &descriptor, - const TensorInfo& inputToForgetWeights, - const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, - const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, - const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, const TensorInfo& cellBias, - const TensorInfo& outputGateBias, - const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, - const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, - const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, - const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights); + const LstmInputParamsInfo& paramsInfo); } //namespace armnn -- cgit v1.2.1