From e9444751017fe108ce80fd5c270d04fffeb14e1e Mon Sep 17 00:00:00 2001 From: Sadik Armagan Date: Wed, 2 Dec 2020 11:28:58 +0000 Subject: IVGCVSW-5482 'Add a ClCompileContext parameter to each ClWorkload Constructor' * Injected CLCompileContext object to each CL workload. Signed-off-by: Sadik Armagan Change-Id: I4837dbd3d5b56cf743b3b89c944e3cdf8b11a42a --- src/backends/cl/test/ClCreateWorkloadTests.cpp | 94 ++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) (limited to 'src/backends/cl/test/ClCreateWorkloadTests.cpp') diff --git a/src/backends/cl/test/ClCreateWorkloadTests.cpp b/src/backends/cl/test/ClCreateWorkloadTests.cpp index 4bd3d3a33d..765409a426 100644 --- a/src/backends/cl/test/ClCreateWorkloadTests.cpp +++ b/src/backends/cl/test/ClCreateWorkloadTests.cpp @@ -10,6 +10,8 @@ #include #include #include +#include +#include #include #include @@ -334,6 +336,98 @@ BOOST_AUTO_TEST_CASE(CreateConvolution2dFastMathEnabledWorkload) ARMNN_ASSERT(conv2dWorkload->GetConvolutionMethod() == arm_compute::ConvolutionMethod::WINOGRAD); } +BOOST_AUTO_TEST_CASE(CreateConvolution2dClCompiledContextWorkload) +{ + using namespace armnn; + + const DataType inputType = DataType::QAsymmU8; + const DataType kernelType = DataType::QSymmS8; + const DataType biasType = DataType::Signed32; + + TensorInfo inputInfo ({ 1, 3, 1, 2 }, inputType, 0.5f, 128); + TensorInfo outputInfo({ 1, 3, 1, 3 }, inputType, 1.0f, 128); + + const std::vector quantScales{ 0.5f, 0.75f, 1.0f }; + constexpr unsigned int quantDimension = 0; + + TensorInfo kernelInfo({ 3, 1, 1, 2 }, kernelType, quantScales, quantDimension); + + const std::vector biasQuantScales{ 0.25f, 0.375f, 0.5f }; + TensorInfo biasInfo({ 3 }, biasType, biasQuantScales, quantDimension); + + std::vector inputData = + { + 138, 108, 138, 108, 138, 108 + }; + + std::vector kernelData = + { + 1, 2, 1, 2, 1, 2 + }; + + std::vector biasData = + { + 4, 4, 4 + }; + + std::vector expectedOutputData = + { + 121, 118, 115, 121, 118, 115, 121, 118, 115 + }; + + + Convolution2dDescriptor descriptor; + descriptor.m_StrideX = 1; + descriptor.m_StrideY = 1; + descriptor.m_PadLeft = 0; + descriptor.m_PadRight = 0; + descriptor.m_PadTop = 0; + descriptor.m_PadBottom = 0; + descriptor.m_BiasEnabled = true; + descriptor.m_DataLayout = DataLayout::NHWC; + + auto memoryManager = ClWorkloadFactoryHelper::GetMemoryManager(); + auto clMemoryManager = armnn::PolymorphicPointerDowncast(memoryManager); + auto tensorHandleFactory = ClWorkloadFactoryHelper::GetTensorHandleFactory(memoryManager); + + std::unique_ptr inputHandle = tensorHandleFactory.CreateTensorHandle(inputInfo); + std::unique_ptr outputHandle = tensorHandleFactory.CreateTensorHandle(outputInfo); + + + WorkloadInfo workloadInfo; + ScopedCpuTensorHandle weightTensor(kernelInfo); + ScopedCpuTensorHandle biasTensor(biasInfo); + + AllocateAndCopyDataToITensorHandle(&weightTensor, kernelData.data()); + AllocateAndCopyDataToITensorHandle(&biasTensor, biasData.data()); + + Convolution2dQueueDescriptor queueDescriptor; + queueDescriptor.m_Parameters = descriptor; + queueDescriptor.m_Weight = &weightTensor; + queueDescriptor.m_Bias = &biasTensor; + + AddInputToWorkload(queueDescriptor, workloadInfo, inputInfo, inputHandle.get()); + AddOutputToWorkload(queueDescriptor, workloadInfo, outputInfo, outputHandle.get()); + + // Initialize our m_CLCompileContext using default device and context + auto context = arm_compute::CLKernelLibrary::get().context(); + auto device = arm_compute::CLKernelLibrary::get().get_device(); + auto clCompileContext = arm_compute::CLCompileContext(context, device); + + + + // Check built programs are empty in context + BOOST_TEST(clCompileContext.get_built_programs().empty()); + + auto workload = std::make_unique(queueDescriptor, + workloadInfo, + clMemoryManager->GetIntraLayerManager(), + clCompileContext); + ARMNN_ASSERT(workload != nullptr); + // Check built programs are not empty in context + BOOST_TEST(!clCompileContext.get_built_programs().empty()); +} + template static void ClDepthwiseConvolutionWorkloadTest(DataLayout dataLayout) { -- cgit v1.2.1