From 737d9ff58b348b11234b6c2363390607d576177d Mon Sep 17 00:00:00 2001 From: Ferran Balaguer Date: Thu, 1 Aug 2019 09:58:08 +0100 Subject: IVGCVSW-3342 Add CL backend support for Quantized_LSTM (16bit cell state) !android-nn-driver:1685 Signed-off-by: Ferran Balaguer Signed-off-by: Matthew Bentham Change-Id: I17278562f72d4b77e22c3af25bf7199b9150a765 --- src/backends/cl/test/ClCreateWorkloadTests.cpp | 41 ++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) (limited to 'src/backends/cl/test/ClCreateWorkloadTests.cpp') diff --git a/src/backends/cl/test/ClCreateWorkloadTests.cpp b/src/backends/cl/test/ClCreateWorkloadTests.cpp index f453ccc9fd..bb36504214 100644 --- a/src/backends/cl/test/ClCreateWorkloadTests.cpp +++ b/src/backends/cl/test/ClCreateWorkloadTests.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -974,4 +975,44 @@ BOOST_AUTO_TEST_CASE(CreateStackUint8Workload) ClCreateStackWorkloadTest({ 3, 4, 5 }, { 3, 4, 2, 5 }, 2, 2); } +template +static void ClCreateQuantizedLstmWorkloadTest() +{ + using namespace armnn::armcomputetensorutils; + using boost::polymorphic_downcast; + + Graph graph; + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); + + auto workload = CreateQuantizedLstmWorkloadTest(factory, graph); + + QuantizedLstmQueueDescriptor queueDescriptor = workload->GetData(); + + IAclTensorHandle* inputHandle = polymorphic_downcast(queueDescriptor.m_Inputs[0]); + BOOST_TEST((inputHandle->GetShape() == TensorShape({2, 2}))); + BOOST_TEST((inputHandle->GetDataType() == arm_compute::DataType::QASYMM8)); + + IAclTensorHandle* cellStateInHandle = polymorphic_downcast(queueDescriptor.m_Inputs[1]); + BOOST_TEST((cellStateInHandle->GetShape() == TensorShape({2, 4}))); + BOOST_TEST((cellStateInHandle->GetDataType() == arm_compute::DataType::QSYMM16)); + + IAclTensorHandle* outputStateInHandle = polymorphic_downcast(queueDescriptor.m_Inputs[2]); + BOOST_TEST((outputStateInHandle->GetShape() == TensorShape({2, 4}))); + BOOST_TEST((outputStateInHandle->GetDataType() == arm_compute::DataType::QASYMM8)); + + IAclTensorHandle* cellStateOutHandle = polymorphic_downcast(queueDescriptor.m_Outputs[0]); + BOOST_TEST((cellStateOutHandle->GetShape() == TensorShape({2, 4}))); + BOOST_TEST((cellStateOutHandle->GetDataType() == arm_compute::DataType::QSYMM16)); + + IAclTensorHandle* outputStateOutHandle = polymorphic_downcast(queueDescriptor.m_Outputs[1]); + BOOST_TEST((outputStateOutHandle->GetShape() == TensorShape({2, 4}))); + BOOST_TEST((outputStateOutHandle->GetDataType() == arm_compute::DataType::QASYMM8)); +} + +BOOST_AUTO_TEST_CASE(CreateQuantizedLstmWorkload) +{ + ClCreateQuantizedLstmWorkloadTest(); +} + BOOST_AUTO_TEST_SUITE_END() -- cgit v1.2.1