aboutsummaryrefslogtreecommitdiff
path: root/src/backends/test/CreateWorkloadCl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/test/CreateWorkloadCl.cpp')
-rw-r--r--src/backends/test/CreateWorkloadCl.cpp27
1 files changed, 24 insertions, 3 deletions
diff --git a/src/backends/test/CreateWorkloadCl.cpp b/src/backends/test/CreateWorkloadCl.cpp
index 39bc259940..cc0e12d202 100644
--- a/src/backends/test/CreateWorkloadCl.cpp
+++ b/src/backends/test/CreateWorkloadCl.cpp
@@ -524,13 +524,14 @@ BOOST_AUTO_TEST_CASE(CreateMemCopyWorkloadsCl)
CreateMemCopyWorkloads<IClTensorHandle>(factory);
}
-BOOST_AUTO_TEST_CASE(CreateL2NormalizationWorkload)
+template <typename L2NormalizationWorkloadType, typename armnn::DataType DataType>
+static void ClL2NormalizationWorkloadTest(DataLayout dataLayout)
{
Graph graph;
ClWorkloadFactory factory;
- auto workload = CreateL2NormalizationWorkloadTest<ClL2NormalizationFloatWorkload, armnn::DataType::Float32>
- (factory, graph);
+ auto workload = CreateL2NormalizationWorkloadTest<L2NormalizationWorkloadType, DataType>
+ (factory, graph, dataLayout);
// Checks that inputs/outputs are as we expect them (see definition of CreateNormalizationWorkloadTest).
L2NormalizationQueueDescriptor queueDescriptor = workload->GetData();
@@ -541,6 +542,26 @@ BOOST_AUTO_TEST_CASE(CreateL2NormalizationWorkload)
BOOST_TEST(CompareIClTensorHandleShape(outputHandle, { 5, 20, 50, 67 }));
}
+BOOST_AUTO_TEST_CASE(CreateL2NormalizationFloatNchwWorkload)
+{
+ ClL2NormalizationWorkloadTest<ClL2NormalizationFloatWorkload, armnn::DataType::Float32>(DataLayout::NCHW);
+}
+
+BOOST_AUTO_TEST_CASE(CreateL2NormalizationFloatNhwcWorkload)
+{
+ ClL2NormalizationWorkloadTest<ClL2NormalizationFloatWorkload, armnn::DataType::Float32>(DataLayout::NHWC);
+}
+
+BOOST_AUTO_TEST_CASE(CreateL2NormalizationFloat16NchwWorkload)
+{
+ ClL2NormalizationWorkloadTest<ClL2NormalizationFloatWorkload, armnn::DataType::Float16>(DataLayout::NCHW);
+}
+
+BOOST_AUTO_TEST_CASE(CreateL2NormalizationFloat16NhwcWorkload)
+{
+ ClL2NormalizationWorkloadTest<ClL2NormalizationFloatWorkload, armnn::DataType::Float16>(DataLayout::NHWC);
+}
+
template <typename LstmWorkloadType>
static void ClCreateLstmWorkloadTest()
{