From 04a729708f986b1a69c1efc42d5cf18271cfae1e Mon Sep 17 00:00:00 2001 From: Sadik Armagan Date: Mon, 14 Sep 2020 15:44:18 +0100 Subject: IVGCVSW-5157 'Pipe ModelOption through Network::LoadNetwork() to Workload factory' * Pass ModelOptions to WorkloadFactory * Updated signature of CL and NEON Convolution2d workloads added FastMathEnabled param. Signed-off-by: Sadik Armagan Change-Id: I536178be8e4dd4083489e69febadaf0feeba46d2 --- src/backends/neon/test/NeonCreateWorkloadTests.cpp | 29 ++++++++++++++++++++++ 1 file changed, 29 insertions(+) (limited to 'src/backends/neon/test/NeonCreateWorkloadTests.cpp') diff --git a/src/backends/neon/test/NeonCreateWorkloadTests.cpp b/src/backends/neon/test/NeonCreateWorkloadTests.cpp index 37d026f107..99ff9ae8b8 100644 --- a/src/backends/neon/test/NeonCreateWorkloadTests.cpp +++ b/src/backends/neon/test/NeonCreateWorkloadTests.cpp @@ -6,6 +6,8 @@ #include "NeonWorkloadFactoryHelper.hpp" #include +#include +#include #include #include @@ -276,6 +278,33 @@ BOOST_AUTO_TEST_CASE(CreateConvolution2dFloatNhwcWorkload) NeonCreateConvolution2dWorkloadTest(DataLayout::NHWC); } +BOOST_AUTO_TEST_CASE(CreateConvolution2dFastMathEnabledWorkload) +{ + Graph graph; + using ModelOptions = std::vector; + ModelOptions modelOptions = {}; + BackendOptions cpuAcc("CpuAcc", + { + { "FastMathEnabled", true } + }); + modelOptions.push_back(cpuAcc); + NeonWorkloadFactory factory = + NeonWorkloadFactoryHelper::GetFactory(NeonWorkloadFactoryHelper::GetMemoryManager(), modelOptions); + + auto workload = + CreateConvolution2dWorkloadTest(factory, + graph, + DataLayout::NCHW, + modelOptions); + + ARMNN_ASSERT(workload != nullptr); + auto conv2dWorkload = PolymorphicDowncast(workload.get()); + IgnoreUnused(conv2dWorkload); + ARMNN_ASSERT(conv2dWorkload != nullptr); + // fast_math enabled but configuration does not match with WINOGRAD + ARMNN_ASSERT(conv2dWorkload->GetConvolutionMethod() == arm_compute::ConvolutionMethod::GEMM); +} + template static void NeonCreateDepthWiseConvolutionWorkloadTest(DataLayout dataLayout) { -- cgit v1.2.1