aboutsummaryrefslogtreecommitdiff
path: root/src/backends/neon/test/NeonCreateWorkloadTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/neon/test/NeonCreateWorkloadTests.cpp')
-rw-r--r--src/backends/neon/test/NeonCreateWorkloadTests.cpp33
1 files changed, 32 insertions, 1 deletions
diff --git a/src/backends/neon/test/NeonCreateWorkloadTests.cpp b/src/backends/neon/test/NeonCreateWorkloadTests.cpp
index 73491c7810..37d026f107 100644
--- a/src/backends/neon/test/NeonCreateWorkloadTests.cpp
+++ b/src/backends/neon/test/NeonCreateWorkloadTests.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -813,6 +813,37 @@ BOOST_AUTO_TEST_CASE(CreateL2NormalizationNhwcWorkload)
NeonCreateL2NormalizationWorkloadTest<NeonL2NormalizationFloatWorkload, DataType::Float32>(DataLayout::NHWC);
}
+template <typename LogSoftmaxWorkloadType, typename armnn::DataType DataType>
+static void NeonCreateLogSoftmaxWorkloadTest()
+{
+ Graph graph;
+ NeonWorkloadFactory factory =
+ NeonWorkloadFactoryHelper::GetFactory(NeonWorkloadFactoryHelper::GetMemoryManager());
+
+ auto workload = CreateLogSoftmaxWorkloadTest<LogSoftmaxWorkloadType, DataType>(factory, graph);
+
+ // Checks that outputs and inputs are as we expect them (see definition of CreateLogSoftmaxWorkloadTest).
+ LogSoftmaxQueueDescriptor queueDescriptor = workload->GetData();
+ auto inputHandle = PolymorphicDowncast<IAclTensorHandle*>(queueDescriptor.m_Inputs[0]);
+ auto outputHandle = PolymorphicDowncast<IAclTensorHandle*>(queueDescriptor.m_Outputs[0]);
+ armnn::TensorInfo tensorInfo({4, 1}, DataType);
+
+ BOOST_TEST(TestNeonTensorHandleInfo(inputHandle, tensorInfo));
+ BOOST_TEST(TestNeonTensorHandleInfo(outputHandle, tensorInfo));
+}
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+BOOST_AUTO_TEST_CASE(CreateLogSoftmaxFloat16Workload)
+{
+ NeonCreateLogSoftmaxWorkloadTest<NeonLogSoftmaxWorkload, DataType::Float16>();
+}
+#endif
+
+BOOST_AUTO_TEST_CASE(CreateLogSoftmaxFloatWorkload)
+{
+ NeonCreateLogSoftmaxWorkloadTest<NeonLogSoftmaxWorkload, DataType::Float32>();
+}
+
template <typename LstmWorkloadType>
static void NeonCreateLstmWorkloadTest()
{