diff options
Diffstat (limited to 'src/armnn/test/CreateWorkload.hpp')
-rw-r--r-- | src/armnn/test/CreateWorkload.hpp | 29 |
1 files changed, 15 insertions, 14 deletions
diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp index 0048646d45..774df6a4bb 100644 --- a/src/armnn/test/CreateWorkload.hpp +++ b/src/armnn/test/CreateWorkload.hpp @@ -836,10 +836,10 @@ void CreateSplitterMultipleInputsOneOutputWorkloadTest(armnn::IWorkloadFactory& wlActiv1_1 = std::move(workloadActiv1_1); } -template <typename ResizeBilinearWorkload, armnn::DataType DataType> -std::unique_ptr<ResizeBilinearWorkload> CreateResizeBilinearWorkloadTest(armnn::IWorkloadFactory& factory, - armnn::Graph& graph, - DataLayout dataLayout = DataLayout::NCHW) +template <typename ResizeWorkload, armnn::DataType DataType> +std::unique_ptr<ResizeWorkload> CreateResizeBilinearWorkloadTest(armnn::IWorkloadFactory& factory, + armnn::Graph& graph, + DataLayout dataLayout = DataLayout::NCHW) { TensorShape inputShape; TensorShape outputShape; @@ -856,15 +856,16 @@ std::unique_ptr<ResizeBilinearWorkload> CreateResizeBilinearWorkloadTest(armnn:: } // Creates the layer we're testing. - ResizeBilinearDescriptor resizeDesc; + ResizeDescriptor resizeDesc; armnnUtils::DataLayoutIndexed dimensionIndices = dataLayout; - resizeDesc.m_TargetWidth = outputShape[dimensionIndices.GetWidthIndex()]; + resizeDesc.m_Method = ResizeMethod::Bilinear; + resizeDesc.m_TargetWidth = outputShape[dimensionIndices.GetWidthIndex()]; resizeDesc.m_TargetHeight = outputShape[dimensionIndices.GetHeightIndex()]; - resizeDesc.m_DataLayout = dataLayout; - Layer* const layer = graph.AddLayer<ResizeBilinearLayer>(resizeDesc, "layer"); + resizeDesc.m_DataLayout = dataLayout; + Layer* const layer = graph.AddLayer<ResizeLayer>(resizeDesc, "resize"); // Creates extra layers. - Layer* const input = graph.AddLayer<InputLayer>(0, "input"); + Layer* const input = graph.AddLayer<InputLayer>(0, "input"); Layer* const output = graph.AddLayer<OutputLayer>(0, "output"); // Connects up. @@ -875,12 +876,12 @@ std::unique_ptr<ResizeBilinearWorkload> CreateResizeBilinearWorkloadTest(armnn:: CreateTensorHandles(graph, factory); // Makes the workload and checks it. - auto workload = MakeAndCheckWorkload<ResizeBilinearWorkload>(*layer, graph, factory); + auto workload = MakeAndCheckWorkload<ResizeWorkload>(*layer, graph, factory); - ResizeBilinearQueueDescriptor queueDescriptor = workload->GetData(); - BOOST_TEST(queueDescriptor.m_Inputs.size() == 1); - BOOST_TEST(queueDescriptor.m_Outputs.size() == 1); - BOOST_TEST((queueDescriptor.m_Parameters.m_DataLayout == dataLayout)); + auto queueDescriptor = workload->GetData(); + BOOST_CHECK(queueDescriptor.m_Inputs.size() == 1); + BOOST_CHECK(queueDescriptor.m_Outputs.size() == 1); + BOOST_CHECK(queueDescriptor.m_Parameters.m_DataLayout == dataLayout); // Returns so we can do extra, backend-specific tests. return workload; |