From 737d9ff58b348b11234b6c2363390607d576177d Mon Sep 17 00:00:00 2001 From: Ferran Balaguer Date: Thu, 1 Aug 2019 09:58:08 +0100 Subject: IVGCVSW-3342 Add CL backend support for Quantized_LSTM (16bit cell state) !android-nn-driver:1685 Signed-off-by: Ferran Balaguer Signed-off-by: Matthew Bentham Change-Id: I17278562f72d4b77e22c3af25bf7199b9150a765 --- src/armnn/test/CreateWorkload.hpp | 1 - src/backends/backendsCommon/WorkloadFactory.cpp | 80 ++++------- src/backends/cl/ClLayerSupport.cpp | 19 +++ src/backends/cl/ClLayerSupport.hpp | 8 ++ src/backends/cl/ClWorkloadFactory.cpp | 6 + src/backends/cl/ClWorkloadFactory.hpp | 3 + src/backends/cl/backend.mk | 1 + src/backends/cl/test/ClCreateWorkloadTests.cpp | 41 ++++++ src/backends/cl/test/ClLayerTests.cpp | 2 + src/backends/cl/workloads/CMakeLists.txt | 2 + .../cl/workloads/ClQuantizedLstmWorkload.cpp | 158 +++++++++++++++++++++ .../cl/workloads/ClQuantizedLstmWorkload.hpp | 48 +++++++ src/backends/cl/workloads/ClWorkloads.hpp | 1 + 13 files changed, 320 insertions(+), 50 deletions(-) create mode 100644 src/backends/cl/workloads/ClQuantizedLstmWorkload.cpp create mode 100644 src/backends/cl/workloads/ClQuantizedLstmWorkload.hpp diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp index 3ec7e8e673..b576f12c22 100644 --- a/src/armnn/test/CreateWorkload.hpp +++ b/src/armnn/test/CreateWorkload.hpp @@ -351,7 +351,6 @@ template std::unique_ptr CreateQuantizedLstmWorkloadTest(armnn::IWorkloadFactory& factory, armnn::Graph& graph) { - auto layer = graph.AddLayer("quantizedLstmlayer"); unsigned int numBatches = 2; unsigned int inputSize = 2; diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index dca5778e0e..1f616f0b18 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -639,61 +639,43 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, auto cLayer = boost::polymorphic_downcast(&layer); // Inputs - const TensorInfo& input = OverrideDataType( - layer.GetInputSlot(0).GetConnection()->GetTensorInfo(), dataType); - const TensorInfo& previousCellStateIn = OverrideDataType( - layer.GetInputSlot(1).GetConnection()->GetTensorInfo(), dataType); - const TensorInfo& previousOutputIn = OverrideDataType( - layer.GetInputSlot(2).GetConnection()->GetTensorInfo(), dataType); + const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); + const TensorInfo& previousCellStateIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo(); + const TensorInfo& previousOutputIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo(); // Outputs - const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType); - const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType); + const TensorInfo& cellStateOut = layer.GetOutputSlot(0).GetTensorInfo(); + const TensorInfo& output = layer.GetOutputSlot(1).GetTensorInfo(); // QuantizedLstm parameters - const TensorInfo& inputToInputWeights = OverrideDataType( - cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo(), dataType); - const TensorInfo& inputToForgetWeights = OverrideDataType( - cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo(), dataType); - const TensorInfo& inputToCellWeights = OverrideDataType( - cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo(), dataType); - const TensorInfo& inputToOutputWeights = OverrideDataType( - cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo(), dataType); - - const TensorInfo& recurrentToInputWeights = OverrideDataType( - cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType); - const TensorInfo& recurrentToForgetWeights = OverrideDataType( - cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType); - const TensorInfo& recurrentToCellWeights = OverrideDataType( - cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType); - const TensorInfo& recurrentToOutputWeights = OverrideDataType( - cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType); - - const TensorInfo& inputGateBias = OverrideDataType( - cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo(), dataType); - const TensorInfo& forgetGateBias = OverrideDataType( - cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo(), dataType); - const TensorInfo& cellBias = OverrideDataType( - cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo(), dataType); - const TensorInfo& outputGateBias = OverrideDataType( - cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo(), dataType); - QuantizedLstmInputParamsInfo paramsInfo; - paramsInfo.m_InputToInputWeights = &inputToInputWeights; - paramsInfo.m_InputToForgetWeights = &inputToForgetWeights; - paramsInfo.m_InputToCellWeights = &inputToCellWeights; - paramsInfo.m_InputToOutputWeights = &inputToOutputWeights; - - paramsInfo.m_RecurrentToInputWeights = &recurrentToInputWeights; - paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights; - paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights; - paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights; - - paramsInfo.m_InputGateBias = &inputGateBias; - paramsInfo.m_ForgetGateBias = &forgetGateBias; - paramsInfo.m_CellBias = &cellBias; - paramsInfo.m_OutputGateBias = &outputGateBias; + paramsInfo.m_InputToInputWeights = + &cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo(); + paramsInfo.m_InputToForgetWeights = + &cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo(); + paramsInfo.m_InputToCellWeights = + &cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo(); + paramsInfo.m_InputToOutputWeights = + &cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo(); + + paramsInfo.m_RecurrentToInputWeights = + &cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo(); + paramsInfo.m_RecurrentToForgetWeights = + &cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo(); + paramsInfo.m_RecurrentToCellWeights = + &cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo(); + paramsInfo.m_RecurrentToOutputWeights = + &cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo(); + + paramsInfo.m_InputGateBias = + &cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo(); + paramsInfo.m_ForgetGateBias = + &cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo(); + paramsInfo.m_CellBias = + &cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo(); + paramsInfo.m_OutputGateBias = + &cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo();; result = layerSupportObject->IsQuantizedLstmSupported(input, previousCellStateIn, diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp index 05539623a5..4ea6f2db3a 100644 --- a/src/backends/cl/ClLayerSupport.cpp +++ b/src/backends/cl/ClLayerSupport.cpp @@ -41,6 +41,7 @@ #include "workloads/ClPooling2dWorkload.hpp" #include "workloads/ClPreluWorkload.hpp" #include "workloads/ClResizeWorkload.hpp" +#include "workloads/ClQuantizedLstmWorkload.hpp" #include "workloads/ClQuantizeWorkload.hpp" #include "workloads/ClSoftmaxBaseWorkload.hpp" #include "workloads/ClSpaceToBatchNdWorkload.hpp" @@ -547,6 +548,24 @@ bool ClLayerSupport::IsPreluSupported(const armnn::TensorInfo &input, FORWARD_WORKLOAD_VALIDATE_FUNC(ClPreluWorkloadValidate, reasonIfUnsupported, input, alpha, output); } +bool ClLayerSupport::IsQuantizedLstmSupported(const TensorInfo& input, + const TensorInfo& previousCellStateIn, + const TensorInfo& previousOutputIn, + const TensorInfo& cellStateOut, + const TensorInfo& output, + const QuantizedLstmInputParamsInfo& paramsInfo, + Optional reasonIfUnsupported) const +{ + FORWARD_WORKLOAD_VALIDATE_FUNC(ClQuantizedLstmWorkloadValidate, + reasonIfUnsupported, + input, + previousCellStateIn, + previousOutputIn, + cellStateOut, + output, + paramsInfo); +} + bool ClLayerSupport::IsQuantizeSupported(const TensorInfo& input, const TensorInfo& output, Optional reasonIfUnsupported) const diff --git a/src/backends/cl/ClLayerSupport.hpp b/src/backends/cl/ClLayerSupport.hpp index 4879e8b4b8..a367085eef 100644 --- a/src/backends/cl/ClLayerSupport.hpp +++ b/src/backends/cl/ClLayerSupport.hpp @@ -175,6 +175,14 @@ public: const TensorInfo& output, Optional reasonIfUnsupported = EmptyOptional()) const override; + bool IsQuantizedLstmSupported(const TensorInfo& input, + const TensorInfo& previousCellStateIn, + const TensorInfo& previousOutputIn, + const TensorInfo& cellStateOut, + const TensorInfo& output, + const QuantizedLstmInputParamsInfo& paramsInfo, + Optional reasonIfUnsupported = EmptyOptional()) const override; + bool IsQuantizeSupported(const TensorInfo& input, const TensorInfo& output, Optional reasonIfUnsupported = EmptyOptional()) const override; diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp index 4a593aac63..d72fa92a30 100644 --- a/src/backends/cl/ClWorkloadFactory.cpp +++ b/src/backends/cl/ClWorkloadFactory.cpp @@ -433,6 +433,12 @@ std::unique_ptr ClWorkloadFactory::CreateSpaceToDepth(const SpaceToDe return MakeWorkload(descriptor, info); } +std::unique_ptr ClWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + return MakeWorkload(descriptor, info); +} + std::unique_ptr ClWorkloadFactory::CreateStack(const StackQueueDescriptor& descriptor, const WorkloadInfo& info) const { diff --git a/src/backends/cl/ClWorkloadFactory.hpp b/src/backends/cl/ClWorkloadFactory.hpp index 8586435481..01bfb8db9f 100644 --- a/src/backends/cl/ClWorkloadFactory.hpp +++ b/src/backends/cl/ClWorkloadFactory.hpp @@ -182,6 +182,9 @@ public: std::unique_ptr CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& descriptor, const WorkloadInfo& info) const override; + std::unique_ptr CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; + std::unique_ptr CreateStack(const StackQueueDescriptor& descriptor, const WorkloadInfo& info) const override; diff --git a/src/backends/cl/backend.mk b/src/backends/cl/backend.mk index ee6447f340..8c34e62705 100644 --- a/src/backends/cl/backend.mk +++ b/src/backends/cl/backend.mk @@ -46,6 +46,7 @@ BACKEND_SOURCES := \ workloads/ClPermuteWorkload.cpp \ workloads/ClPooling2dWorkload.cpp \ workloads/ClPreluWorkload.cpp \ + workloads/ClQuantizedLstmWorkload.cpp \ workloads/ClQuantizeWorkload.cpp \ workloads/ClReshapeWorkload.cpp \ workloads/ClResizeWorkload.cpp \ diff --git a/src/backends/cl/test/ClCreateWorkloadTests.cpp b/src/backends/cl/test/ClCreateWorkloadTests.cpp index f453ccc9fd..bb36504214 100644 --- a/src/backends/cl/test/ClCreateWorkloadTests.cpp +++ b/src/backends/cl/test/ClCreateWorkloadTests.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -974,4 +975,44 @@ BOOST_AUTO_TEST_CASE(CreateStackUint8Workload) ClCreateStackWorkloadTest({ 3, 4, 5 }, { 3, 4, 2, 5 }, 2, 2); } +template +static void ClCreateQuantizedLstmWorkloadTest() +{ + using namespace armnn::armcomputetensorutils; + using boost::polymorphic_downcast; + + Graph graph; + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); + + auto workload = CreateQuantizedLstmWorkloadTest(factory, graph); + + QuantizedLstmQueueDescriptor queueDescriptor = workload->GetData(); + + IAclTensorHandle* inputHandle = polymorphic_downcast(queueDescriptor.m_Inputs[0]); + BOOST_TEST((inputHandle->GetShape() == TensorShape({2, 2}))); + BOOST_TEST((inputHandle->GetDataType() == arm_compute::DataType::QASYMM8)); + + IAclTensorHandle* cellStateInHandle = polymorphic_downcast(queueDescriptor.m_Inputs[1]); + BOOST_TEST((cellStateInHandle->GetShape() == TensorShape({2, 4}))); + BOOST_TEST((cellStateInHandle->GetDataType() == arm_compute::DataType::QSYMM16)); + + IAclTensorHandle* outputStateInHandle = polymorphic_downcast(queueDescriptor.m_Inputs[2]); + BOOST_TEST((outputStateInHandle->GetShape() == TensorShape({2, 4}))); + BOOST_TEST((outputStateInHandle->GetDataType() == arm_compute::DataType::QASYMM8)); + + IAclTensorHandle* cellStateOutHandle = polymorphic_downcast(queueDescriptor.m_Outputs[0]); + BOOST_TEST((cellStateOutHandle->GetShape() == TensorShape({2, 4}))); + BOOST_TEST((cellStateOutHandle->GetDataType() == arm_compute::DataType::QSYMM16)); + + IAclTensorHandle* outputStateOutHandle = polymorphic_downcast(queueDescriptor.m_Outputs[1]); + BOOST_TEST((outputStateOutHandle->GetShape() == TensorShape({2, 4}))); + BOOST_TEST((outputStateOutHandle->GetDataType() == arm_compute::DataType::QASYMM8)); +} + +BOOST_AUTO_TEST_CASE(CreateQuantizedLstmWorkload) +{ + ClCreateQuantizedLstmWorkloadTest(); +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/src/backends/cl/test/ClLayerTests.cpp b/src/backends/cl/test/ClLayerTests.cpp index dd4c16edf4..c9114b9ac4 100644 --- a/src/backends/cl/test/ClLayerTests.cpp +++ b/src/backends/cl/test/ClLayerTests.cpp @@ -401,6 +401,8 @@ ARMNN_AUTO_TEST_CASE(LstmLayerFloat32NoCifgWithPeepholeWithProjection, ARMNN_AUTO_TEST_CASE(LstmLayerFloat32NoCifgWithPeepholeWithProjectionWithLayerNorm, LstmLayerFloat32NoCifgWithPeepholeWithProjectionWithLayerNormTest) +ARMNN_AUTO_TEST_CASE(QuantizedLstm, QuantizedLstmTest) + // Convert from Float16 to Float32 ARMNN_AUTO_TEST_CASE(SimpleConvertFp16ToFp32, SimpleConvertFp16ToFp32Test) // Convert from Float32 to Float16 diff --git a/src/backends/cl/workloads/CMakeLists.txt b/src/backends/cl/workloads/CMakeLists.txt index 49a8b177d0..f62600b983 100644 --- a/src/backends/cl/workloads/CMakeLists.txt +++ b/src/backends/cl/workloads/CMakeLists.txt @@ -56,6 +56,8 @@ list(APPEND armnnClBackendWorkloads_sources ClPooling2dWorkload.hpp ClPreluWorkload.cpp ClPreluWorkload.hpp + ClQuantizedLstmWorkload.cpp + ClQuantizedLstmWorkload.hpp ClQuantizeWorkload.cpp ClQuantizeWorkload.hpp ClReshapeWorkload.cpp diff --git a/src/backends/cl/workloads/ClQuantizedLstmWorkload.cpp b/src/backends/cl/workloads/ClQuantizedLstmWorkload.cpp new file mode 100644 index 0000000000..76a6694153 --- /dev/null +++ b/src/backends/cl/workloads/ClQuantizedLstmWorkload.cpp @@ -0,0 +1,158 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "ClQuantizedLstmWorkload.hpp" +#include "ClWorkloadUtils.hpp" + +#include +#include +#include + +namespace armnn +{ + +using namespace armcomputetensorutils; + +arm_compute::Status ClQuantizedLstmWorkloadValidate(const TensorInfo& input, const TensorInfo& previousCellStateIn, + const TensorInfo& previousOutputIn, const TensorInfo& cellStateOut, + const TensorInfo& output, + const QuantizedLstmInputParamsInfo& paramsInfo) +{ + // Inputs + const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input); + const arm_compute::TensorInfo aclPreviousCellStateInInfo = BuildArmComputeTensorInfo(previousCellStateIn); + const arm_compute::TensorInfo aclPreviousOutputInInfo = BuildArmComputeTensorInfo(previousOutputIn); + + // Outputs + const arm_compute::TensorInfo aclCellStateOutInfo = BuildArmComputeTensorInfo(cellStateOut); + const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output); + + // Basic parameters + const arm_compute::TensorInfo aclInputToInputWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.get_InputToInputWeights()); + const arm_compute::TensorInfo aclInputToForgetWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.get_InputToForgetWeights()); + const arm_compute::TensorInfo aclInputToCellWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.get_InputToCellWeights()); + const arm_compute::TensorInfo aclInputToOutputWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.get_InputToOutputWeights()); + const arm_compute::TensorInfo aclRecurrentToInputWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToInputWeights()); + const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToForgetWeights()); + const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToCellWeights()); + const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToOutputWeights()); + const arm_compute::TensorInfo aclInputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputGateBias()); + const arm_compute::TensorInfo aclForgetGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_ForgetGateBias()); + const arm_compute::TensorInfo aclCellBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellBias()); + const arm_compute::TensorInfo aclOutputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_OutputGateBias()); + + return arm_compute::CLLSTMLayerQuantized::validate(&aclInputInfo, &aclInputToInputWeightsInfo, + &aclInputToForgetWeightsInfo, &aclInputToCellWeightsInfo, + &aclInputToOutputWeightsInfo, &aclRecurrentToInputWeightsInfo, + &aclRecurrentToForgetWeightsInfo, &aclRecurrentToCellWeightsInfo, + &aclRecurrentToOutputWeightsInfo, &aclInputGateBiasInfo, + &aclForgetGateBiasInfo, &aclCellBiasInfo, &aclOutputGateBiasInfo, + &aclPreviousCellStateInInfo, &aclPreviousOutputInInfo, + &aclCellStateOutInfo, &aclOutputInfo); +} + +ClQuantizedLstmWorkload::ClQuantizedLstmWorkload(const QuantizedLstmQueueDescriptor &descriptor, + const WorkloadInfo &info): + BaseWorkload(descriptor, info) +{ + m_InputToInputWeightsTensor = std::make_unique(); + BuildArmComputeTensor(*m_InputToInputWeightsTensor, m_Data.m_InputToInputWeights->GetTensorInfo()); + + m_InputToForgetWeightsTensor = std::make_unique(); + BuildArmComputeTensor(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights->GetTensorInfo()); + + m_InputToCellWeightsTensor = std::make_unique(); + BuildArmComputeTensor(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights->GetTensorInfo()); + + m_InputToOutputWeightsTensor = std::make_unique(); + BuildArmComputeTensor(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights->GetTensorInfo()); + + m_RecurrentToInputWeightsTensor = std::make_unique(); + BuildArmComputeTensor(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights->GetTensorInfo()); + + m_RecurrentToForgetWeightsTensor = std::make_unique(); + BuildArmComputeTensor(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights->GetTensorInfo()); + + m_RecurrentToCellWeightsTensor = std::make_unique(); + BuildArmComputeTensor(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights->GetTensorInfo()); + + m_RecurrentToOutputWeightsTensor = std::make_unique(); + BuildArmComputeTensor(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights->GetTensorInfo()); + + m_InputGateBiasTensor = std::make_unique(); + BuildArmComputeTensor(*m_InputGateBiasTensor, m_Data.m_InputGateBias->GetTensorInfo()); + + m_ForgetGateBiasTensor = std::make_unique(); + BuildArmComputeTensor(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias->GetTensorInfo()); + + m_CellBiasTensor = std::make_unique(); + BuildArmComputeTensor(*m_CellBiasTensor, m_Data.m_CellBias->GetTensorInfo()); + + m_OutputGateBiasTensor = std::make_unique(); + BuildArmComputeTensor(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias->GetTensorInfo()); + + const arm_compute::ICLTensor& inputTensor = static_cast(m_Data.m_Inputs[0])->GetTensor(); + arm_compute::ICLTensor& cellStateInTensor = static_cast(m_Data.m_Inputs[1])->GetTensor(); + const arm_compute::ICLTensor& outputStateInTensor = static_cast(m_Data.m_Inputs[2])->GetTensor(); + + arm_compute::ICLTensor& cellStateOutTensor = static_cast(m_Data.m_Outputs[0])->GetTensor(); + arm_compute::ICLTensor& outputStateOutTensor = static_cast(m_Data.m_Outputs[1])->GetTensor(); + + m_QuantizedLstmLayer.configure(&inputTensor, m_InputToInputWeightsTensor.get(), m_InputToForgetWeightsTensor.get(), + m_InputToCellWeightsTensor.get(), m_InputToOutputWeightsTensor.get(), + m_RecurrentToInputWeightsTensor.get(), m_RecurrentToForgetWeightsTensor.get(), + m_RecurrentToCellWeightsTensor.get(), m_RecurrentToOutputWeightsTensor.get(), + m_InputGateBiasTensor.get(), m_ForgetGateBiasTensor.get(), m_CellBiasTensor.get(), + m_OutputGateBiasTensor.get(), &cellStateInTensor, &outputStateInTensor, + &cellStateOutTensor, &outputStateOutTensor); + + InitializeArmComputeClTensorData(*m_InputToInputWeightsTensor, m_Data.m_InputToInputWeights); + InitializeArmComputeClTensorData(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights); + InitializeArmComputeClTensorData(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights); + InitializeArmComputeClTensorData(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights); + InitializeArmComputeClTensorData(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights); + InitializeArmComputeClTensorData(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights); + InitializeArmComputeClTensorData(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights); + InitializeArmComputeClTensorData(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights); + InitializeArmComputeClTensorData(*m_InputGateBiasTensor, m_Data.m_InputGateBias); + InitializeArmComputeClTensorData(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias); + InitializeArmComputeClTensorData(*m_CellBiasTensor, m_Data.m_CellBias); + InitializeArmComputeClTensorData(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias); + + m_QuantizedLstmLayer.prepare(); + FreeUnusedTensors(); +} + +void ClQuantizedLstmWorkload::Execute() const +{ + ARMNN_SCOPED_PROFILING_EVENT_CL("ClQuantizedLstmWorkload_Execute"); + RunClFunction(m_QuantizedLstmLayer, CHECK_LOCATION()); +} + +void ClQuantizedLstmWorkload::FreeUnusedTensors() +{ + FreeTensorIfUnused(m_InputToInputWeightsTensor); + FreeTensorIfUnused(m_InputToForgetWeightsTensor); + FreeTensorIfUnused(m_InputToCellWeightsTensor); + FreeTensorIfUnused(m_InputToOutputWeightsTensor); + FreeTensorIfUnused(m_RecurrentToInputWeightsTensor); + FreeTensorIfUnused(m_RecurrentToForgetWeightsTensor); + FreeTensorIfUnused(m_RecurrentToCellWeightsTensor); + FreeTensorIfUnused(m_RecurrentToOutputWeightsTensor); + FreeTensorIfUnused(m_InputGateBiasTensor); + FreeTensorIfUnused(m_ForgetGateBiasTensor); + FreeTensorIfUnused(m_CellBiasTensor); + FreeTensorIfUnused(m_OutputGateBiasTensor); +} + +} // namespace armnn \ No newline at end of file diff --git a/src/backends/cl/workloads/ClQuantizedLstmWorkload.hpp b/src/backends/cl/workloads/ClQuantizedLstmWorkload.hpp new file mode 100644 index 0000000000..c7d83755c7 --- /dev/null +++ b/src/backends/cl/workloads/ClQuantizedLstmWorkload.hpp @@ -0,0 +1,48 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include +#include + +#include + +namespace armnn +{ + +arm_compute::Status ClQuantizedLstmWorkloadValidate(const TensorInfo& input, const TensorInfo& previousCellStateIn, + const TensorInfo& previousOutputIn, const TensorInfo& cellStateOut, + const TensorInfo& output, + const QuantizedLstmInputParamsInfo& paramsInfo); + +class ClQuantizedLstmWorkload : public BaseWorkload +{ +public: + ClQuantizedLstmWorkload(const QuantizedLstmQueueDescriptor& descriptor, const WorkloadInfo& info); + void Execute() const override; + +private: + mutable arm_compute::CLLSTMLayerQuantized m_QuantizedLstmLayer; + + std::unique_ptr m_InputToInputWeightsTensor; + std::unique_ptr m_InputToForgetWeightsTensor; + std::unique_ptr m_InputToCellWeightsTensor; + std::unique_ptr m_InputToOutputWeightsTensor; + std::unique_ptr m_RecurrentToInputWeightsTensor; + std::unique_ptr m_RecurrentToForgetWeightsTensor; + std::unique_ptr m_RecurrentToCellWeightsTensor; + std::unique_ptr m_RecurrentToOutputWeightsTensor; + std::unique_ptr m_InputGateBiasTensor; + std::unique_ptr m_ForgetGateBiasTensor; + std::unique_ptr m_CellBiasTensor; + std::unique_ptr m_OutputGateBiasTensor; + + void FreeUnusedTensors(); +}; + +} //namespace armnn + + diff --git a/src/backends/cl/workloads/ClWorkloads.hpp b/src/backends/cl/workloads/ClWorkloads.hpp index 03dffc4edc..1af30ffb34 100644 --- a/src/backends/cl/workloads/ClWorkloads.hpp +++ b/src/backends/cl/workloads/ClWorkloads.hpp @@ -29,6 +29,7 @@ #include "ClPooling2dWorkload.hpp" #include "ClPreluWorkload.hpp" #include "ClQuantizeWorkload.hpp" +#include "ClQuantizedLstmWorkload.hpp" #include "ClReshapeWorkload.hpp" #include "ClResizeWorkload.hpp" #include "ClSoftmaxFloatWorkload.hpp" -- cgit v1.2.1