diff options
Diffstat (limited to 'src/backends/cl/test/ClCreateWorkloadTests.cpp')
-rw-r--r-- | src/backends/cl/test/ClCreateWorkloadTests.cpp | 25 |
1 files changed, 24 insertions, 1 deletions
diff --git a/src/backends/cl/test/ClCreateWorkloadTests.cpp b/src/backends/cl/test/ClCreateWorkloadTests.cpp index 896d486ebf..1dd0abeadd 100644 --- a/src/backends/cl/test/ClCreateWorkloadTests.cpp +++ b/src/backends/cl/test/ClCreateWorkloadTests.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -793,6 +793,29 @@ BOOST_AUTO_TEST_CASE(CreateL2NormalizationFloat16NhwcWorkload) ClL2NormalizationWorkloadTest<ClL2NormalizationFloatWorkload, armnn::DataType::Float16>(DataLayout::NHWC); } +template <typename LogSoftmaxWorkloadType, typename armnn::DataType DataType> +static void ClCreateLogSoftmaxWorkloadTest() +{ + Graph graph; + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); + + auto workload = CreateLogSoftmaxWorkloadTest<LogSoftmaxWorkloadType, DataType>(factory, graph); + + // Checks that outputs and inputs are as we expect them (see definition of CreateLogSoftmaxWorkloadTest). + LogSoftmaxQueueDescriptor queueDescriptor = workload->GetData(); + auto inputHandle = PolymorphicDowncast<IClTensorHandle*>(queueDescriptor.m_Inputs[0]); + auto outputHandle = PolymorphicDowncast<IClTensorHandle*>(queueDescriptor.m_Outputs[0]); + + BOOST_TEST(CompareIClTensorHandleShape(inputHandle, {4, 1})); + BOOST_TEST(CompareIClTensorHandleShape(outputHandle, {4, 1})); +} + +BOOST_AUTO_TEST_CASE(CreateLogSoftmaxFloat32WorkloadTest) +{ + ClCreateLogSoftmaxWorkloadTest<ClLogSoftmaxWorkload, armnn::DataType::Float32>(); +} + template <typename LstmWorkloadType> static void ClCreateLstmWorkloadTest() { |