diff options
Diffstat (limited to 'src/armnn/test/CreateWorkload.hpp')
-rw-r--r-- | src/armnn/test/CreateWorkload.hpp | 11 |
1 files changed, 7 insertions, 4 deletions
diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp index c111fe6016..66f62820d9 100644 --- a/src/armnn/test/CreateWorkload.hpp +++ b/src/armnn/test/CreateWorkload.hpp @@ -485,9 +485,10 @@ std::unique_ptr<FullyConnectedWorkload> CreateFullyConnectedWorkloadTest(armnn:: return workload; } -template <typename NormalizationFloat32Workload, armnn::DataType DataType> -std::unique_ptr<NormalizationFloat32Workload> CreateNormalizationWorkloadTest(armnn::IWorkloadFactory& factory, - armnn::Graph& graph) +template <typename NormalizationWorkload, armnn::DataType DataType> +std::unique_ptr<NormalizationWorkload> CreateNormalizationWorkloadTest(armnn::IWorkloadFactory& factory, + armnn::Graph& graph, + DataLayout dataLayout = DataLayout::NCHW) { // Creates the layer we're testing. NormalizationDescriptor layerDesc; @@ -497,6 +498,7 @@ std::unique_ptr<NormalizationFloat32Workload> CreateNormalizationWorkloadTest(ar layerDesc.m_Alpha = 0.5f; layerDesc.m_Beta = -1.0f; layerDesc.m_K = 0.2f; + layerDesc.m_DataLayout = dataLayout; NormalizationLayer* layer = graph.AddLayer<NormalizationLayer>(layerDesc, "layer"); @@ -510,7 +512,7 @@ std::unique_ptr<NormalizationFloat32Workload> CreateNormalizationWorkloadTest(ar CreateTensorHandles(graph, factory); // Makes the workload and checks it. - auto workload = MakeAndCheckWorkload<NormalizationFloat32Workload>(*layer, graph, factory); + auto workload = MakeAndCheckWorkload<NormalizationWorkload>(*layer, graph, factory); NormalizationQueueDescriptor queueDescriptor = workload->GetData(); BOOST_TEST((queueDescriptor.m_Parameters.m_NormChannelType == NormalizationAlgorithmChannel::Across)); @@ -519,6 +521,7 @@ std::unique_ptr<NormalizationFloat32Workload> CreateNormalizationWorkloadTest(ar BOOST_TEST(queueDescriptor.m_Parameters.m_Alpha == 0.5f); BOOST_TEST(queueDescriptor.m_Parameters.m_Beta == -1.0f); BOOST_TEST(queueDescriptor.m_Parameters.m_K == 0.2f); + BOOST_TEST((queueDescriptor.m_Parameters.m_DataLayout == dataLayout)); BOOST_TEST(queueDescriptor.m_Inputs.size() == 1); BOOST_TEST(queueDescriptor.m_Outputs.size() == 1); |