diff options
Diffstat (limited to 'src/backends/neon/test/NeonCreateWorkloadTests.cpp')
-rw-r--r-- | src/backends/neon/test/NeonCreateWorkloadTests.cpp | 33 |
1 files changed, 32 insertions, 1 deletions
diff --git a/src/backends/neon/test/NeonCreateWorkloadTests.cpp b/src/backends/neon/test/NeonCreateWorkloadTests.cpp index 73491c7810..37d026f107 100644 --- a/src/backends/neon/test/NeonCreateWorkloadTests.cpp +++ b/src/backends/neon/test/NeonCreateWorkloadTests.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -813,6 +813,37 @@ BOOST_AUTO_TEST_CASE(CreateL2NormalizationNhwcWorkload) NeonCreateL2NormalizationWorkloadTest<NeonL2NormalizationFloatWorkload, DataType::Float32>(DataLayout::NHWC); } +template <typename LogSoftmaxWorkloadType, typename armnn::DataType DataType> +static void NeonCreateLogSoftmaxWorkloadTest() +{ + Graph graph; + NeonWorkloadFactory factory = + NeonWorkloadFactoryHelper::GetFactory(NeonWorkloadFactoryHelper::GetMemoryManager()); + + auto workload = CreateLogSoftmaxWorkloadTest<LogSoftmaxWorkloadType, DataType>(factory, graph); + + // Checks that outputs and inputs are as we expect them (see definition of CreateLogSoftmaxWorkloadTest). + LogSoftmaxQueueDescriptor queueDescriptor = workload->GetData(); + auto inputHandle = PolymorphicDowncast<IAclTensorHandle*>(queueDescriptor.m_Inputs[0]); + auto outputHandle = PolymorphicDowncast<IAclTensorHandle*>(queueDescriptor.m_Outputs[0]); + armnn::TensorInfo tensorInfo({4, 1}, DataType); + + BOOST_TEST(TestNeonTensorHandleInfo(inputHandle, tensorInfo)); + BOOST_TEST(TestNeonTensorHandleInfo(outputHandle, tensorInfo)); +} + +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +BOOST_AUTO_TEST_CASE(CreateLogSoftmaxFloat16Workload) +{ + NeonCreateLogSoftmaxWorkloadTest<NeonLogSoftmaxWorkload, DataType::Float16>(); +} +#endif + +BOOST_AUTO_TEST_CASE(CreateLogSoftmaxFloatWorkload) +{ + NeonCreateLogSoftmaxWorkloadTest<NeonLogSoftmaxWorkload, DataType::Float32>(); +} + template <typename LstmWorkloadType> static void NeonCreateLstmWorkloadTest() { |