aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/CreateWorkload.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/CreateWorkload.hpp')
-rw-r--r--src/armnn/test/CreateWorkload.hpp26
1 files changed, 18 insertions, 8 deletions
diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp
index db1773a0ce..01c5e9f689 100644
--- a/src/armnn/test/CreateWorkload.hpp
+++ b/src/armnn/test/CreateWorkload.hpp
@@ -133,8 +133,20 @@ std::unique_ptr<WorkloadType> CreateArithmeticWorkloadTest(armnn::IWorkloadFacto
template <typename BatchNormalizationFloat32Workload, armnn::DataType DataType>
std::unique_ptr<BatchNormalizationFloat32Workload> CreateBatchNormalizationWorkloadTest(
- armnn::IWorkloadFactory& factory, armnn::Graph& graph, DataLayout dataLayout = DataLayout::NCHW)
+ armnn::IWorkloadFactory& factory, armnn::Graph& graph, DataLayoutIndexed dataLayout = DataLayout::NCHW)
{
+
+ TensorShape tensorShape;
+ switch (dataLayout.GetDataLayout())
+ {
+ case DataLayout::NHWC:
+ tensorShape = { 2, 4, 4, 3 };
+ break;
+ case DataLayout::NCHW:
+ default:
+ tensorShape = { 2, 3, 4, 4 };
+ }
+
// Creates the layer we're testing.
BatchNormalizationDescriptor layerDesc;
layerDesc.m_Eps = 0.05f;
@@ -156,25 +168,23 @@ std::unique_ptr<BatchNormalizationFloat32Workload> CreateBatchNormalizationWorkl
Layer* const input = graph.AddLayer<InputLayer>(0, "input");
Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
- TensorShape inputShape = (dataLayout == DataLayout::NCHW) ? TensorShape{ 2, 3, 1, 1 } : TensorShape{ 2, 1, 1, 3 };
- TensorShape outputShape = (dataLayout == DataLayout::NCHW) ? TensorShape{ 2, 3, 1, 1 } : TensorShape{ 2, 1, 1, 3 };
-
- // Connects up.
- Connect(input, layer, TensorInfo(inputShape, DataType));
- Connect(layer, output, TensorInfo(outputShape, DataType));
+ //Connects up.
+ armnn::TensorInfo tensorInfo(tensorShape, DataType);
+ Connect(input, layer, tensorInfo);
+ Connect(layer, output, tensorInfo);
CreateTensorHandles(graph, factory);
// Makes the workload and checks it.
auto workload = MakeAndCheckWorkload<BatchNormalizationFloat32Workload>(*layer, graph, factory);
BatchNormalizationQueueDescriptor queueDescriptor = workload->GetData();
BOOST_TEST(queueDescriptor.m_Parameters.m_Eps == 0.05f);
- BOOST_TEST((queueDescriptor.m_Parameters.m_DataLayout == dataLayout));
BOOST_TEST(queueDescriptor.m_Inputs.size() == 1);
BOOST_TEST(queueDescriptor.m_Outputs.size() == 1);
BOOST_TEST((queueDescriptor.m_Mean->GetTensorInfo() == TensorInfo({3}, DataType)));
BOOST_TEST((queueDescriptor.m_Variance->GetTensorInfo() == TensorInfo({3}, DataType)));
BOOST_TEST((queueDescriptor.m_Gamma->GetTensorInfo() == TensorInfo({3}, DataType)));
BOOST_TEST((queueDescriptor.m_Beta->GetTensorInfo() == TensorInfo({3}, DataType)));
+ BOOST_TEST((queueDescriptor.m_Parameters.m_DataLayout.GetDataLayout() == dataLayout));
// Returns so we can do extra, backend-specific tests.
return workload;