From 7bdcc2e3073accbd6e7884a4b7d0d016253e7412 Mon Sep 17 00:00:00 2001 From: Ryan OShea Date: Wed, 13 May 2020 16:36:19 +0100 Subject: IVGCVSW-4450 Add CL Enhanced Quantized LSTM Workload * Adds QLstm CL workload * Added Layer and CreateWorkload tests Signed-off-by: Ryan OShea Signed-off-by: James Conroy Change-Id: I32335e528467bfd619edb249d2971705ac2a6163 --- src/backends/cl/ClLayerSupport.cpp | 39 ++- src/backends/cl/ClLayerSupport.hpp | 10 + src/backends/cl/ClWorkloadFactory.cpp | 6 + src/backends/cl/ClWorkloadFactory.hpp | 3 + src/backends/cl/backend.mk | 1 + src/backends/cl/test/ClCreateWorkloadTests.cpp | 28 ++ src/backends/cl/test/ClLayerTests.cpp | 4 + src/backends/cl/workloads/CMakeLists.txt | 2 + src/backends/cl/workloads/ClQLstmWorkload.cpp | 401 +++++++++++++++++++++++++ src/backends/cl/workloads/ClQLstmWorkload.hpp | 66 ++++ src/backends/cl/workloads/ClWorkloadUtils.hpp | 3 + src/backends/cl/workloads/ClWorkloads.hpp | 1 + 12 files changed, 562 insertions(+), 2 deletions(-) create mode 100644 src/backends/cl/workloads/ClQLstmWorkload.cpp create mode 100644 src/backends/cl/workloads/ClQLstmWorkload.hpp diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp index eb68a80765..7418dbd9e4 100644 --- a/src/backends/cl/ClLayerSupport.cpp +++ b/src/backends/cl/ClLayerSupport.cpp @@ -47,11 +47,12 @@ #include "workloads/ClPermuteWorkload.hpp" #include "workloads/ClPooling2dWorkload.hpp" #include "workloads/ClPreluWorkload.hpp" +#include "workloads/ClQLstmWorkload.hpp" +#include "workloads/ClQuantizedLstmWorkload.hpp" +#include "workloads/ClQuantizeWorkload.hpp" #include "workloads/ClReshapeWorkload.hpp" #include "workloads/ClResizeWorkload.hpp" #include "workloads/ClRsqrtWorkload.hpp" -#include "workloads/ClQuantizedLstmWorkload.hpp" -#include "workloads/ClQuantizeWorkload.hpp" #include "workloads/ClSliceWorkload.hpp" #include "workloads/ClSoftmaxWorkload.hpp" #include "workloads/ClSpaceToBatchNdWorkload.hpp" @@ -618,6 +619,40 @@ bool ClLayerSupport::IsPreluSupported(const armnn::TensorInfo &input, FORWARD_WORKLOAD_VALIDATE_FUNC(ClPreluWorkloadValidate, reasonIfUnsupported, input, alpha, output); } +bool ClLayerSupport::IsQLstmSupported(const TensorInfo& input, + const TensorInfo& previousOutputIn, + const TensorInfo& previousCellStateIn, + const TensorInfo& outputStateOut, + const TensorInfo& cellStateOut, + const TensorInfo& output, + const QLstmDescriptor& descriptor, + const LstmInputParamsInfo& paramsInfo, + Optional reasonIfUnsupported) const +{ + if (input.GetDataType() == armnn::DataType::QAsymmS8 && + previousOutputIn.GetDataType() == armnn::DataType::QAsymmS8 && + previousCellStateIn.GetDataType() == armnn::DataType::QSymmS16 && + outputStateOut.GetDataType() == armnn::DataType::QAsymmS8 && + cellStateOut.GetDataType() == armnn::DataType::QSymmS16 && + output.GetDataType() == armnn::DataType::QAsymmS8) + { + FORWARD_WORKLOAD_VALIDATE_FUNC(ClQLstmWorkloadValidate, + reasonIfUnsupported, + input, + previousCellStateIn, + previousOutputIn, + cellStateOut, + outputStateOut, + output, + descriptor, + paramsInfo); + } + else + { + return false; + } +} + bool ClLayerSupport::IsQuantizedLstmSupported(const TensorInfo& input, const TensorInfo& previousCellStateIn, const TensorInfo& previousOutputIn, diff --git a/src/backends/cl/ClLayerSupport.hpp b/src/backends/cl/ClLayerSupport.hpp index 60899d0596..d785f54387 100644 --- a/src/backends/cl/ClLayerSupport.hpp +++ b/src/backends/cl/ClLayerSupport.hpp @@ -203,6 +203,16 @@ public: const TensorInfo& output, Optional reasonIfUnsupported = EmptyOptional()) const override; + bool IsQLstmSupported(const TensorInfo& input, + const TensorInfo& previousOutputIn, + const TensorInfo& previousCellStateIn, + const TensorInfo& outputStateOut, + const TensorInfo& cellStateOut, + const TensorInfo& output, + const QLstmDescriptor& descriptor, + const LstmInputParamsInfo& paramsInfo, + Optional reasonIfUnsupported = EmptyOptional()) const override; + bool IsQuantizedLstmSupported(const TensorInfo& input, const TensorInfo& previousCellStateIn, const TensorInfo& previousOutputIn, diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp index f584272e10..50a867ca2c 100644 --- a/src/backends/cl/ClWorkloadFactory.cpp +++ b/src/backends/cl/ClWorkloadFactory.cpp @@ -431,6 +431,12 @@ std::unique_ptr ClWorkloadFactory::CreatePrelu(const PreluQueueDescri return MakeWorkload(descriptor, info); } +std::unique_ptr ClWorkloadFactory::CreateQLstm(const QLstmQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + return std::make_unique(descriptor, info); +} + std::unique_ptr ClWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& descriptor, const WorkloadInfo& info) const { diff --git a/src/backends/cl/ClWorkloadFactory.hpp b/src/backends/cl/ClWorkloadFactory.hpp index b1c63211d6..3f92e86478 100644 --- a/src/backends/cl/ClWorkloadFactory.hpp +++ b/src/backends/cl/ClWorkloadFactory.hpp @@ -168,6 +168,9 @@ public: std::unique_ptr CreatePrelu(const PreluQueueDescriptor& descriptor, const WorkloadInfo& info) const override; + std::unique_ptr CreateQLstm(const QLstmQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; + std::unique_ptr CreateQuantize(const QuantizeQueueDescriptor& descriptor, const WorkloadInfo& info) const override; diff --git a/src/backends/cl/backend.mk b/src/backends/cl/backend.mk index 8b464bb1d6..f18f8f03a2 100644 --- a/src/backends/cl/backend.mk +++ b/src/backends/cl/backend.mk @@ -53,6 +53,7 @@ BACKEND_SOURCES := \ workloads/ClPermuteWorkload.cpp \ workloads/ClPooling2dWorkload.cpp \ workloads/ClPreluWorkload.cpp \ + workloads/ClQLstmWorkload.cpp \ workloads/ClQuantizedLstmWorkload.cpp \ workloads/ClQuantizeWorkload.cpp \ workloads/ClReshapeWorkload.cpp \ diff --git a/src/backends/cl/test/ClCreateWorkloadTests.cpp b/src/backends/cl/test/ClCreateWorkloadTests.cpp index b7522547d4..896d486ebf 100644 --- a/src/backends/cl/test/ClCreateWorkloadTests.cpp +++ b/src/backends/cl/test/ClCreateWorkloadTests.cpp @@ -1035,6 +1035,34 @@ BOOST_AUTO_TEST_CASE(CreateStackUint8Workload) ClCreateStackWorkloadTest({ 3, 4, 5 }, { 3, 4, 2, 5 }, 2, 2); } + +template +static void ClCreateQLstmWorkloadTest() +{ + Graph graph; + ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); + + auto workload = CreateQLstmWorkloadTest(factory, graph); + QLstmQueueDescriptor queueDescriptor = workload->GetData(); + + IAclTensorHandle* inputHandle = PolymorphicDowncast(queueDescriptor.m_Inputs[0]); + BOOST_TEST((inputHandle->GetShape() == TensorShape({2, 4}))); + BOOST_TEST((inputHandle->GetDataType() == arm_compute::DataType::QASYMM8_SIGNED)); + + IAclTensorHandle* cellStateOutHandle = PolymorphicDowncast(queueDescriptor.m_Outputs[1]); + BOOST_TEST((cellStateOutHandle->GetShape() == TensorShape({2, 4}))); + BOOST_TEST((cellStateOutHandle->GetDataType() == arm_compute::DataType::QSYMM16)); + + IAclTensorHandle* outputHandle = PolymorphicDowncast(queueDescriptor.m_Outputs[2]); + BOOST_TEST((outputHandle->GetShape() == TensorShape({2, 4}))); + BOOST_TEST((outputHandle->GetDataType() == arm_compute::DataType::QASYMM8_SIGNED)); +} + +BOOST_AUTO_TEST_CASE(CreateQLstmWorkloadTest) +{ + ClCreateQLstmWorkloadTest(); +} + template static void ClCreateQuantizedLstmWorkloadTest() { diff --git a/src/backends/cl/test/ClLayerTests.cpp b/src/backends/cl/test/ClLayerTests.cpp index ce4496e71b..1c433a85c5 100644 --- a/src/backends/cl/test/ClLayerTests.cpp +++ b/src/backends/cl/test/ClLayerTests.cpp @@ -516,6 +516,10 @@ ARMNN_AUTO_TEST_CASE(LstmLayerFloat32NoCifgWithPeepholeWithProjection, ARMNN_AUTO_TEST_CASE(LstmLayerFloat32NoCifgWithPeepholeWithProjectionWithLayerNorm, LstmLayerFloat32NoCifgWithPeepholeWithProjectionWithLayerNormTest) +// QLstm +ARMNN_AUTO_TEST_CASE(QLstm, QLstmTest) + +// QuantizedLstm ARMNN_AUTO_TEST_CASE(QuantizedLstm, QuantizedLstmTest) // Convert from Float16 to Float32 diff --git a/src/backends/cl/workloads/CMakeLists.txt b/src/backends/cl/workloads/CMakeLists.txt index 1d8ac3c321..6d0aa792a1 100644 --- a/src/backends/cl/workloads/CMakeLists.txt +++ b/src/backends/cl/workloads/CMakeLists.txt @@ -66,6 +66,8 @@ list(APPEND armnnClBackendWorkloads_sources ClPooling2dWorkload.hpp ClPreluWorkload.cpp ClPreluWorkload.hpp + ClQLstmWorkload.cpp + ClQLstmWorkload.hpp ClQuantizedLstmWorkload.cpp ClQuantizedLstmWorkload.hpp ClQuantizeWorkload.cpp diff --git a/src/backends/cl/workloads/ClQLstmWorkload.cpp b/src/backends/cl/workloads/ClQLstmWorkload.cpp new file mode 100644 index 0000000000..ea51f3717c --- /dev/null +++ b/src/backends/cl/workloads/ClQLstmWorkload.cpp @@ -0,0 +1,401 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "ClQLstmWorkload.hpp" +#include "ClWorkloadUtils.hpp" + +#include "aclCommon/ArmComputeTensorUtils.hpp" + +#include "cl/ClTensorHandle.hpp" + +namespace armnn +{ +using namespace armcomputetensorutils; + +ClQLstmWorkload::ClQLstmWorkload(const QLstmQueueDescriptor &descriptor, const WorkloadInfo &info) + : BaseWorkload(descriptor, info) +{ + arm_compute::LSTMParams qLstmParams; + + // Mandatory params + m_InputToForgetWeightsTensor = std::make_unique(); + BuildArmComputeTensor(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights->GetTensorInfo()); + + m_InputToCellWeightsTensor = std::make_unique(); + BuildArmComputeTensor(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights->GetTensorInfo()); + + m_InputToOutputWeightsTensor = std::make_unique(); + BuildArmComputeTensor(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights->GetTensorInfo()); + + m_RecurrentToForgetWeightsTensor = std::make_unique(); + BuildArmComputeTensor(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights->GetTensorInfo()); + + m_RecurrentToCellWeightsTensor = std::make_unique(); + BuildArmComputeTensor(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights->GetTensorInfo()); + + m_RecurrentToOutputWeightsTensor = std::make_unique(); + BuildArmComputeTensor(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights->GetTensorInfo()); + + m_ForgetGateBiasTensor = std::make_unique(); + BuildArmComputeTensor(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias->GetTensorInfo()); + + m_CellBiasTensor = std::make_unique(); + BuildArmComputeTensor(*m_CellBiasTensor, m_Data.m_CellBias->GetTensorInfo()); + + m_OutputGateBiasTensor = std::make_unique(); + BuildArmComputeTensor(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias->GetTensorInfo()); + + // Create tensors for optional params if they are enabled + if (m_Data.m_Parameters.m_PeepholeEnabled) + { + m_CellToInputWeightsTensor = std::make_unique(); + + if (!m_Data.m_Parameters.m_CifgEnabled) + { + // In ACL this is categorised as a CIFG param and not a Peephole param + BuildArmComputeTensor(*m_CellToInputWeightsTensor, m_Data.m_CellToInputWeights->GetTensorInfo()); + } + + m_CellToForgetWeightsTensor = std::make_unique(); + BuildArmComputeTensor(*m_CellToForgetWeightsTensor, m_Data.m_CellToForgetWeights->GetTensorInfo()); + + m_CellToOutputWeightsTensor = std::make_unique(); + BuildArmComputeTensor(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights->GetTensorInfo()); + + // Set Peephole params + qLstmParams.set_peephole_params(m_CellToForgetWeightsTensor.get(), + m_CellToOutputWeightsTensor.get()); + } + + if (m_Data.m_Parameters.m_ProjectionEnabled) + { + m_ProjectionWeightsTensor = std::make_unique(); + BuildArmComputeTensor(*m_ProjectionWeightsTensor, m_Data.m_ProjectionWeights->GetTensorInfo()); + + m_ProjectionBiasTensor = std::make_unique(); + if (m_Data.m_ProjectionBias != nullptr) + { + BuildArmComputeTensor(*m_ProjectionBiasTensor, m_Data.m_ProjectionBias->GetTensorInfo()); + } + + // Set projection params + qLstmParams.set_projection_params( + m_ProjectionWeightsTensor.get(), + m_Data.m_ProjectionBias != nullptr ? m_ProjectionBiasTensor.get() : nullptr); + } + + if (m_Data.m_Parameters.m_LayerNormEnabled) + { + m_InputLayerNormWeightsTensor = std::make_unique(); + + if (!m_Data.m_Parameters.m_CifgEnabled) + { + BuildArmComputeTensor(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights->GetTensorInfo()); + } + + m_ForgetLayerNormWeightsTensor = std::make_unique(); + BuildArmComputeTensor(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights->GetTensorInfo()); + + m_CellLayerNormWeightsTensor = std::make_unique(); + BuildArmComputeTensor(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights->GetTensorInfo()); + + m_OutputLayerNormWeightsTensor = std::make_unique(); + BuildArmComputeTensor(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights->GetTensorInfo()); + + qLstmParams.set_layer_normalization_params( + m_Data.m_InputLayerNormWeights != nullptr ? m_InputLayerNormWeightsTensor.get() : nullptr, + m_ForgetLayerNormWeightsTensor.get(), + m_CellLayerNormWeightsTensor.get(), + m_OutputLayerNormWeightsTensor.get()); + } + + if (!m_Data.m_Parameters.m_CifgEnabled) + { + m_InputToInputWeightsTensor = std::make_unique(); + BuildArmComputeTensor(*m_InputToInputWeightsTensor, m_Data.m_InputToInputWeights->GetTensorInfo()); + + m_RecurrentToInputWeightsTensor = std::make_unique(); + BuildArmComputeTensor(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights->GetTensorInfo()); + + m_InputGateBiasTensor = std::make_unique(); + BuildArmComputeTensor(*m_InputGateBiasTensor, m_Data.m_InputGateBias->GetTensorInfo()); + + qLstmParams.set_cifg_params( + m_InputToInputWeightsTensor.get(), + m_RecurrentToInputWeightsTensor.get(), + m_Data.m_CellToInputWeights != nullptr ? m_CellToInputWeightsTensor.get() : nullptr, + m_InputGateBiasTensor.get()); + } + + // Input/Output tensors + const arm_compute::ICLTensor& input = static_cast(m_Data.m_Inputs[0])->GetTensor(); + const arm_compute::ICLTensor& outputStateIn = static_cast(m_Data.m_Inputs[1])->GetTensor(); + const arm_compute::ICLTensor& cellStateIn = static_cast(m_Data.m_Inputs[2])->GetTensor(); + + arm_compute::ICLTensor& outputStateOut = static_cast(m_Data.m_Outputs[0])->GetTensor(); + arm_compute::ICLTensor& cellStateOut = static_cast(m_Data.m_Outputs[1])->GetTensor(); + arm_compute::ICLTensor& output = static_cast(m_Data.m_Outputs[2])->GetTensor(); + + // Set scalar descriptor params + qLstmParams.set_cell_clip_params(m_Data.m_Parameters.m_CellClip); + qLstmParams.set_projection_clip_params(m_Data.m_Parameters.m_ProjectionClip); + qLstmParams.set_hidden_state_params(m_Data.m_Parameters.m_HiddenStateZeroPoint, + m_Data.m_Parameters.m_HiddenStateScale); + qLstmParams.set_matmul_scale_params(m_Data.m_Parameters.m_InputIntermediateScale, + m_Data.m_Parameters.m_ForgetIntermediateScale, + m_Data.m_Parameters.m_CellIntermediateScale, + m_Data.m_Parameters.m_OutputIntermediateScale); + + m_QLstmLayer.configure(&input, + m_InputToForgetWeightsTensor.get(), + m_InputToCellWeightsTensor.get(), + m_InputToOutputWeightsTensor.get(), + m_RecurrentToForgetWeightsTensor.get(), + m_RecurrentToCellWeightsTensor.get(), + m_RecurrentToOutputWeightsTensor.get(), + m_ForgetGateBiasTensor.get(), + m_CellBiasTensor.get(), + m_OutputGateBiasTensor.get(), + &cellStateIn, + &outputStateIn, + &cellStateOut, + &outputStateOut, + &output, + qLstmParams); + + // InitializeArmComputeTensorData for mandatory params + InitializeArmComputeClTensorData(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights); + InitializeArmComputeClTensorData(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights); + InitializeArmComputeClTensorData(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights); + + InitializeArmComputeClTensorData(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights); + InitializeArmComputeClTensorData(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights); + InitializeArmComputeClTensorData(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights); + + InitializeArmComputeClTensorData(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias); + InitializeArmComputeClTensorData(*m_CellBiasTensor, m_Data.m_CellBias); + InitializeArmComputeClTensorData(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias); + + if (!m_Data.m_Parameters.m_CifgEnabled) + { + InitializeArmComputeClTensorData(*m_InputToInputWeightsTensor, m_Data.m_InputToInputWeights); + InitializeArmComputeClTensorData(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights); + InitializeArmComputeClTensorData(*m_InputGateBiasTensor, m_Data.m_InputGateBias); + } + + if (m_Data.m_Parameters.m_ProjectionEnabled) + { + InitializeArmComputeClTensorData(*m_ProjectionWeightsTensor, m_Data.m_ProjectionWeights); + + if (m_Data.m_ProjectionBias != nullptr) + { + InitializeArmComputeClTensorData(*m_ProjectionBiasTensor, m_Data.m_ProjectionBias); + } + } + + if (m_Data.m_Parameters.m_PeepholeEnabled) + { + if (!m_Data.m_Parameters.m_CifgEnabled) + { + InitializeArmComputeClTensorData(*m_CellToInputWeightsTensor, m_Data.m_CellToInputWeights); + } + + InitializeArmComputeClTensorData(*m_CellToForgetWeightsTensor, m_Data.m_CellToForgetWeights); + InitializeArmComputeClTensorData(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights); + } + + if (m_Data.m_Parameters.m_LayerNormEnabled) + { + if (!m_Data.m_Parameters.m_CifgEnabled) + { + InitializeArmComputeClTensorData(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights); + } + InitializeArmComputeClTensorData(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights); + InitializeArmComputeClTensorData(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights); + InitializeArmComputeClTensorData(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights); + } + + m_QLstmLayer.prepare(); + + FreeUnusedTensors(); +} + +void ClQLstmWorkload::Execute() const +{ + m_QLstmLayer.run(); +} + +arm_compute::Status ClQLstmWorkloadValidate(const TensorInfo& input, + const TensorInfo& cellStateIn, + const TensorInfo& outputStateIn, + const TensorInfo& cellStateOut, + const TensorInfo& outputStateOut, + const TensorInfo& output, + const QLstmDescriptor& descriptor, + const LstmInputParamsInfo& paramsInfo) +{ + arm_compute::LSTMParams aclParamsInfo; + + // The inputs and outputs + const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input); + const arm_compute::TensorInfo aclOutputStateInInfo = BuildArmComputeTensorInfo(outputStateIn); + const arm_compute::TensorInfo aclCellStateInInfo = BuildArmComputeTensorInfo(cellStateIn); + + const arm_compute::TensorInfo aclOutputStateOutInfo = BuildArmComputeTensorInfo(outputStateOut); + const arm_compute::TensorInfo aclCellStateOutInfo = BuildArmComputeTensorInfo(cellStateOut); + const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output); + + // Mandatory tensor info + const arm_compute::TensorInfo aclInputToForgetWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.GetInputToForgetWeights()); + const arm_compute::TensorInfo aclInputToCellWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.GetInputToCellWeights()); + const arm_compute::TensorInfo aclInputToOutputWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.GetInputToOutputWeights()); + const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToForgetWeights()); + const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToCellWeights()); + const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToOutputWeights()); + const arm_compute::TensorInfo aclForgetGateBiasInfo + = BuildArmComputeTensorInfo(paramsInfo.GetForgetGateBias()); + const arm_compute::TensorInfo aclCellBiasInfo + = BuildArmComputeTensorInfo(paramsInfo.GetCellBias()); + const arm_compute::TensorInfo aclOutputGateBiasInfo + = BuildArmComputeTensorInfo(paramsInfo.GetOutputGateBias()); + + // Optional tensor info + arm_compute::TensorInfo aclInputToInputWeightsInfo; + arm_compute::TensorInfo aclRecurrentToInputWeightsInfo; + arm_compute::TensorInfo aclCellToInputWeightsInfo; + arm_compute::TensorInfo aclCellToForgetWeightsInfo; + arm_compute::TensorInfo aclCellToOutputWeightsInfo; + arm_compute::TensorInfo aclInputGateBiasInfo; + arm_compute::TensorInfo aclProjectionWeightsInfo; + arm_compute::TensorInfo aclProjectionBiasInfo; + arm_compute::TensorInfo aclInputLayerNormWeightsInfo; + arm_compute::TensorInfo aclForgetLayerNormWeightsInfo; + arm_compute::TensorInfo aclCellLayerNormWeightsInfo; + arm_compute::TensorInfo aclOutputLayerNormWeightsInfo; + + + if (descriptor.m_PeepholeEnabled) + { + if (!descriptor.m_CifgEnabled) + { + aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToInputWeights()); + } + + aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToForgetWeights()); + aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToOutputWeights()); + + aclParamsInfo.set_peephole_params(&aclCellToForgetWeightsInfo, + &aclCellToOutputWeightsInfo); + } + + if (descriptor.m_ProjectionEnabled) + { + aclProjectionWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetProjectionWeights()); + + if (paramsInfo.m_ProjectionBias != nullptr) + { + aclProjectionBiasInfo = BuildArmComputeTensorInfo(paramsInfo.GetProjectionBias()); + } + + aclParamsInfo.set_projection_params( + &aclProjectionWeightsInfo, + paramsInfo.m_ProjectionBias != nullptr ? &aclProjectionBiasInfo : nullptr); + } + + if (descriptor.m_LayerNormEnabled) + { + if (!descriptor.m_CifgEnabled) + { + aclInputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetInputLayerNormWeights()); + } + + aclForgetLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetForgetLayerNormWeights()); + aclCellLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellLayerNormWeights()); + aclOutputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetOutputLayerNormWeights()); + + aclParamsInfo.set_layer_normalization_params( + paramsInfo.m_InputLayerNormWeights != nullptr ? &aclInputLayerNormWeightsInfo : nullptr, + &aclForgetLayerNormWeightsInfo, + &aclCellLayerNormWeightsInfo, + &aclOutputLayerNormWeightsInfo); + } + + if (!descriptor.m_CifgEnabled) + { + aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetInputToInputWeights()); + aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToInputWeights()); + aclInputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.GetInputGateBias()); + + + aclParamsInfo.set_cifg_params( + &aclInputToInputWeightsInfo, + &aclRecurrentToInputWeightsInfo, + paramsInfo.m_CellToInputWeights != nullptr ? &aclCellToInputWeightsInfo : nullptr, + &aclInputGateBiasInfo); + } + + aclParamsInfo.set_cell_clip_params(descriptor.m_CellClip); + aclParamsInfo.set_projection_clip_params(descriptor.m_ProjectionClip); + aclParamsInfo.set_hidden_state_params(descriptor.m_HiddenStateZeroPoint, descriptor.m_HiddenStateScale); + aclParamsInfo.set_matmul_scale_params(descriptor.m_InputIntermediateScale, + descriptor.m_ForgetIntermediateScale, + descriptor.m_CellIntermediateScale, + descriptor.m_OutputIntermediateScale); + + return arm_compute::CLQLSTMLayer::validate(&aclInputInfo, + &aclInputToForgetWeightsInfo, + &aclInputToCellWeightsInfo, + &aclInputToOutputWeightsInfo, + &aclRecurrentToForgetWeightsInfo, + &aclRecurrentToCellWeightsInfo, + &aclRecurrentToOutputWeightsInfo, + &aclForgetGateBiasInfo, + &aclCellBiasInfo, + &aclOutputGateBiasInfo, + &aclCellStateInInfo, + &aclOutputStateInInfo, + &aclCellStateOutInfo, + &aclOutputStateOutInfo, + &aclOutputInfo, + aclParamsInfo); +} + +void ClQLstmWorkload::FreeUnusedTensors() +{ + FreeTensorIfUnused(m_InputToInputWeightsTensor); + FreeTensorIfUnused(m_InputToForgetWeightsTensor); + FreeTensorIfUnused(m_InputToCellWeightsTensor); + FreeTensorIfUnused(m_InputToOutputWeightsTensor); + + FreeTensorIfUnused(m_RecurrentToInputWeightsTensor); + FreeTensorIfUnused(m_RecurrentToForgetWeightsTensor); + FreeTensorIfUnused(m_RecurrentToCellWeightsTensor); + FreeTensorIfUnused(m_RecurrentToOutputWeightsTensor); + + FreeTensorIfUnused(m_CellToInputWeightsTensor); + FreeTensorIfUnused(m_CellToForgetWeightsTensor); + FreeTensorIfUnused(m_CellToOutputWeightsTensor); + + FreeTensorIfUnused(m_InputGateBiasTensor); + FreeTensorIfUnused(m_ForgetGateBiasTensor); + FreeTensorIfUnused(m_CellBiasTensor); + FreeTensorIfUnused(m_OutputGateBiasTensor); + + FreeTensorIfUnused(m_ProjectionWeightsTensor); + FreeTensorIfUnused(m_ProjectionBiasTensor); + + FreeTensorIfUnused(m_InputLayerNormWeightsTensor); + FreeTensorIfUnused(m_ForgetLayerNormWeightsTensor); + FreeTensorIfUnused(m_CellLayerNormWeightsTensor); + FreeTensorIfUnused(m_OutputLayerNormWeightsTensor); +} + +} //namespace armnn diff --git a/src/backends/cl/workloads/ClQLstmWorkload.hpp b/src/backends/cl/workloads/ClQLstmWorkload.hpp new file mode 100644 index 0000000000..f98c9b3f9a --- /dev/null +++ b/src/backends/cl/workloads/ClQLstmWorkload.hpp @@ -0,0 +1,66 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include +#include +#include +#include + +#include "arm_compute/graph/Tensor.h" +#include "arm_compute/runtime/CL/functions/CLQLSTMLayer.h" + +namespace armnn +{ + +class ClQLstmWorkload : public BaseWorkload +{ +public: + ClQLstmWorkload(const QLstmQueueDescriptor& descriptor, const WorkloadInfo& info); + virtual void Execute() const override; + +private: + mutable arm_compute::CLQLSTMLayer m_QLstmLayer; + + std::unique_ptr m_InputToInputWeightsTensor; + std::unique_ptr m_InputToForgetWeightsTensor; + std::unique_ptr m_InputToCellWeightsTensor; + std::unique_ptr m_InputToOutputWeightsTensor; + + std::unique_ptr m_RecurrentToInputWeightsTensor; + std::unique_ptr m_RecurrentToForgetWeightsTensor; + std::unique_ptr m_RecurrentToCellWeightsTensor; + std::unique_ptr m_RecurrentToOutputWeightsTensor; + + std::unique_ptr m_CellToInputWeightsTensor; + std::unique_ptr m_CellToForgetWeightsTensor; + std::unique_ptr m_CellToOutputWeightsTensor; + + std::unique_ptr m_InputGateBiasTensor; + std::unique_ptr m_ForgetGateBiasTensor; + std::unique_ptr m_CellBiasTensor; + std::unique_ptr m_OutputGateBiasTensor; + + std::unique_ptr m_ProjectionWeightsTensor; + std::unique_ptr m_ProjectionBiasTensor; + + std::unique_ptr m_InputLayerNormWeightsTensor; + std::unique_ptr m_ForgetLayerNormWeightsTensor; + std::unique_ptr m_CellLayerNormWeightsTensor; + std::unique_ptr m_OutputLayerNormWeightsTensor; + + void FreeUnusedTensors(); +}; + +arm_compute::Status ClQLstmWorkloadValidate(const TensorInfo& input, + const TensorInfo& cellStateIn, + const TensorInfo& outputStateIn, + const TensorInfo& cellStateOut, + const TensorInfo& outputStateOut, + const TensorInfo& output, + const QLstmDescriptor& descriptor, + const LstmInputParamsInfo& paramsInfo); +} //namespace armnn diff --git a/src/backends/cl/workloads/ClWorkloadUtils.hpp b/src/backends/cl/workloads/ClWorkloadUtils.hpp index 54e7717b7d..f2fb80a84d 100644 --- a/src/backends/cl/workloads/ClWorkloadUtils.hpp +++ b/src/backends/cl/workloads/ClWorkloadUtils.hpp @@ -111,6 +111,9 @@ inline void InitializeArmComputeClTensorData(arm_compute::CLTensor& clTensor, case DataType::QSymmS8: CopyArmComputeClTensorData(clTensor, handle->GetConstTensor()); break; + case DataType::QSymmS16: + CopyArmComputeClTensorData(clTensor, handle->GetConstTensor()); + break; ARMNN_NO_DEPRECATE_WARN_END case DataType::Signed32: CopyArmComputeClTensorData(clTensor, handle->GetConstTensor()); diff --git a/src/backends/cl/workloads/ClWorkloads.hpp b/src/backends/cl/workloads/ClWorkloads.hpp index 62b73daef3..7b3ce439be 100644 --- a/src/backends/cl/workloads/ClWorkloads.hpp +++ b/src/backends/cl/workloads/ClWorkloads.hpp @@ -33,6 +33,7 @@ #include "ClPadWorkload.hpp" #include "ClPooling2dWorkload.hpp" #include "ClPreluWorkload.hpp" +#include "ClQLstmWorkload.hpp" #include "ClQuantizeWorkload.hpp" #include "ClQuantizedLstmWorkload.hpp" #include "ClReshapeWorkload.hpp" -- cgit v1.2.1