aboutsummaryrefslogtreecommitdiff
path: root/src/backends/cl/test/ClCreateWorkloadTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/cl/test/ClCreateWorkloadTests.cpp')
-rw-r--r--src/backends/cl/test/ClCreateWorkloadTests.cpp38
1 files changed, 38 insertions, 0 deletions
diff --git a/src/backends/cl/test/ClCreateWorkloadTests.cpp b/src/backends/cl/test/ClCreateWorkloadTests.cpp
index de13390f87..f453ccc9fd 100644
--- a/src/backends/cl/test/ClCreateWorkloadTests.cpp
+++ b/src/backends/cl/test/ClCreateWorkloadTests.cpp
@@ -936,4 +936,42 @@ BOOST_AUTO_TEST_CASE(CreateSpaceToDepthQSymm16Workload)
ClSpaceToDepthWorkloadTest<ClSpaceToDepthWorkload, armnn::DataType::QuantisedSymm16>();
}
+template <armnn::DataType DataType>
+static void ClCreateStackWorkloadTest(const std::initializer_list<unsigned int>& inputShape,
+ const std::initializer_list<unsigned int>& outputShape,
+ unsigned int axis,
+ unsigned int numInputs)
+{
+ armnn::Graph graph;
+ ClWorkloadFactory factory =
+ ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager());
+
+ auto workload = CreateStackWorkloadTest<ClStackWorkload, DataType>(factory,
+ graph,
+ TensorShape(inputShape),
+ TensorShape(outputShape),
+ axis,
+ numInputs);
+
+ // Check inputs and output are as expected
+ StackQueueDescriptor queueDescriptor = workload->GetData();
+ for (unsigned int i = 0; i < numInputs; ++i)
+ {
+ auto inputHandle = boost::polymorphic_downcast<IClTensorHandle*>(queueDescriptor.m_Inputs[i]);
+ BOOST_TEST(CompareIClTensorHandleShape(inputHandle, inputShape));
+ }
+ auto outputHandle = boost::polymorphic_downcast<IClTensorHandle*>(queueDescriptor.m_Outputs[0]);
+ BOOST_TEST(CompareIClTensorHandleShape(outputHandle, outputShape));
+}
+
+BOOST_AUTO_TEST_CASE(CreateStackFloat32Workload)
+{
+ ClCreateStackWorkloadTest<armnn::DataType::Float32>({ 3, 4, 5 }, { 3, 4, 2, 5 }, 2, 2);
+}
+
+BOOST_AUTO_TEST_CASE(CreateStackUint8Workload)
+{
+ ClCreateStackWorkloadTest<armnn::DataType::QuantisedAsymm8>({ 3, 4, 5 }, { 3, 4, 2, 5 }, 2, 2);
+}
+
BOOST_AUTO_TEST_SUITE_END()