aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrancis Murtagh <francis.murtagh@arm.com>2019-08-02 13:20:54 +0100
committerÁron Virginás-Tar <aron.virginas-tar@arm.com>2019-08-05 15:19:07 +0000
commit4fc3c48c2d230d8c55aa01aa98e32b6df7cafc0c (patch)
tree3ee1e3e59ed6cdd3c86377d260374e21d4fde923
parentf0a0a9ec1e8188e6494d57160341b5bb8a4c3bd7 (diff)
downloadarmnn-4fc3c48c2d230d8c55aa01aa98e32b6df7cafc0c.tar.gz
IVGCVSW-3341 Add Neon backend support for Quantized_LSTM (16bit cell state)
* Add Neon Workload * Update NeonWorkloads.hpp * Update NeonWorkloadFactory * Update NeonLayerSupport * Update backends.mk and CMakeLists.txt * Add NeonCreateWorkload test * Enable LayerTest !android-nn-driver:1685 Change-Id: Idd799bbf039acf0d59084d02c3b57766ce3691b5 Signed-off-by: Francis Murtagh <francis.murtagh@arm.com> Signed-off-by: Matthew Bentham <Matthew.Bentham@arm.com>
-rw-r--r--src/backends/neon/NeonLayerSupport.cpp19
-rw-r--r--src/backends/neon/NeonLayerSupport.hpp8
-rw-r--r--src/backends/neon/NeonWorkloadFactory.cpp6
-rw-r--r--src/backends/neon/NeonWorkloadFactory.hpp3
-rw-r--r--src/backends/neon/backend.mk1
-rw-r--r--src/backends/neon/test/NeonCreateWorkloadTests.cpp39
-rw-r--r--src/backends/neon/test/NeonLayerTests.cpp2
-rw-r--r--src/backends/neon/workloads/CMakeLists.txt2
-rw-r--r--src/backends/neon/workloads/NeonQuantizedLstmWorkload.cpp210
-rw-r--r--src/backends/neon/workloads/NeonQuantizedLstmWorkload.hpp52
-rw-r--r--src/backends/neon/workloads/NeonWorkloads.hpp1
11 files changed, 343 insertions, 0 deletions
diff --git a/src/backends/neon/NeonLayerSupport.cpp b/src/backends/neon/NeonLayerSupport.cpp
index dac3525f60..b61279c133 100644
--- a/src/backends/neon/NeonLayerSupport.cpp
+++ b/src/backends/neon/NeonLayerSupport.cpp
@@ -39,6 +39,7 @@
#include "workloads/NeonPooling2dWorkload.hpp"
#include "workloads/NeonPreluWorkload.hpp"
#include "workloads/NeonQuantizeWorkload.hpp"
+#include "workloads/NeonQuantizedLstmWorkload.hpp"
#include "workloads/NeonResizeWorkload.hpp"
#include "workloads/NeonSoftmaxBaseWorkload.hpp"
#include "workloads/NeonSpaceToDepthWorkload.hpp"
@@ -487,6 +488,24 @@ bool NeonLayerSupport::IsQuantizeSupported(const TensorInfo& input,
output);
}
+bool NeonLayerSupport::IsQuantizedLstmSupported(const TensorInfo& input,
+ const TensorInfo& cellStateIn,
+ const TensorInfo& outputStateIn,
+ const TensorInfo& cellStateOut,
+ const TensorInfo& outputStateOut,
+ const QuantizedLstmInputParamsInfo& paramsInfo,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ FORWARD_WORKLOAD_VALIDATE_FUNC(NeonQuantizedLstmWorkloadValidate,
+ reasonIfUnsupported,
+ input,
+ cellStateIn,
+ outputStateIn,
+ cellStateOut,
+ outputStateOut,
+ paramsInfo);
+}
+
bool NeonLayerSupport::IsReshapeSupported(const TensorInfo& input,
const ReshapeDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported) const
diff --git a/src/backends/neon/NeonLayerSupport.hpp b/src/backends/neon/NeonLayerSupport.hpp
index 078d2f619b..acaebc4c58 100644
--- a/src/backends/neon/NeonLayerSupport.hpp
+++ b/src/backends/neon/NeonLayerSupport.hpp
@@ -165,6 +165,14 @@ public:
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ bool IsQuantizedLstmSupported(const TensorInfo& input,
+ const TensorInfo& cellStateIn,
+ const TensorInfo& outputStateIn,
+ const TensorInfo& cellStateOut,
+ const TensorInfo& outputStateOut,
+ const QuantizedLstmInputParamsInfo& paramsInfo,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
bool IsReshapeSupported(const TensorInfo& input,
const ReshapeDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
diff --git a/src/backends/neon/NeonWorkloadFactory.cpp b/src/backends/neon/NeonWorkloadFactory.cpp
index fd0381c26d..0e66bfc757 100644
--- a/src/backends/neon/NeonWorkloadFactory.cpp
+++ b/src/backends/neon/NeonWorkloadFactory.cpp
@@ -319,6 +319,12 @@ std::unique_ptr<IWorkload> NeonWorkloadFactory::CreateLstm(const LstmQueueDescri
return MakeWorkloadHelper<NeonLstmFloatWorkload, NullWorkload>(descriptor, info);
}
+std::unique_ptr<IWorkload> NeonWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
+{
+ return std::make_unique<NeonQuantizedLstmWorkload>(descriptor, info);
+}
+
std::unique_ptr<IWorkload> NeonWorkloadFactory::CreateConvertFp16ToFp32(
const ConvertFp16ToFp32QueueDescriptor& descriptor,
const WorkloadInfo& info) const
diff --git a/src/backends/neon/NeonWorkloadFactory.hpp b/src/backends/neon/NeonWorkloadFactory.hpp
index 360dc7c61b..b9995d8b4b 100644
--- a/src/backends/neon/NeonWorkloadFactory.hpp
+++ b/src/backends/neon/NeonWorkloadFactory.hpp
@@ -135,6 +135,9 @@ public:
std::unique_ptr<IWorkload> CreateLstm(const LstmQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
+
std::unique_ptr<IWorkload> CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
diff --git a/src/backends/neon/backend.mk b/src/backends/neon/backend.mk
index 98755e9b0a..d5483b0c7d 100644
--- a/src/backends/neon/backend.mk
+++ b/src/backends/neon/backend.mk
@@ -43,6 +43,7 @@ BACKEND_SOURCES := \
workloads/NeonPermuteWorkload.cpp \
workloads/NeonPooling2dWorkload.cpp \
workloads/NeonPreluWorkload.cpp \
+ workloads/NeonQuantizedLstmWorkload.cpp \
workloads/NeonQuantizeWorkload.cpp \
workloads/NeonReshapeWorkload.cpp \
workloads/NeonResizeWorkload.cpp \
diff --git a/src/backends/neon/test/NeonCreateWorkloadTests.cpp b/src/backends/neon/test/NeonCreateWorkloadTests.cpp
index 848af1285f..056bfb283f 100644
--- a/src/backends/neon/test/NeonCreateWorkloadTests.cpp
+++ b/src/backends/neon/test/NeonCreateWorkloadTests.cpp
@@ -5,6 +5,7 @@
#include "NeonWorkloadFactoryHelper.hpp"
+#include <aclCommon/ArmComputeTensorUtils.hpp>
#include <backendsCommon/MemCopyWorkload.hpp>
#include <aclCommon/test/CreateWorkloadClNeon.hpp>
@@ -873,4 +874,42 @@ BOOST_AUTO_TEST_CASE(CreateStackUint8Workload)
NeonCreateStackWorkloadTest<armnn::DataType::QuantisedAsymm8>({ 3, 4, 5 }, { 3, 4, 2, 5 }, 2, 2);
}
+template <typename QuantizedLstmWorkloadType>
+static void NeonCreateQuantizedLstmWorkloadTest()
+{
+ using boost::polymorphic_downcast;
+
+ Graph graph;
+ NeonWorkloadFactory factory = NeonWorkloadFactoryHelper::GetFactory(NeonWorkloadFactoryHelper::GetMemoryManager());
+
+ auto workload = CreateQuantizedLstmWorkloadTest<QuantizedLstmWorkloadType>(factory, graph);
+
+ QuantizedLstmQueueDescriptor queueDescriptor = workload->GetData();
+
+ IAclTensorHandle* inputHandle = polymorphic_downcast<IAclTensorHandle*>(queueDescriptor.m_Inputs[0]);
+ BOOST_TEST((inputHandle->GetShape() == TensorShape({2, 2})));
+ BOOST_TEST((inputHandle->GetDataType() == arm_compute::DataType::QASYMM8));
+
+ IAclTensorHandle* cellStateInHandle = polymorphic_downcast<IAclTensorHandle*>(queueDescriptor.m_Inputs[1]);
+ BOOST_TEST((cellStateInHandle->GetShape() == TensorShape({2, 4})));
+ BOOST_TEST((cellStateInHandle->GetDataType() == arm_compute::DataType::QSYMM16));
+
+ IAclTensorHandle* outputStateInHandle = polymorphic_downcast<IAclTensorHandle*>(queueDescriptor.m_Inputs[2]);
+ BOOST_TEST((outputStateInHandle->GetShape() == TensorShape({2, 4})));
+ BOOST_TEST((outputStateInHandle->GetDataType() == arm_compute::DataType::QASYMM8));
+
+ IAclTensorHandle* cellStateOutHandle = polymorphic_downcast<IAclTensorHandle*>(queueDescriptor.m_Outputs[0]);
+ BOOST_TEST((cellStateOutHandle->GetShape() == TensorShape({2, 4})));
+ BOOST_TEST((cellStateOutHandle->GetDataType() == arm_compute::DataType::QSYMM16));
+
+ IAclTensorHandle* outputStateOutHandle = polymorphic_downcast<IAclTensorHandle*>(queueDescriptor.m_Outputs[1]);
+ BOOST_TEST((outputStateOutHandle->GetShape() == TensorShape({2, 4})));
+ BOOST_TEST((outputStateOutHandle->GetDataType() == arm_compute::DataType::QASYMM8));
+}
+
+BOOST_AUTO_TEST_CASE(CreateQuantizedLstmWorkload)
+{
+ NeonCreateQuantizedLstmWorkloadTest<NeonQuantizedLstmWorkload>();
+}
+
BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/backends/neon/test/NeonLayerTests.cpp b/src/backends/neon/test/NeonLayerTests.cpp
index dd30536ac9..ed99461b31 100644
--- a/src/backends/neon/test/NeonLayerTests.cpp
+++ b/src/backends/neon/test/NeonLayerTests.cpp
@@ -529,6 +529,8 @@ ARMNN_AUTO_TEST_CASE(LstmLayerFloat32NoCifgWithPeepholeWithProjection,
ARMNN_AUTO_TEST_CASE(LstmLayerFloat32NoCifgWithPeepholeWithProjectionWithLayerNorm,
LstmLayerFloat32NoCifgWithPeepholeWithProjectionWithLayerNormTest)
+ARMNN_AUTO_TEST_CASE(QuantizedLstm, QuantizedLstmTest)
+
// Mean
ARMNN_AUTO_TEST_CASE(MeanSimpleFloat32, MeanSimpleTest<armnn::DataType::Float32>)
ARMNN_AUTO_TEST_CASE(MeanSimpleAxisFloat32, MeanSimpleAxisTest<armnn::DataType::Float32>)
diff --git a/src/backends/neon/workloads/CMakeLists.txt b/src/backends/neon/workloads/CMakeLists.txt
index dea0228377..34fe0723af 100644
--- a/src/backends/neon/workloads/CMakeLists.txt
+++ b/src/backends/neon/workloads/CMakeLists.txt
@@ -52,6 +52,8 @@ list(APPEND armnnNeonBackendWorkloads_sources
NeonPooling2dWorkload.hpp
NeonPreluWorkload.cpp
NeonPreluWorkload.hpp
+ NeonQuantizedLstmWorkload.cpp
+ NeonQuantizedLstmWorkload.hpp
NeonQuantizeWorkload.cpp
NeonQuantizeWorkload.hpp
NeonReshapeWorkload.cpp
diff --git a/src/backends/neon/workloads/NeonQuantizedLstmWorkload.cpp b/src/backends/neon/workloads/NeonQuantizedLstmWorkload.cpp
new file mode 100644
index 0000000000..d4319d414d
--- /dev/null
+++ b/src/backends/neon/workloads/NeonQuantizedLstmWorkload.cpp
@@ -0,0 +1,210 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "NeonQuantizedLstmWorkload.hpp"
+#include "NeonWorkloadUtils.hpp"
+
+#include <aclCommon/ArmComputeTensorUtils.hpp>
+#include <backendsCommon/CpuTensorHandle.hpp>
+#include <neon/NeonTensorHandle.hpp>
+
+namespace armnn
+{
+using namespace armcomputetensorutils;
+
+NeonQuantizedLstmWorkload::NeonQuantizedLstmWorkload(const QuantizedLstmQueueDescriptor &descriptor,
+ const WorkloadInfo &info)
+ : BaseWorkload<QuantizedLstmQueueDescriptor>(descriptor, info)
+{
+ // Basic parameters
+ m_InputToInputWeightsTensor = std::make_unique<arm_compute::Tensor>();
+ BuildArmComputeTensor(*m_InputToInputWeightsTensor, m_Data.m_InputToInputWeights->GetTensorInfo());
+
+ m_InputToForgetWeightsTensor = std::make_unique<arm_compute::Tensor>();
+ BuildArmComputeTensor(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights->GetTensorInfo());
+
+ m_InputToCellWeightsTensor = std::make_unique<arm_compute::Tensor>();
+ BuildArmComputeTensor(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights->GetTensorInfo());
+
+ m_InputToOutputWeightsTensor = std::make_unique<arm_compute::Tensor>();
+ BuildArmComputeTensor(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights->GetTensorInfo());
+
+ m_RecurrentToInputWeightsTensor = std::make_unique<arm_compute::Tensor>();
+ BuildArmComputeTensor(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights->GetTensorInfo());
+
+ m_RecurrentToForgetWeightsTensor = std::make_unique<arm_compute::Tensor>();
+ BuildArmComputeTensor(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights->GetTensorInfo());
+
+ m_RecurrentToCellWeightsTensor = std::make_unique<arm_compute::Tensor>();
+ BuildArmComputeTensor(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights->GetTensorInfo());
+
+ m_RecurrentToOutputWeightsTensor = std::make_unique<arm_compute::Tensor>();
+ BuildArmComputeTensor(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights->GetTensorInfo());
+
+ m_InputGateBiasTensor = std::make_unique<arm_compute::Tensor>();
+ BuildArmComputeTensor(*m_InputGateBiasTensor, m_Data.m_InputGateBias->GetTensorInfo());
+
+ m_ForgetGateBiasTensor = std::make_unique<arm_compute::Tensor>();
+ BuildArmComputeTensor(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias->GetTensorInfo());
+
+ m_CellBiasTensor = std::make_unique<arm_compute::Tensor>();
+ BuildArmComputeTensor(*m_CellBiasTensor, m_Data.m_CellBias->GetTensorInfo());
+
+ m_OutputGateBiasTensor = std::make_unique<arm_compute::Tensor>();
+ BuildArmComputeTensor(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias->GetTensorInfo());
+
+ const arm_compute::ITensor& input = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
+ arm_compute::ITensor& cell_state_in = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
+ const arm_compute::ITensor& output_state_in = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[2])->GetTensor();
+
+ arm_compute::ITensor& cell_state_out = static_cast<IAclTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
+ arm_compute::ITensor& output_state_out = static_cast<IAclTensorHandle*>(m_Data.m_Outputs[1])->GetTensor();
+
+ m_QuantizedLstmLayer.configure(&input,
+ m_InputToInputWeightsTensor.get(),
+ m_InputToForgetWeightsTensor.get(),
+ m_InputToCellWeightsTensor.get(),
+ m_InputToOutputWeightsTensor.get(),
+ m_RecurrentToInputWeightsTensor.get(),
+ m_RecurrentToForgetWeightsTensor.get(),
+ m_RecurrentToCellWeightsTensor.get(),
+ m_RecurrentToOutputWeightsTensor.get(),
+ m_InputGateBiasTensor.get(),
+ m_ForgetGateBiasTensor.get(),
+ m_CellBiasTensor.get(),
+ m_OutputGateBiasTensor.get(),
+ &cell_state_in,
+ &output_state_in,
+ &cell_state_out,
+ &output_state_out);
+
+ InitializeArmComputeTensorData(*m_InputToInputWeightsTensor,
+ m_Data.m_InputToInputWeights);
+
+ InitializeArmComputeTensorData(*m_InputToForgetWeightsTensor,
+ m_Data.m_InputToForgetWeights);
+
+ InitializeArmComputeTensorData(*m_InputToCellWeightsTensor,
+ m_Data.m_InputToCellWeights);
+
+ InitializeArmComputeTensorData(*m_InputToOutputWeightsTensor,
+ m_Data.m_InputToOutputWeights);
+
+ InitializeArmComputeTensorData(*m_RecurrentToInputWeightsTensor,
+ m_Data.m_RecurrentToInputWeights);
+
+ InitializeArmComputeTensorData(*m_RecurrentToForgetWeightsTensor,
+ m_Data.m_RecurrentToForgetWeights);
+
+ InitializeArmComputeTensorData(*m_RecurrentToCellWeightsTensor,
+ m_Data.m_RecurrentToCellWeights);
+
+ InitializeArmComputeTensorData(*m_RecurrentToOutputWeightsTensor,
+ m_Data.m_RecurrentToOutputWeights);
+
+ InitializeArmComputeTensorData(*m_InputGateBiasTensor,
+ m_Data.m_InputGateBias);
+
+ InitializeArmComputeTensorData(*m_ForgetGateBiasTensor,
+ m_Data.m_ForgetGateBias);
+
+ InitializeArmComputeTensorData(*m_CellBiasTensor,
+ m_Data.m_CellBias);
+
+ InitializeArmComputeTensorData(*m_OutputGateBiasTensor,
+ m_Data.m_OutputGateBias);
+
+ // Force Compute Library to perform the necessary copying and reshaping, after which
+ // delete all the input tensors that will no longer be needed
+ m_QuantizedLstmLayer.prepare();
+ FreeUnusedTensors();
+}
+
+void NeonQuantizedLstmWorkload::Execute() const
+{
+ m_QuantizedLstmLayer.run();
+}
+
+arm_compute::Status NeonQuantizedLstmWorkloadValidate(const TensorInfo& input,
+ const TensorInfo& cellStateIn,
+ const TensorInfo& outputStateIn,
+ const TensorInfo& cellStateOut,
+ const TensorInfo& outputStateOut,
+ const QuantizedLstmInputParamsInfo& paramsInfo)
+{
+ // The inputs and outputs
+ const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input);
+ const arm_compute::TensorInfo aclCellStateInInfo = BuildArmComputeTensorInfo(cellStateIn);
+ const arm_compute::TensorInfo aclOutputStateInInfo = BuildArmComputeTensorInfo(outputStateIn);
+ const arm_compute::TensorInfo aclCellStateOutInfo = BuildArmComputeTensorInfo(cellStateOut);
+ const arm_compute::TensorInfo aclOutputStateOutInfo = BuildArmComputeTensorInfo(outputStateOut);
+
+ // Basic parameters
+ const arm_compute::TensorInfo aclInputToInputWeightsInfo
+ = BuildArmComputeTensorInfo(paramsInfo.get_InputToInputWeights());
+ const arm_compute::TensorInfo aclInputToForgetWeightsInfo
+ = BuildArmComputeTensorInfo(paramsInfo.get_InputToForgetWeights());
+ const arm_compute::TensorInfo aclInputToCellWeightsInfo
+ = BuildArmComputeTensorInfo(paramsInfo.get_InputToCellWeights());
+ const arm_compute::TensorInfo aclInputToOutputWeightsInfo
+ = BuildArmComputeTensorInfo(paramsInfo.get_InputToOutputWeights());
+
+ const arm_compute::TensorInfo aclRecurrentToInputWeightsInfo
+ = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToInputWeights());
+ const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo
+ = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToForgetWeights());
+ const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo
+ = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToCellWeights());
+ const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo
+ = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToOutputWeights());
+
+ const arm_compute::TensorInfo aclInputGateBiasInfo
+ = BuildArmComputeTensorInfo(paramsInfo.get_InputGateBias());
+ const arm_compute::TensorInfo aclForgetGateBiasInfo
+ = BuildArmComputeTensorInfo(paramsInfo.get_ForgetGateBias());
+ const arm_compute::TensorInfo aclCellBiasInfo
+ = BuildArmComputeTensorInfo(paramsInfo.get_CellBias());
+ const arm_compute::TensorInfo aclOutputGateBiasInfo
+ = BuildArmComputeTensorInfo(paramsInfo.get_OutputGateBias());
+
+ return arm_compute::NELSTMLayerQuantized::validate(&aclInputInfo,
+ &aclInputToInputWeightsInfo,
+ &aclInputToForgetWeightsInfo,
+ &aclInputToCellWeightsInfo,
+ &aclInputToOutputWeightsInfo,
+ &aclRecurrentToInputWeightsInfo,
+ &aclRecurrentToForgetWeightsInfo,
+ &aclRecurrentToCellWeightsInfo,
+ &aclRecurrentToOutputWeightsInfo,
+ &aclInputGateBiasInfo,
+ &aclForgetGateBiasInfo,
+ &aclCellBiasInfo,
+ &aclOutputGateBiasInfo,
+ &aclCellStateInInfo,
+ &aclOutputStateInInfo,
+ &aclCellStateOutInfo,
+ &aclOutputStateOutInfo);
+}
+
+void NeonQuantizedLstmWorkload::FreeUnusedTensors()
+{
+ FreeTensorIfUnused(m_InputToInputWeightsTensor);
+ FreeTensorIfUnused(m_InputToForgetWeightsTensor);
+ FreeTensorIfUnused(m_InputToCellWeightsTensor);
+ FreeTensorIfUnused(m_InputToOutputWeightsTensor);
+ FreeTensorIfUnused(m_RecurrentToInputWeightsTensor);
+ FreeTensorIfUnused(m_RecurrentToForgetWeightsTensor);
+ FreeTensorIfUnused(m_RecurrentToCellWeightsTensor);
+ FreeTensorIfUnused(m_RecurrentToOutputWeightsTensor);
+ FreeTensorIfUnused(m_InputGateBiasTensor);
+ FreeTensorIfUnused(m_ForgetGateBiasTensor);
+ FreeTensorIfUnused(m_CellBiasTensor);
+ FreeTensorIfUnused(m_OutputGateBiasTensor);
+ FreeTensorIfUnused(m_CellStateInTensor);
+ FreeTensorIfUnused(m_OutputStateInTensor);
+ FreeTensorIfUnused(m_CellStateOutTensor);
+}
+
+} //namespace armnn
diff --git a/src/backends/neon/workloads/NeonQuantizedLstmWorkload.hpp b/src/backends/neon/workloads/NeonQuantizedLstmWorkload.hpp
new file mode 100644
index 0000000000..ab8ea71437
--- /dev/null
+++ b/src/backends/neon/workloads/NeonQuantizedLstmWorkload.hpp
@@ -0,0 +1,52 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <backendsCommon/Workload.hpp>
+#include <backendsCommon/WorkloadData.hpp>
+
+#include <arm_compute/graph/Tensor.h>
+#include <arm_compute/runtime/NEON/functions/NELSTMLayerQuantized.h>
+
+namespace armnn
+{
+
+class NeonQuantizedLstmWorkload : public BaseWorkload<QuantizedLstmQueueDescriptor>
+{
+public:
+ NeonQuantizedLstmWorkload(const QuantizedLstmQueueDescriptor& descriptor, const WorkloadInfo& info);
+ virtual void Execute() const override;
+
+private:
+ mutable arm_compute::NELSTMLayerQuantized m_QuantizedLstmLayer;
+
+ std::unique_ptr<arm_compute::Tensor> m_InputToInputWeightsTensor;
+ std::unique_ptr<arm_compute::Tensor> m_InputToForgetWeightsTensor;
+ std::unique_ptr<arm_compute::Tensor> m_InputToCellWeightsTensor;
+ std::unique_ptr<arm_compute::Tensor> m_InputToOutputWeightsTensor;
+ std::unique_ptr<arm_compute::Tensor> m_RecurrentToInputWeightsTensor;
+ std::unique_ptr<arm_compute::Tensor> m_RecurrentToForgetWeightsTensor;
+ std::unique_ptr<arm_compute::Tensor> m_RecurrentToCellWeightsTensor;
+ std::unique_ptr<arm_compute::Tensor> m_RecurrentToOutputWeightsTensor;
+ std::unique_ptr<arm_compute::Tensor> m_InputGateBiasTensor;
+ std::unique_ptr<arm_compute::Tensor> m_ForgetGateBiasTensor;
+ std::unique_ptr<arm_compute::Tensor> m_CellBiasTensor;
+ std::unique_ptr<arm_compute::Tensor> m_OutputGateBiasTensor;
+ std::unique_ptr<arm_compute::Tensor> m_CellStateInTensor;
+ std::unique_ptr<arm_compute::Tensor> m_OutputStateInTensor;
+ std::unique_ptr<arm_compute::Tensor> m_CellStateOutTensor;
+
+ void FreeUnusedTensors();
+};
+
+arm_compute::Status NeonQuantizedLstmWorkloadValidate(const TensorInfo& input,
+ const TensorInfo& outputStateIn,
+ const TensorInfo& cellStateIn,
+ const TensorInfo& outputStateOut,
+ const TensorInfo& cellStateOut,
+ const QuantizedLstmInputParamsInfo& paramsInfo);
+
+} //namespace armnn
diff --git a/src/backends/neon/workloads/NeonWorkloads.hpp b/src/backends/neon/workloads/NeonWorkloads.hpp
index 7cb6c3b7b1..8fc684e3e9 100644
--- a/src/backends/neon/workloads/NeonWorkloads.hpp
+++ b/src/backends/neon/workloads/NeonWorkloads.hpp
@@ -18,6 +18,7 @@
#include "NeonGreaterWorkload.hpp"
#include "NeonL2NormalizationFloatWorkload.hpp"
#include "NeonLstmFloatWorkload.hpp"
+#include "NeonQuantizedLstmWorkload.hpp"
#include "NeonMaximumWorkload.hpp"
#include "NeonMeanWorkload.hpp"
#include "NeonConcatWorkload.hpp"