diff options
author | Francis Murtagh <francis.murtagh@arm.com> | 2019-08-02 13:20:54 +0100 |
---|---|---|
committer | Áron Virginás-Tar <aron.virginas-tar@arm.com> | 2019-08-05 15:19:07 +0000 |
commit | 4fc3c48c2d230d8c55aa01aa98e32b6df7cafc0c (patch) | |
tree | 3ee1e3e59ed6cdd3c86377d260374e21d4fde923 /src/backends/neon | |
parent | f0a0a9ec1e8188e6494d57160341b5bb8a4c3bd7 (diff) | |
download | armnn-4fc3c48c2d230d8c55aa01aa98e32b6df7cafc0c.tar.gz |
IVGCVSW-3341 Add Neon backend support for Quantized_LSTM (16bit cell state)
* Add Neon Workload
* Update NeonWorkloads.hpp
* Update NeonWorkloadFactory
* Update NeonLayerSupport
* Update backends.mk and CMakeLists.txt
* Add NeonCreateWorkload test
* Enable LayerTest
!android-nn-driver:1685
Change-Id: Idd799bbf039acf0d59084d02c3b57766ce3691b5
Signed-off-by: Francis Murtagh <francis.murtagh@arm.com>
Signed-off-by: Matthew Bentham <Matthew.Bentham@arm.com>
Diffstat (limited to 'src/backends/neon')
-rw-r--r-- | src/backends/neon/NeonLayerSupport.cpp | 19 | ||||
-rw-r--r-- | src/backends/neon/NeonLayerSupport.hpp | 8 | ||||
-rw-r--r-- | src/backends/neon/NeonWorkloadFactory.cpp | 6 | ||||
-rw-r--r-- | src/backends/neon/NeonWorkloadFactory.hpp | 3 | ||||
-rw-r--r-- | src/backends/neon/backend.mk | 1 | ||||
-rw-r--r-- | src/backends/neon/test/NeonCreateWorkloadTests.cpp | 39 | ||||
-rw-r--r-- | src/backends/neon/test/NeonLayerTests.cpp | 2 | ||||
-rw-r--r-- | src/backends/neon/workloads/CMakeLists.txt | 2 | ||||
-rw-r--r-- | src/backends/neon/workloads/NeonQuantizedLstmWorkload.cpp | 210 | ||||
-rw-r--r-- | src/backends/neon/workloads/NeonQuantizedLstmWorkload.hpp | 52 | ||||
-rw-r--r-- | src/backends/neon/workloads/NeonWorkloads.hpp | 1 |
11 files changed, 343 insertions, 0 deletions
diff --git a/src/backends/neon/NeonLayerSupport.cpp b/src/backends/neon/NeonLayerSupport.cpp index dac3525f60..b61279c133 100644 --- a/src/backends/neon/NeonLayerSupport.cpp +++ b/src/backends/neon/NeonLayerSupport.cpp @@ -39,6 +39,7 @@ #include "workloads/NeonPooling2dWorkload.hpp" #include "workloads/NeonPreluWorkload.hpp" #include "workloads/NeonQuantizeWorkload.hpp" +#include "workloads/NeonQuantizedLstmWorkload.hpp" #include "workloads/NeonResizeWorkload.hpp" #include "workloads/NeonSoftmaxBaseWorkload.hpp" #include "workloads/NeonSpaceToDepthWorkload.hpp" @@ -487,6 +488,24 @@ bool NeonLayerSupport::IsQuantizeSupported(const TensorInfo& input, output); } +bool NeonLayerSupport::IsQuantizedLstmSupported(const TensorInfo& input, + const TensorInfo& cellStateIn, + const TensorInfo& outputStateIn, + const TensorInfo& cellStateOut, + const TensorInfo& outputStateOut, + const QuantizedLstmInputParamsInfo& paramsInfo, + Optional<std::string&> reasonIfUnsupported) const +{ + FORWARD_WORKLOAD_VALIDATE_FUNC(NeonQuantizedLstmWorkloadValidate, + reasonIfUnsupported, + input, + cellStateIn, + outputStateIn, + cellStateOut, + outputStateOut, + paramsInfo); +} + bool NeonLayerSupport::IsReshapeSupported(const TensorInfo& input, const ReshapeDescriptor& descriptor, Optional<std::string&> reasonIfUnsupported) const diff --git a/src/backends/neon/NeonLayerSupport.hpp b/src/backends/neon/NeonLayerSupport.hpp index 078d2f619b..acaebc4c58 100644 --- a/src/backends/neon/NeonLayerSupport.hpp +++ b/src/backends/neon/NeonLayerSupport.hpp @@ -165,6 +165,14 @@ public: const TensorInfo& output, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; + bool IsQuantizedLstmSupported(const TensorInfo& input, + const TensorInfo& cellStateIn, + const TensorInfo& outputStateIn, + const TensorInfo& cellStateOut, + const TensorInfo& outputStateOut, + const QuantizedLstmInputParamsInfo& paramsInfo, + Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; + bool IsReshapeSupported(const TensorInfo& input, const ReshapeDescriptor& descriptor, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; diff --git a/src/backends/neon/NeonWorkloadFactory.cpp b/src/backends/neon/NeonWorkloadFactory.cpp index fd0381c26d..0e66bfc757 100644 --- a/src/backends/neon/NeonWorkloadFactory.cpp +++ b/src/backends/neon/NeonWorkloadFactory.cpp @@ -319,6 +319,12 @@ std::unique_ptr<IWorkload> NeonWorkloadFactory::CreateLstm(const LstmQueueDescri return MakeWorkloadHelper<NeonLstmFloatWorkload, NullWorkload>(descriptor, info); } +std::unique_ptr<IWorkload> NeonWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + return std::make_unique<NeonQuantizedLstmWorkload>(descriptor, info); +} + std::unique_ptr<IWorkload> NeonWorkloadFactory::CreateConvertFp16ToFp32( const ConvertFp16ToFp32QueueDescriptor& descriptor, const WorkloadInfo& info) const diff --git a/src/backends/neon/NeonWorkloadFactory.hpp b/src/backends/neon/NeonWorkloadFactory.hpp index 360dc7c61b..b9995d8b4b 100644 --- a/src/backends/neon/NeonWorkloadFactory.hpp +++ b/src/backends/neon/NeonWorkloadFactory.hpp @@ -135,6 +135,9 @@ public: std::unique_ptr<IWorkload> CreateLstm(const LstmQueueDescriptor& descriptor, const WorkloadInfo& info) const override; + std::unique_ptr<IWorkload> CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; + std::unique_ptr<IWorkload> CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& descriptor, const WorkloadInfo& info) const override; diff --git a/src/backends/neon/backend.mk b/src/backends/neon/backend.mk index 98755e9b0a..d5483b0c7d 100644 --- a/src/backends/neon/backend.mk +++ b/src/backends/neon/backend.mk @@ -43,6 +43,7 @@ BACKEND_SOURCES := \ workloads/NeonPermuteWorkload.cpp \ workloads/NeonPooling2dWorkload.cpp \ workloads/NeonPreluWorkload.cpp \ + workloads/NeonQuantizedLstmWorkload.cpp \ workloads/NeonQuantizeWorkload.cpp \ workloads/NeonReshapeWorkload.cpp \ workloads/NeonResizeWorkload.cpp \ diff --git a/src/backends/neon/test/NeonCreateWorkloadTests.cpp b/src/backends/neon/test/NeonCreateWorkloadTests.cpp index 848af1285f..056bfb283f 100644 --- a/src/backends/neon/test/NeonCreateWorkloadTests.cpp +++ b/src/backends/neon/test/NeonCreateWorkloadTests.cpp @@ -5,6 +5,7 @@ #include "NeonWorkloadFactoryHelper.hpp" +#include <aclCommon/ArmComputeTensorUtils.hpp> #include <backendsCommon/MemCopyWorkload.hpp> #include <aclCommon/test/CreateWorkloadClNeon.hpp> @@ -873,4 +874,42 @@ BOOST_AUTO_TEST_CASE(CreateStackUint8Workload) NeonCreateStackWorkloadTest<armnn::DataType::QuantisedAsymm8>({ 3, 4, 5 }, { 3, 4, 2, 5 }, 2, 2); } +template <typename QuantizedLstmWorkloadType> +static void NeonCreateQuantizedLstmWorkloadTest() +{ + using boost::polymorphic_downcast; + + Graph graph; + NeonWorkloadFactory factory = NeonWorkloadFactoryHelper::GetFactory(NeonWorkloadFactoryHelper::GetMemoryManager()); + + auto workload = CreateQuantizedLstmWorkloadTest<QuantizedLstmWorkloadType>(factory, graph); + + QuantizedLstmQueueDescriptor queueDescriptor = workload->GetData(); + + IAclTensorHandle* inputHandle = polymorphic_downcast<IAclTensorHandle*>(queueDescriptor.m_Inputs[0]); + BOOST_TEST((inputHandle->GetShape() == TensorShape({2, 2}))); + BOOST_TEST((inputHandle->GetDataType() == arm_compute::DataType::QASYMM8)); + + IAclTensorHandle* cellStateInHandle = polymorphic_downcast<IAclTensorHandle*>(queueDescriptor.m_Inputs[1]); + BOOST_TEST((cellStateInHandle->GetShape() == TensorShape({2, 4}))); + BOOST_TEST((cellStateInHandle->GetDataType() == arm_compute::DataType::QSYMM16)); + + IAclTensorHandle* outputStateInHandle = polymorphic_downcast<IAclTensorHandle*>(queueDescriptor.m_Inputs[2]); + BOOST_TEST((outputStateInHandle->GetShape() == TensorShape({2, 4}))); + BOOST_TEST((outputStateInHandle->GetDataType() == arm_compute::DataType::QASYMM8)); + + IAclTensorHandle* cellStateOutHandle = polymorphic_downcast<IAclTensorHandle*>(queueDescriptor.m_Outputs[0]); + BOOST_TEST((cellStateOutHandle->GetShape() == TensorShape({2, 4}))); + BOOST_TEST((cellStateOutHandle->GetDataType() == arm_compute::DataType::QSYMM16)); + + IAclTensorHandle* outputStateOutHandle = polymorphic_downcast<IAclTensorHandle*>(queueDescriptor.m_Outputs[1]); + BOOST_TEST((outputStateOutHandle->GetShape() == TensorShape({2, 4}))); + BOOST_TEST((outputStateOutHandle->GetDataType() == arm_compute::DataType::QASYMM8)); +} + +BOOST_AUTO_TEST_CASE(CreateQuantizedLstmWorkload) +{ + NeonCreateQuantizedLstmWorkloadTest<NeonQuantizedLstmWorkload>(); +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/src/backends/neon/test/NeonLayerTests.cpp b/src/backends/neon/test/NeonLayerTests.cpp index dd30536ac9..ed99461b31 100644 --- a/src/backends/neon/test/NeonLayerTests.cpp +++ b/src/backends/neon/test/NeonLayerTests.cpp @@ -529,6 +529,8 @@ ARMNN_AUTO_TEST_CASE(LstmLayerFloat32NoCifgWithPeepholeWithProjection, ARMNN_AUTO_TEST_CASE(LstmLayerFloat32NoCifgWithPeepholeWithProjectionWithLayerNorm, LstmLayerFloat32NoCifgWithPeepholeWithProjectionWithLayerNormTest) +ARMNN_AUTO_TEST_CASE(QuantizedLstm, QuantizedLstmTest) + // Mean ARMNN_AUTO_TEST_CASE(MeanSimpleFloat32, MeanSimpleTest<armnn::DataType::Float32>) ARMNN_AUTO_TEST_CASE(MeanSimpleAxisFloat32, MeanSimpleAxisTest<armnn::DataType::Float32>) diff --git a/src/backends/neon/workloads/CMakeLists.txt b/src/backends/neon/workloads/CMakeLists.txt index dea0228377..34fe0723af 100644 --- a/src/backends/neon/workloads/CMakeLists.txt +++ b/src/backends/neon/workloads/CMakeLists.txt @@ -52,6 +52,8 @@ list(APPEND armnnNeonBackendWorkloads_sources NeonPooling2dWorkload.hpp NeonPreluWorkload.cpp NeonPreluWorkload.hpp + NeonQuantizedLstmWorkload.cpp + NeonQuantizedLstmWorkload.hpp NeonQuantizeWorkload.cpp NeonQuantizeWorkload.hpp NeonReshapeWorkload.cpp diff --git a/src/backends/neon/workloads/NeonQuantizedLstmWorkload.cpp b/src/backends/neon/workloads/NeonQuantizedLstmWorkload.cpp new file mode 100644 index 0000000000..d4319d414d --- /dev/null +++ b/src/backends/neon/workloads/NeonQuantizedLstmWorkload.cpp @@ -0,0 +1,210 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "NeonQuantizedLstmWorkload.hpp" +#include "NeonWorkloadUtils.hpp" + +#include <aclCommon/ArmComputeTensorUtils.hpp> +#include <backendsCommon/CpuTensorHandle.hpp> +#include <neon/NeonTensorHandle.hpp> + +namespace armnn +{ +using namespace armcomputetensorutils; + +NeonQuantizedLstmWorkload::NeonQuantizedLstmWorkload(const QuantizedLstmQueueDescriptor &descriptor, + const WorkloadInfo &info) + : BaseWorkload<QuantizedLstmQueueDescriptor>(descriptor, info) +{ + // Basic parameters + m_InputToInputWeightsTensor = std::make_unique<arm_compute::Tensor>(); + BuildArmComputeTensor(*m_InputToInputWeightsTensor, m_Data.m_InputToInputWeights->GetTensorInfo()); + + m_InputToForgetWeightsTensor = std::make_unique<arm_compute::Tensor>(); + BuildArmComputeTensor(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights->GetTensorInfo()); + + m_InputToCellWeightsTensor = std::make_unique<arm_compute::Tensor>(); + BuildArmComputeTensor(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights->GetTensorInfo()); + + m_InputToOutputWeightsTensor = std::make_unique<arm_compute::Tensor>(); + BuildArmComputeTensor(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights->GetTensorInfo()); + + m_RecurrentToInputWeightsTensor = std::make_unique<arm_compute::Tensor>(); + BuildArmComputeTensor(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights->GetTensorInfo()); + + m_RecurrentToForgetWeightsTensor = std::make_unique<arm_compute::Tensor>(); + BuildArmComputeTensor(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights->GetTensorInfo()); + + m_RecurrentToCellWeightsTensor = std::make_unique<arm_compute::Tensor>(); + BuildArmComputeTensor(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights->GetTensorInfo()); + + m_RecurrentToOutputWeightsTensor = std::make_unique<arm_compute::Tensor>(); + BuildArmComputeTensor(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights->GetTensorInfo()); + + m_InputGateBiasTensor = std::make_unique<arm_compute::Tensor>(); + BuildArmComputeTensor(*m_InputGateBiasTensor, m_Data.m_InputGateBias->GetTensorInfo()); + + m_ForgetGateBiasTensor = std::make_unique<arm_compute::Tensor>(); + BuildArmComputeTensor(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias->GetTensorInfo()); + + m_CellBiasTensor = std::make_unique<arm_compute::Tensor>(); + BuildArmComputeTensor(*m_CellBiasTensor, m_Data.m_CellBias->GetTensorInfo()); + + m_OutputGateBiasTensor = std::make_unique<arm_compute::Tensor>(); + BuildArmComputeTensor(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias->GetTensorInfo()); + + const arm_compute::ITensor& input = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[0])->GetTensor(); + arm_compute::ITensor& cell_state_in = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[1])->GetTensor(); + const arm_compute::ITensor& output_state_in = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[2])->GetTensor(); + + arm_compute::ITensor& cell_state_out = static_cast<IAclTensorHandle*>(m_Data.m_Outputs[0])->GetTensor(); + arm_compute::ITensor& output_state_out = static_cast<IAclTensorHandle*>(m_Data.m_Outputs[1])->GetTensor(); + + m_QuantizedLstmLayer.configure(&input, + m_InputToInputWeightsTensor.get(), + m_InputToForgetWeightsTensor.get(), + m_InputToCellWeightsTensor.get(), + m_InputToOutputWeightsTensor.get(), + m_RecurrentToInputWeightsTensor.get(), + m_RecurrentToForgetWeightsTensor.get(), + m_RecurrentToCellWeightsTensor.get(), + m_RecurrentToOutputWeightsTensor.get(), + m_InputGateBiasTensor.get(), + m_ForgetGateBiasTensor.get(), + m_CellBiasTensor.get(), + m_OutputGateBiasTensor.get(), + &cell_state_in, + &output_state_in, + &cell_state_out, + &output_state_out); + + InitializeArmComputeTensorData(*m_InputToInputWeightsTensor, + m_Data.m_InputToInputWeights); + + InitializeArmComputeTensorData(*m_InputToForgetWeightsTensor, + m_Data.m_InputToForgetWeights); + + InitializeArmComputeTensorData(*m_InputToCellWeightsTensor, + m_Data.m_InputToCellWeights); + + InitializeArmComputeTensorData(*m_InputToOutputWeightsTensor, + m_Data.m_InputToOutputWeights); + + InitializeArmComputeTensorData(*m_RecurrentToInputWeightsTensor, + m_Data.m_RecurrentToInputWeights); + + InitializeArmComputeTensorData(*m_RecurrentToForgetWeightsTensor, + m_Data.m_RecurrentToForgetWeights); + + InitializeArmComputeTensorData(*m_RecurrentToCellWeightsTensor, + m_Data.m_RecurrentToCellWeights); + + InitializeArmComputeTensorData(*m_RecurrentToOutputWeightsTensor, + m_Data.m_RecurrentToOutputWeights); + + InitializeArmComputeTensorData(*m_InputGateBiasTensor, + m_Data.m_InputGateBias); + + InitializeArmComputeTensorData(*m_ForgetGateBiasTensor, + m_Data.m_ForgetGateBias); + + InitializeArmComputeTensorData(*m_CellBiasTensor, + m_Data.m_CellBias); + + InitializeArmComputeTensorData(*m_OutputGateBiasTensor, + m_Data.m_OutputGateBias); + + // Force Compute Library to perform the necessary copying and reshaping, after which + // delete all the input tensors that will no longer be needed + m_QuantizedLstmLayer.prepare(); + FreeUnusedTensors(); +} + +void NeonQuantizedLstmWorkload::Execute() const +{ + m_QuantizedLstmLayer.run(); +} + +arm_compute::Status NeonQuantizedLstmWorkloadValidate(const TensorInfo& input, + const TensorInfo& cellStateIn, + const TensorInfo& outputStateIn, + const TensorInfo& cellStateOut, + const TensorInfo& outputStateOut, + const QuantizedLstmInputParamsInfo& paramsInfo) +{ + // The inputs and outputs + const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input); + const arm_compute::TensorInfo aclCellStateInInfo = BuildArmComputeTensorInfo(cellStateIn); + const arm_compute::TensorInfo aclOutputStateInInfo = BuildArmComputeTensorInfo(outputStateIn); + const arm_compute::TensorInfo aclCellStateOutInfo = BuildArmComputeTensorInfo(cellStateOut); + const arm_compute::TensorInfo aclOutputStateOutInfo = BuildArmComputeTensorInfo(outputStateOut); + + // Basic parameters + const arm_compute::TensorInfo aclInputToInputWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.get_InputToInputWeights()); + const arm_compute::TensorInfo aclInputToForgetWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.get_InputToForgetWeights()); + const arm_compute::TensorInfo aclInputToCellWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.get_InputToCellWeights()); + const arm_compute::TensorInfo aclInputToOutputWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.get_InputToOutputWeights()); + + const arm_compute::TensorInfo aclRecurrentToInputWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToInputWeights()); + const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToForgetWeights()); + const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToCellWeights()); + const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToOutputWeights()); + + const arm_compute::TensorInfo aclInputGateBiasInfo + = BuildArmComputeTensorInfo(paramsInfo.get_InputGateBias()); + const arm_compute::TensorInfo aclForgetGateBiasInfo + = BuildArmComputeTensorInfo(paramsInfo.get_ForgetGateBias()); + const arm_compute::TensorInfo aclCellBiasInfo + = BuildArmComputeTensorInfo(paramsInfo.get_CellBias()); + const arm_compute::TensorInfo aclOutputGateBiasInfo + = BuildArmComputeTensorInfo(paramsInfo.get_OutputGateBias()); + + return arm_compute::NELSTMLayerQuantized::validate(&aclInputInfo, + &aclInputToInputWeightsInfo, + &aclInputToForgetWeightsInfo, + &aclInputToCellWeightsInfo, + &aclInputToOutputWeightsInfo, + &aclRecurrentToInputWeightsInfo, + &aclRecurrentToForgetWeightsInfo, + &aclRecurrentToCellWeightsInfo, + &aclRecurrentToOutputWeightsInfo, + &aclInputGateBiasInfo, + &aclForgetGateBiasInfo, + &aclCellBiasInfo, + &aclOutputGateBiasInfo, + &aclCellStateInInfo, + &aclOutputStateInInfo, + &aclCellStateOutInfo, + &aclOutputStateOutInfo); +} + +void NeonQuantizedLstmWorkload::FreeUnusedTensors() +{ + FreeTensorIfUnused(m_InputToInputWeightsTensor); + FreeTensorIfUnused(m_InputToForgetWeightsTensor); + FreeTensorIfUnused(m_InputToCellWeightsTensor); + FreeTensorIfUnused(m_InputToOutputWeightsTensor); + FreeTensorIfUnused(m_RecurrentToInputWeightsTensor); + FreeTensorIfUnused(m_RecurrentToForgetWeightsTensor); + FreeTensorIfUnused(m_RecurrentToCellWeightsTensor); + FreeTensorIfUnused(m_RecurrentToOutputWeightsTensor); + FreeTensorIfUnused(m_InputGateBiasTensor); + FreeTensorIfUnused(m_ForgetGateBiasTensor); + FreeTensorIfUnused(m_CellBiasTensor); + FreeTensorIfUnused(m_OutputGateBiasTensor); + FreeTensorIfUnused(m_CellStateInTensor); + FreeTensorIfUnused(m_OutputStateInTensor); + FreeTensorIfUnused(m_CellStateOutTensor); +} + +} //namespace armnn diff --git a/src/backends/neon/workloads/NeonQuantizedLstmWorkload.hpp b/src/backends/neon/workloads/NeonQuantizedLstmWorkload.hpp new file mode 100644 index 0000000000..ab8ea71437 --- /dev/null +++ b/src/backends/neon/workloads/NeonQuantizedLstmWorkload.hpp @@ -0,0 +1,52 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <backendsCommon/Workload.hpp> +#include <backendsCommon/WorkloadData.hpp> + +#include <arm_compute/graph/Tensor.h> +#include <arm_compute/runtime/NEON/functions/NELSTMLayerQuantized.h> + +namespace armnn +{ + +class NeonQuantizedLstmWorkload : public BaseWorkload<QuantizedLstmQueueDescriptor> +{ +public: + NeonQuantizedLstmWorkload(const QuantizedLstmQueueDescriptor& descriptor, const WorkloadInfo& info); + virtual void Execute() const override; + +private: + mutable arm_compute::NELSTMLayerQuantized m_QuantizedLstmLayer; + + std::unique_ptr<arm_compute::Tensor> m_InputToInputWeightsTensor; + std::unique_ptr<arm_compute::Tensor> m_InputToForgetWeightsTensor; + std::unique_ptr<arm_compute::Tensor> m_InputToCellWeightsTensor; + std::unique_ptr<arm_compute::Tensor> m_InputToOutputWeightsTensor; + std::unique_ptr<arm_compute::Tensor> m_RecurrentToInputWeightsTensor; + std::unique_ptr<arm_compute::Tensor> m_RecurrentToForgetWeightsTensor; + std::unique_ptr<arm_compute::Tensor> m_RecurrentToCellWeightsTensor; + std::unique_ptr<arm_compute::Tensor> m_RecurrentToOutputWeightsTensor; + std::unique_ptr<arm_compute::Tensor> m_InputGateBiasTensor; + std::unique_ptr<arm_compute::Tensor> m_ForgetGateBiasTensor; + std::unique_ptr<arm_compute::Tensor> m_CellBiasTensor; + std::unique_ptr<arm_compute::Tensor> m_OutputGateBiasTensor; + std::unique_ptr<arm_compute::Tensor> m_CellStateInTensor; + std::unique_ptr<arm_compute::Tensor> m_OutputStateInTensor; + std::unique_ptr<arm_compute::Tensor> m_CellStateOutTensor; + + void FreeUnusedTensors(); +}; + +arm_compute::Status NeonQuantizedLstmWorkloadValidate(const TensorInfo& input, + const TensorInfo& outputStateIn, + const TensorInfo& cellStateIn, + const TensorInfo& outputStateOut, + const TensorInfo& cellStateOut, + const QuantizedLstmInputParamsInfo& paramsInfo); + +} //namespace armnn diff --git a/src/backends/neon/workloads/NeonWorkloads.hpp b/src/backends/neon/workloads/NeonWorkloads.hpp index 7cb6c3b7b1..8fc684e3e9 100644 --- a/src/backends/neon/workloads/NeonWorkloads.hpp +++ b/src/backends/neon/workloads/NeonWorkloads.hpp @@ -18,6 +18,7 @@ #include "NeonGreaterWorkload.hpp" #include "NeonL2NormalizationFloatWorkload.hpp" #include "NeonLstmFloatWorkload.hpp" +#include "NeonQuantizedLstmWorkload.hpp" #include "NeonMaximumWorkload.hpp" #include "NeonMeanWorkload.hpp" #include "NeonConcatWorkload.hpp" |