aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/CreateWorkload.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/CreateWorkload.hpp')
-rw-r--r--src/armnn/test/CreateWorkload.hpp46
1 files changed, 25 insertions, 21 deletions
diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp
index a33189efeb..f3cf544fa3 100644
--- a/src/armnn/test/CreateWorkload.hpp
+++ b/src/armnn/test/CreateWorkload.hpp
@@ -397,52 +397,56 @@ std::unique_ptr<Convolution2dWorkload> CreateDirectConvolution2dWorkloadTest(arm
return workload;
}
-template <typename DepthwiseConvolution2dFloat32Workload>
+template <typename DepthwiseConvolution2dFloat32Workload, armnn::DataType DataType>
std::unique_ptr<DepthwiseConvolution2dFloat32Workload> CreateDepthwiseConvolution2dWorkloadTest(
- armnn::IWorkloadFactory& factory, armnn::Graph& graph)
+ armnn::IWorkloadFactory& factory, armnn::Graph& graph, DataLayout dataLayout = DataLayout::NCHW)
{
// Creates the layer we're testing.
DepthwiseConvolution2dDescriptor layerDesc;
- layerDesc.m_PadLeft = 3;
- layerDesc.m_PadRight = 3;
+ layerDesc.m_PadLeft = 1;
+ layerDesc.m_PadRight = 2;
layerDesc.m_PadTop = 1;
- layerDesc.m_PadBottom = 1;
- layerDesc.m_StrideX = 2;
- layerDesc.m_StrideY = 4;
- layerDesc.m_BiasEnabled = true;
+ layerDesc.m_PadBottom = 2;
+ layerDesc.m_StrideX = 1;
+ layerDesc.m_StrideY = 1;
+ layerDesc.m_BiasEnabled = false;
+ layerDesc.m_DataLayout = dataLayout;
DepthwiseConvolution2dLayer* const layer = graph.AddLayer<DepthwiseConvolution2dLayer>(layerDesc, "layer");
- layer->m_Weight = std::make_unique<ScopedCpuTensorHandle>(TensorInfo({3, 3, 5, 3}, DataType::Float32));
- layer->m_Bias = std::make_unique<ScopedCpuTensorHandle>(TensorInfo({9}, DataType::Float32));
+ layer->m_Weight = std::make_unique<ScopedCpuTensorHandle>(TensorInfo({1, 4, 4, 2}, DataType));
layer->m_Weight->Allocate();
- layer->m_Bias->Allocate();
// Creates extra layers.
Layer* const input = graph.AddLayer<InputLayer>(0, "input");
Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
+ TensorShape inputShape = (dataLayout == DataLayout::NCHW) ?
+ TensorShape{ 2, 2, 5, 5 } : TensorShape{ 2, 5, 5, 2 };
+ TensorShape outputShape = (dataLayout == DataLayout::NCHW) ?
+ TensorShape{ 2, 2, 5, 5 } : TensorShape{ 2, 5, 5, 2 };
+
// Connects up.
- Connect(input, layer, TensorInfo({2, 3, 8, 16}, armnn::DataType::Float32));
- Connect(layer, output, TensorInfo({2, 9, 2, 10}, armnn::DataType::Float32));
+ Connect(input, layer, TensorInfo(inputShape, DataType));
+ Connect(layer, output, TensorInfo(outputShape, DataType));
CreateTensorHandles(graph, factory);
// Makes the workload and checks it.
auto workload = MakeAndCheckWorkload<DepthwiseConvolution2dFloat32Workload>(*layer, graph, factory);
DepthwiseConvolution2dQueueDescriptor queueDescriptor = workload->GetData();
- BOOST_TEST(queueDescriptor.m_Parameters.m_StrideX == 2);
- BOOST_TEST(queueDescriptor.m_Parameters.m_StrideY == 4);
- BOOST_TEST(queueDescriptor.m_Parameters.m_PadLeft == 3);
- BOOST_TEST(queueDescriptor.m_Parameters.m_PadRight == 3);
+ BOOST_TEST(queueDescriptor.m_Parameters.m_StrideX == 1);
+ BOOST_TEST(queueDescriptor.m_Parameters.m_StrideY == 1);
+ BOOST_TEST(queueDescriptor.m_Parameters.m_PadLeft == 1);
+ BOOST_TEST(queueDescriptor.m_Parameters.m_PadRight == 2);
BOOST_TEST(queueDescriptor.m_Parameters.m_PadTop == 1);
- BOOST_TEST(queueDescriptor.m_Parameters.m_PadBottom == 1);
- BOOST_TEST(queueDescriptor.m_Parameters.m_BiasEnabled == true);
+ BOOST_TEST(queueDescriptor.m_Parameters.m_PadBottom == 2);
+ BOOST_TEST(queueDescriptor.m_Parameters.m_BiasEnabled == false);
+ BOOST_TEST((queueDescriptor.m_Parameters.m_DataLayout == dataLayout));
BOOST_TEST(queueDescriptor.m_Inputs.size() == 1);
BOOST_TEST(queueDescriptor.m_Outputs.size() == 1);
- BOOST_TEST((queueDescriptor.m_Weight->GetTensorInfo() == TensorInfo({3, 3, 5, 3}, DataType::Float32)));
- BOOST_TEST((queueDescriptor.m_Bias->GetTensorInfo() == TensorInfo({9}, DataType::Float32)));
+ BOOST_TEST((queueDescriptor.m_Weight->GetTensorInfo() == TensorInfo({1, 4, 4, 2}, DataType)));
// Returns so we can do extra, backend-specific tests.
return workload;