aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/CreateWorkload.hpp
diff options
context:
space:
mode:
authornarpra01 <narumol.prangnawarat@arm.com>2018-10-02 14:35:53 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-10-10 16:16:58 +0100
commit55a97bc2605fc1246a9a1f7ee89cde415496a1ba (patch)
tree28043aa8cbe684f978d46c690b100000e9517312 /src/armnn/test/CreateWorkload.hpp
parentee9e7665a5922f7ec0c5ec24d6ab2ecd88fbcfd6 (diff)
downloadarmnn-55a97bc2605fc1246a9a1f7ee89cde415496a1ba.tar.gz
IVGCVSW-1920 Unittests for NHWC Normalization Workloads and Layer
Change-Id: Iea941c1747454f5a4342351e4e82b10ffb9ccbbd
Diffstat (limited to 'src/armnn/test/CreateWorkload.hpp')
-rw-r--r--src/armnn/test/CreateWorkload.hpp11
1 files changed, 7 insertions, 4 deletions
diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp
index c111fe6016..66f62820d9 100644
--- a/src/armnn/test/CreateWorkload.hpp
+++ b/src/armnn/test/CreateWorkload.hpp
@@ -485,9 +485,10 @@ std::unique_ptr<FullyConnectedWorkload> CreateFullyConnectedWorkloadTest(armnn::
return workload;
}
-template <typename NormalizationFloat32Workload, armnn::DataType DataType>
-std::unique_ptr<NormalizationFloat32Workload> CreateNormalizationWorkloadTest(armnn::IWorkloadFactory& factory,
- armnn::Graph& graph)
+template <typename NormalizationWorkload, armnn::DataType DataType>
+std::unique_ptr<NormalizationWorkload> CreateNormalizationWorkloadTest(armnn::IWorkloadFactory& factory,
+ armnn::Graph& graph,
+ DataLayout dataLayout = DataLayout::NCHW)
{
// Creates the layer we're testing.
NormalizationDescriptor layerDesc;
@@ -497,6 +498,7 @@ std::unique_ptr<NormalizationFloat32Workload> CreateNormalizationWorkloadTest(ar
layerDesc.m_Alpha = 0.5f;
layerDesc.m_Beta = -1.0f;
layerDesc.m_K = 0.2f;
+ layerDesc.m_DataLayout = dataLayout;
NormalizationLayer* layer = graph.AddLayer<NormalizationLayer>(layerDesc, "layer");
@@ -510,7 +512,7 @@ std::unique_ptr<NormalizationFloat32Workload> CreateNormalizationWorkloadTest(ar
CreateTensorHandles(graph, factory);
// Makes the workload and checks it.
- auto workload = MakeAndCheckWorkload<NormalizationFloat32Workload>(*layer, graph, factory);
+ auto workload = MakeAndCheckWorkload<NormalizationWorkload>(*layer, graph, factory);
NormalizationQueueDescriptor queueDescriptor = workload->GetData();
BOOST_TEST((queueDescriptor.m_Parameters.m_NormChannelType == NormalizationAlgorithmChannel::Across));
@@ -519,6 +521,7 @@ std::unique_ptr<NormalizationFloat32Workload> CreateNormalizationWorkloadTest(ar
BOOST_TEST(queueDescriptor.m_Parameters.m_Alpha == 0.5f);
BOOST_TEST(queueDescriptor.m_Parameters.m_Beta == -1.0f);
BOOST_TEST(queueDescriptor.m_Parameters.m_K == 0.2f);
+ BOOST_TEST((queueDescriptor.m_Parameters.m_DataLayout == dataLayout));
BOOST_TEST(queueDescriptor.m_Inputs.size() == 1);
BOOST_TEST(queueDescriptor.m_Outputs.size() == 1);