aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/CreateWorkload.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/CreateWorkload.hpp')
-rw-r--r--src/armnn/test/CreateWorkload.hpp12
1 files changed, 8 insertions, 4 deletions
diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp
index f2c8b5a20a..b63e95d4cb 100644
--- a/src/armnn/test/CreateWorkload.hpp
+++ b/src/armnn/test/CreateWorkload.hpp
@@ -162,7 +162,6 @@ std::unique_ptr<BatchNormalizationFloat32Workload> CreateBatchNormalizationWorkl
// Makes the workload and checks it.
auto workload = MakeAndCheckWorkload<BatchNormalizationFloat32Workload>(*layer, graph, factory);
-
BatchNormalizationQueueDescriptor queueDescriptor = workload->GetData();
BOOST_TEST(queueDescriptor.m_Parameters.m_Eps == 0.05f);
BOOST_TEST(queueDescriptor.m_Inputs.size() == 1);
@@ -532,7 +531,8 @@ std::unique_ptr<NormalizationWorkload> CreateNormalizationWorkloadTest(armnn::IW
template <typename Pooling2dWorkload, armnn::DataType DataType>
std::unique_ptr<Pooling2dWorkload> CreatePooling2dWorkloadTest(armnn::IWorkloadFactory& factory,
- armnn::Graph& graph)
+ armnn::Graph& graph,
+ DataLayout dataLayout = DataLayout::NCHW)
{
// Creates the layer we're testing.
Pooling2dDescriptor layerDesc;
@@ -546,6 +546,7 @@ std::unique_ptr<Pooling2dWorkload> CreatePooling2dWorkloadTest(armnn::IWorkloadF
layerDesc.m_StrideX = 2;
layerDesc.m_StrideY = 3;
layerDesc.m_OutputShapeRounding = OutputShapeRounding::Floor;
+ layerDesc.m_DataLayout = dataLayout;
Pooling2dLayer* const layer = graph.AddLayer<Pooling2dLayer>(layerDesc, "layer");
@@ -553,9 +554,12 @@ std::unique_ptr<Pooling2dWorkload> CreatePooling2dWorkloadTest(armnn::IWorkloadF
Layer* const input = graph.AddLayer<InputLayer>(0, "input");
Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
+ TensorShape inputShape = (dataLayout == DataLayout::NCHW) ? TensorShape{3, 2, 5, 5} : TensorShape{3, 5, 5, 2};
+ TensorShape outputShape = (dataLayout == DataLayout::NCHW) ? TensorShape{3, 2, 2, 4} : TensorShape{3, 2, 4, 2};
+
// Connect up
- Connect(input, layer, TensorInfo({3, 2, 5, 5}, DataType));
- Connect(layer, output, TensorInfo({3, 2, 2, 4}, DataType));
+ Connect(input, layer, TensorInfo(inputShape, DataType));
+ Connect(layer, output, TensorInfo(outputShape, DataType));
CreateTensorHandles(graph, factory);
// Make the workload and checks it