diff options
author | narpra01 <narumol.prangnawarat@arm.com> | 2018-10-02 14:35:53 +0100 |
---|---|---|
committer | Matthew Bentham <matthew.bentham@arm.com> | 2018-10-10 16:16:58 +0100 |
commit | 55a97bc2605fc1246a9a1f7ee89cde415496a1ba (patch) | |
tree | 28043aa8cbe684f978d46c690b100000e9517312 /src/armnn/test/CreateWorkload.hpp | |
parent | ee9e7665a5922f7ec0c5ec24d6ab2ecd88fbcfd6 (diff) | |
download | armnn-55a97bc2605fc1246a9a1f7ee89cde415496a1ba.tar.gz |
IVGCVSW-1920 Unittests for NHWC Normalization Workloads and Layer
Change-Id: Iea941c1747454f5a4342351e4e82b10ffb9ccbbd
Diffstat (limited to 'src/armnn/test/CreateWorkload.hpp')
-rw-r--r-- | src/armnn/test/CreateWorkload.hpp | 11 |
1 files changed, 7 insertions, 4 deletions
diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp index c111fe6016..66f62820d9 100644 --- a/src/armnn/test/CreateWorkload.hpp +++ b/src/armnn/test/CreateWorkload.hpp @@ -485,9 +485,10 @@ std::unique_ptr<FullyConnectedWorkload> CreateFullyConnectedWorkloadTest(armnn:: return workload; } -template <typename NormalizationFloat32Workload, armnn::DataType DataType> -std::unique_ptr<NormalizationFloat32Workload> CreateNormalizationWorkloadTest(armnn::IWorkloadFactory& factory, - armnn::Graph& graph) +template <typename NormalizationWorkload, armnn::DataType DataType> +std::unique_ptr<NormalizationWorkload> CreateNormalizationWorkloadTest(armnn::IWorkloadFactory& factory, + armnn::Graph& graph, + DataLayout dataLayout = DataLayout::NCHW) { // Creates the layer we're testing. NormalizationDescriptor layerDesc; @@ -497,6 +498,7 @@ std::unique_ptr<NormalizationFloat32Workload> CreateNormalizationWorkloadTest(ar layerDesc.m_Alpha = 0.5f; layerDesc.m_Beta = -1.0f; layerDesc.m_K = 0.2f; + layerDesc.m_DataLayout = dataLayout; NormalizationLayer* layer = graph.AddLayer<NormalizationLayer>(layerDesc, "layer"); @@ -510,7 +512,7 @@ std::unique_ptr<NormalizationFloat32Workload> CreateNormalizationWorkloadTest(ar CreateTensorHandles(graph, factory); // Makes the workload and checks it. - auto workload = MakeAndCheckWorkload<NormalizationFloat32Workload>(*layer, graph, factory); + auto workload = MakeAndCheckWorkload<NormalizationWorkload>(*layer, graph, factory); NormalizationQueueDescriptor queueDescriptor = workload->GetData(); BOOST_TEST((queueDescriptor.m_Parameters.m_NormChannelType == NormalizationAlgorithmChannel::Across)); @@ -519,6 +521,7 @@ std::unique_ptr<NormalizationFloat32Workload> CreateNormalizationWorkloadTest(ar BOOST_TEST(queueDescriptor.m_Parameters.m_Alpha == 0.5f); BOOST_TEST(queueDescriptor.m_Parameters.m_Beta == -1.0f); BOOST_TEST(queueDescriptor.m_Parameters.m_K == 0.2f); + BOOST_TEST((queueDescriptor.m_Parameters.m_DataLayout == dataLayout)); BOOST_TEST(queueDescriptor.m_Inputs.size() == 1); BOOST_TEST(queueDescriptor.m_Outputs.size() == 1); |