aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2018-10-18 10:55:19 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-10-22 16:57:54 +0100
commit3dc4303c94cf3f5976e495233f663ff56089e53a (patch)
treee70c15cf1206576a81a02dfa3b66b3b09f88942a
parenta160b245a5c876d3630651e938a7c45ee30645be (diff)
downloadarmnn-3dc4303c94cf3f5976e495233f663ff56089e53a.tar.gz
IVGCVSW-2040 Add unit tests for the newly implemented NHWC support in
ref BatchNormalization * Added create workload unit tests for the NHWC data layout Change-Id: I03d66c88dc9b0340302b85012cb0152f0ec6fa72
-rw-r--r--src/armnn/test/CreateWorkload.hpp12
-rw-r--r--src/backends/reference/test/RefCreateWorkloadTests.cpp52
2 files changed, 54 insertions, 10 deletions
diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp
index 51820a425f..21385d7a99 100644
--- a/src/armnn/test/CreateWorkload.hpp
+++ b/src/armnn/test/CreateWorkload.hpp
@@ -133,11 +133,12 @@ std::unique_ptr<WorkloadType> CreateArithmeticWorkloadTest(armnn::IWorkloadFacto
template <typename BatchNormalizationFloat32Workload, armnn::DataType DataType>
std::unique_ptr<BatchNormalizationFloat32Workload> CreateBatchNormalizationWorkloadTest(
- armnn::IWorkloadFactory& factory, armnn::Graph& graph)
+ armnn::IWorkloadFactory& factory, armnn::Graph& graph, DataLayout dataLayout = DataLayout::NCHW)
{
// Creates the layer we're testing.
BatchNormalizationDescriptor layerDesc;
layerDesc.m_Eps = 0.05f;
+ layerDesc.m_DataLayout = dataLayout;
BatchNormalizationLayer* const layer = graph.AddLayer<BatchNormalizationLayer>(layerDesc, "layer");
@@ -155,16 +156,19 @@ std::unique_ptr<BatchNormalizationFloat32Workload> CreateBatchNormalizationWorkl
Layer* const input = graph.AddLayer<InputLayer>(0, "input");
Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
+ TensorShape inputShape = (dataLayout == DataLayout::NCHW) ? TensorShape{ 2, 3, 1, 1 } : TensorShape{ 2, 1, 1, 3 };
+ TensorShape outputShape = (dataLayout == DataLayout::NCHW) ? TensorShape{ 2, 3, 1, 1 } : TensorShape{ 2, 1, 1, 3 };
+
// Connects up.
- armnn::TensorInfo tensorInfo({2, 3, 1, 1}, DataType);
- Connect(input, layer, tensorInfo);
- Connect(layer, output, tensorInfo);
+ Connect(input, layer, TensorInfo(inputShape, DataType));
+ Connect(layer, output, TensorInfo(outputShape, DataType));
CreateTensorHandles(graph, factory);
// Makes the workload and checks it.
auto workload = MakeAndCheckWorkload<BatchNormalizationFloat32Workload>(*layer, graph, factory);
BatchNormalizationQueueDescriptor queueDescriptor = workload->GetData();
BOOST_TEST(queueDescriptor.m_Parameters.m_Eps == 0.05f);
+ BOOST_TEST((queueDescriptor.m_Parameters.m_DataLayout == dataLayout));
BOOST_TEST(queueDescriptor.m_Inputs.size() == 1);
BOOST_TEST(queueDescriptor.m_Outputs.size() == 1);
BOOST_TEST((queueDescriptor.m_Mean->GetTensorInfo() == TensorInfo({3}, DataType)));
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index 1ec7749168..d258b81932 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -143,16 +143,56 @@ BOOST_AUTO_TEST_CASE(CreateDivisionUint8Workload)
armnn::DataType::QuantisedAsymm8>();
}
-BOOST_AUTO_TEST_CASE(CreateBatchNormalizationWorkload)
+template <typename BatchNormalizationWorkloadType, armnn::DataType DataType>
+static void RefCreateBatchNormalizationWorkloadTest(DataLayout dataLayout)
{
- Graph graph;
+ Graph graph;
RefWorkloadFactory factory;
- auto workload = CreateBatchNormalizationWorkloadTest<RefBatchNormalizationFloat32Workload, armnn::DataType::Float32>
- (factory, graph);
+ auto workload =
+ CreateBatchNormalizationWorkloadTest<BatchNormalizationWorkloadType, DataType>(factory, graph, dataLayout);
+
+ TensorShape inputShape;
+ TensorShape outputShape;
+
+ switch (dataLayout)
+ {
+ case DataLayout::NHWC:
+ inputShape = { 2, 1, 1, 3 };
+ outputShape = { 2, 1, 1, 3 };
+ break;
+ case DataLayout::NCHW:
+ default:
+ inputShape = { 2, 3, 1, 1 };
+ outputShape = { 2, 3, 1, 1 };
+ break;
+ }
// Checks that outputs and inputs are as we expect them (see definition of CreateBatchNormalizationWorkloadTest).
- CheckInputOutput(
- std::move(workload), TensorInfo({2, 3, 1, 1}, DataType::Float32), TensorInfo({2, 3, 1, 1}, DataType::Float32));
+ CheckInputOutput(std::move(workload), TensorInfo(inputShape, DataType), TensorInfo(outputShape, DataType));
+}
+
+BOOST_AUTO_TEST_CASE(CreateBatchNormalizationFloat32Workload)
+{
+ RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationFloat32Workload,armnn::DataType::Float32>
+ (DataLayout::NCHW);
+}
+
+BOOST_AUTO_TEST_CASE(CreateBatchNormalizationFloat32WorkloadNhwc)
+{
+ RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationFloat32Workload, armnn::DataType::Float32>
+ (DataLayout::NHWC);
+}
+
+BOOST_AUTO_TEST_CASE(CreateBatchNormalizationUint8Workload)
+{
+ RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationUint8Workload, armnn::DataType::QuantisedAsymm8>
+ (DataLayout::NCHW);
+}
+
+BOOST_AUTO_TEST_CASE(CreateBatchNormalizationUint8WorkloadNhwc)
+{
+ RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationUint8Workload, armnn::DataType::QuantisedAsymm8>
+ (DataLayout::NHWC);
}
BOOST_AUTO_TEST_CASE(CreateConvertFp16ToFp32Float32Workload)