diff options
Diffstat (limited to 'src/backends/neon/test/NeonCreateWorkloadTests.cpp')
-rw-r--r-- | src/backends/neon/test/NeonCreateWorkloadTests.cpp | 18 |
1 files changed, 16 insertions, 2 deletions
diff --git a/src/backends/neon/test/NeonCreateWorkloadTests.cpp b/src/backends/neon/test/NeonCreateWorkloadTests.cpp index 447bad155f..a89602db7f 100644 --- a/src/backends/neon/test/NeonCreateWorkloadTests.cpp +++ b/src/backends/neon/test/NeonCreateWorkloadTests.cpp @@ -325,8 +325,12 @@ static void NeonCreateFullyConnectedWorkloadTest() FullyConnectedQueueDescriptor queueDescriptor = workload->GetData(); auto inputHandle = PolymorphicDowncast<IAclTensorHandle*>(queueDescriptor.m_Inputs[0]); auto outputHandle = PolymorphicDowncast<IAclTensorHandle*>(queueDescriptor.m_Outputs[0]); - BOOST_TEST(TestNeonTensorHandleInfo(inputHandle, TensorInfo({3, 1, 4, 5}, DataType))); - BOOST_TEST(TestNeonTensorHandleInfo(outputHandle, TensorInfo({3, 7}, DataType))); + + // Checks that outputs and inputs are as we expect them (see definition of CreateFullyConnectedWorkloadTest). + float inputsQScale = DataType == armnn::DataType::QAsymmU8 ? 1.0f : 0.0; + float outputQScale = DataType == armnn::DataType::QAsymmU8 ? 2.0f : 0.0; + BOOST_TEST(TestNeonTensorHandleInfo(inputHandle, TensorInfo({3, 1, 4, 5}, DataType, inputsQScale))); + BOOST_TEST(TestNeonTensorHandleInfo(outputHandle, TensorInfo({3, 7}, DataType, outputQScale))); } #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -341,6 +345,16 @@ BOOST_AUTO_TEST_CASE(CreateFullyConnectedFloatWorkload) NeonCreateFullyConnectedWorkloadTest<NeonFullyConnectedWorkload, DataType::Float32>(); } +BOOST_AUTO_TEST_CASE(CreateFullyConnectedQAsymmU8Workload) +{ + NeonCreateFullyConnectedWorkloadTest<NeonFullyConnectedWorkload, DataType::QAsymmU8>(); +} + +BOOST_AUTO_TEST_CASE(CreateFullyConnectedQAsymmS8Workload) +{ + NeonCreateFullyConnectedWorkloadTest<NeonFullyConnectedWorkload, DataType::QAsymmS8>(); +} + template <typename NormalizationWorkloadType, typename armnn::DataType DataType> static void NeonCreateNormalizationWorkloadTest(DataLayout dataLayout) { |