diff options
Diffstat (limited to 'src/backends/cl')
-rw-r--r-- | src/backends/cl/CMakeLists.txt | 2 | ||||
-rw-r--r-- | src/backends/cl/ClBackend.cpp | 21 | ||||
-rw-r--r-- | src/backends/cl/ClBackend.hpp | 7 | ||||
-rw-r--r-- | src/backends/cl/ClWorkloadFactory.cpp | 22 | ||||
-rw-r--r-- | src/backends/cl/ClWorkloadFactory.hpp | 7 | ||||
-rw-r--r-- | src/backends/cl/test/CMakeLists.txt | 2 | ||||
-rw-r--r-- | src/backends/cl/test/ClCreateWorkloadTests.cpp | 43 | ||||
-rw-r--r-- | src/backends/cl/test/ClLayerSupportTests.cpp | 8 | ||||
-rwxr-xr-x | src/backends/cl/test/ClLayerTests.cpp | 1 | ||||
-rw-r--r-- | src/backends/cl/test/ClMemCopyTests.cpp | 6 | ||||
-rw-r--r-- | src/backends/cl/test/ClOptimizedNetworkTests.cpp | 4 | ||||
-rw-r--r-- | src/backends/cl/test/ClWorkloadFactoryHelper.hpp | 35 | ||||
-rw-r--r-- | src/backends/cl/test/OpenClTimerTest.cpp | 4 |
13 files changed, 111 insertions, 51 deletions
diff --git a/src/backends/cl/CMakeLists.txt b/src/backends/cl/CMakeLists.txt index dd2a4a12b1..7593f098da 100644 --- a/src/backends/cl/CMakeLists.txt +++ b/src/backends/cl/CMakeLists.txt @@ -34,8 +34,6 @@ else() ClContextControl.hpp ClLayerSupport.cpp ClLayerSupport.hpp - ClWorkloadFactory.cpp - ClWorkloadFactory.hpp ) endif() diff --git a/src/backends/cl/ClBackend.cpp b/src/backends/cl/ClBackend.cpp index 4ef8d90dfc..2b82c185f0 100644 --- a/src/backends/cl/ClBackend.cpp +++ b/src/backends/cl/ClBackend.cpp @@ -9,10 +9,18 @@ #include "ClBackendContext.hpp" #include "ClLayerSupport.hpp" -#include <backendsCommon/IBackendContext.hpp> +#include <aclCommon/BaseMemoryManager.hpp> + #include <backendsCommon/BackendRegistry.hpp> +#include <backendsCommon/IBackendContext.hpp> +#include <backendsCommon/IMemoryManager.hpp> + #include <Optimizer.hpp> +#include <arm_compute/runtime/CL/CLBufferAllocator.h> + +#include <boost/polymorphic_pointer_cast.hpp> + namespace armnn { @@ -37,9 +45,16 @@ const BackendId& ClBackend::GetIdStatic() return s_Id; } -IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory() const +IBackendInternal::IMemoryManagerUniquePtr ClBackend::CreateMemoryManager() const +{ + return std::make_unique<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>()); +} + +IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory( + const IBackendInternal::IMemoryManagerSharedPtr& memoryManager) const { - return std::make_unique<ClWorkloadFactory>(); + return std::make_unique<ClWorkloadFactory>( + boost::polymorphic_pointer_downcast<ClMemoryManager>(memoryManager)); } IBackendInternal::IBackendContextPtr diff --git a/src/backends/cl/ClBackend.hpp b/src/backends/cl/ClBackend.hpp index 7ee85980a3..ef98da08a4 100644 --- a/src/backends/cl/ClBackend.hpp +++ b/src/backends/cl/ClBackend.hpp @@ -18,8 +18,13 @@ public: static const BackendId& GetIdStatic(); const BackendId& GetId() const override { return GetIdStatic(); } - IBackendInternal::IWorkloadFactoryPtr CreateWorkloadFactory() const override; + IBackendInternal::IMemoryManagerUniquePtr CreateMemoryManager() const override; + + IBackendInternal::IWorkloadFactoryPtr CreateWorkloadFactory( + const IBackendInternal::IMemoryManagerSharedPtr& memoryManager = nullptr) const override; + IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override; + IBackendInternal::Optimizations GetOptimizations() const override; IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override; }; diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp index 0862ea163e..567954919d 100644 --- a/src/backends/cl/ClWorkloadFactory.cpp +++ b/src/backends/cl/ClWorkloadFactory.cpp @@ -82,15 +82,15 @@ std::unique_ptr<IWorkload> ClWorkloadFactory::MakeWorkload(const QueueDescriptor } } -ClWorkloadFactory::ClWorkloadFactory() -: m_MemoryManager(std::make_unique<arm_compute::CLBufferAllocator>()) +ClWorkloadFactory::ClWorkloadFactory(const std::shared_ptr<ClMemoryManager>& memoryManager) + : m_MemoryManager(memoryManager) { } std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const { std::unique_ptr<ClTensorHandle> tensorHandle = std::make_unique<ClTensorHandle>(tensorInfo); - tensorHandle->SetMemoryGroup(m_MemoryManager.GetInterLayerMemoryGroup()); + tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup()); return tensorHandle; } @@ -99,7 +99,7 @@ std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateTensorHandle(const Tenso DataLayout dataLayout) const { std::unique_ptr<ClTensorHandle> tensorHandle = std::make_unique<ClTensorHandle>(tensorInfo, dataLayout); - tensorHandle->SetMemoryGroup(m_MemoryManager.GetInterLayerMemoryGroup()); + tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup()); return tensorHandle; } @@ -145,7 +145,7 @@ std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSoftmax(const SoftmaxQueueDe const WorkloadInfo& info) const { return MakeWorkload<ClSoftmaxFloatWorkload, ClSoftmaxUint8Workload>(descriptor, info, - m_MemoryManager.GetIntraLayerManager()); + m_MemoryManager->GetIntraLayerManager()); } std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& descriptor, @@ -164,7 +164,7 @@ std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateFullyConnected( const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info) const { return MakeWorkload<ClFullyConnectedWorkload, ClFullyConnectedWorkload>(descriptor, info, - m_MemoryManager.GetIntraLayerManager()); + m_MemoryManager->GetIntraLayerManager()); } std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor, @@ -182,7 +182,7 @@ std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreatePooling2d(const Pooli std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload<ClConvolution2dWorkload>(descriptor, info, m_MemoryManager.GetIntraLayerManager()); + return MakeWorkload<ClConvolution2dWorkload>(descriptor, info, m_MemoryManager->GetIntraLayerManager()); } std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDepthwiseConvolution2d( @@ -322,20 +322,16 @@ std::unique_ptr<IWorkload> ClWorkloadFactory::CreateBatchToSpaceNd(const BatchTo void ClWorkloadFactory::Release() { - m_MemoryManager.Release(); + m_MemoryManager->Release(); } void ClWorkloadFactory::Acquire() { - m_MemoryManager.Acquire(); + m_MemoryManager->Acquire(); } #else // #if ARMCOMPUTECL_ENABLED -ClWorkloadFactory::ClWorkloadFactory() -{ -} - std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const { return nullptr; diff --git a/src/backends/cl/ClWorkloadFactory.hpp b/src/backends/cl/ClWorkloadFactory.hpp index 6a928dbbfc..cb715e1db9 100644 --- a/src/backends/cl/ClWorkloadFactory.hpp +++ b/src/backends/cl/ClWorkloadFactory.hpp @@ -17,7 +17,7 @@ namespace armnn class ClWorkloadFactory : public IWorkloadFactory { public: - ClWorkloadFactory(); + ClWorkloadFactory(const std::shared_ptr<ClMemoryManager>& memoryManager); const BackendId& GetBackendId() const override; @@ -134,8 +134,6 @@ public: virtual void Acquire() override; private: - -#ifdef ARMCOMPUTECL_ENABLED template<typename FloatWorkload, typename Uint8Workload, typename QueueDescriptorType, typename... Args> static std::unique_ptr<IWorkload> MakeWorkload(const QueueDescriptorType& descriptor, const WorkloadInfo& info, @@ -146,8 +144,7 @@ private: const WorkloadInfo& info, Args&&... args); - mutable ClMemoryManager m_MemoryManager; -#endif + mutable std::shared_ptr<ClMemoryManager> m_MemoryManager; }; } // namespace armnn diff --git a/src/backends/cl/test/CMakeLists.txt b/src/backends/cl/test/CMakeLists.txt index 574edf4f58..206cf5a9dd 100644 --- a/src/backends/cl/test/CMakeLists.txt +++ b/src/backends/cl/test/CMakeLists.txt @@ -13,6 +13,8 @@ list(APPEND armnnClBackendUnitTests_sources ClMemCopyTests.cpp ClOptimizedNetworkTests.cpp ClRuntimeTests.cpp + ClWorkloadFactoryHelper.hpp + Fp16SupportTest.cpp OpenClTimerTest.cpp ) diff --git a/src/backends/cl/test/ClCreateWorkloadTests.cpp b/src/backends/cl/test/ClCreateWorkloadTests.cpp index 72a2eb27e1..978b3bce9a 100644 --- a/src/backends/cl/test/ClCreateWorkloadTests.cpp +++ b/src/backends/cl/test/ClCreateWorkloadTests.cpp @@ -4,6 +4,7 @@ // #include "ClContextControlFixture.hpp" +#include "ClWorkloadFactoryHelper.hpp" #include <backendsCommon/MemCopyWorkload.hpp> @@ -26,7 +27,7 @@ template <armnn::DataType DataType> static void ClCreateActivationWorkloadTest() { Graph graph; - ClWorkloadFactory factory; + ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); auto workload = CreateActivationWorkloadTest<ClActivationWorkload, DataType>(factory, graph); @@ -56,7 +57,7 @@ template <typename WorkloadType, static void ClCreateArithmethicWorkloadTest() { Graph graph; - ClWorkloadFactory factory; + ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); auto workload = CreateArithmeticWorkloadTest<WorkloadType, DescriptorType, LayerType, DataType>(factory, graph); // Checks that inputs/outputs are as we expect them (see definition of CreateArithmeticWorkloadTest). @@ -145,7 +146,7 @@ template <typename BatchNormalizationWorkloadType, armnn::DataType DataType> static void ClCreateBatchNormalizationWorkloadTest(DataLayout dataLayout) { Graph graph; - ClWorkloadFactory factory; + ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); auto workload = CreateBatchNormalizationWorkloadTest<BatchNormalizationWorkloadType, DataType> (factory, graph, dataLayout); @@ -194,7 +195,7 @@ BOOST_AUTO_TEST_CASE(CreateBatchNormalizationNhwcFloat16NhwcWorkload) BOOST_AUTO_TEST_CASE(CreateConvertFp16ToFp32Workload) { Graph graph; - ClWorkloadFactory factory; + ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); auto workload = CreateConvertFp16ToFp32WorkloadTest<ClConvertFp16ToFp32Workload>(factory, graph); ConvertFp16ToFp32QueueDescriptor queueDescriptor = workload->GetData(); @@ -210,7 +211,7 @@ BOOST_AUTO_TEST_CASE(CreateConvertFp16ToFp32Workload) BOOST_AUTO_TEST_CASE(CreateConvertFp32ToFp16Workload) { Graph graph; - ClWorkloadFactory factory; + ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); auto workload = CreateConvertFp32ToFp16WorkloadTest<ClConvertFp32ToFp16Workload>(factory, graph); ConvertFp32ToFp16QueueDescriptor queueDescriptor = workload->GetData(); @@ -227,7 +228,7 @@ template <typename Convolution2dWorkloadType, typename armnn::DataType DataType> static void ClConvolution2dWorkloadTest(DataLayout dataLayout) { Graph graph; - ClWorkloadFactory factory; + ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); auto workload = CreateConvolution2dWorkloadTest<ClConvolution2dWorkload, DataType>(factory, graph, dataLayout); @@ -269,7 +270,7 @@ template <typename DepthwiseConvolutionWorkloadType, typename armnn::DataType Da static void ClDepthwiseConvolutionWorkloadTest(DataLayout dataLayout) { Graph graph; - ClWorkloadFactory factory; + ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); auto workload = CreateDepthwiseConvolution2dWorkloadTest<DepthwiseConvolutionWorkloadType, DataType> (factory, graph, dataLayout); @@ -299,7 +300,7 @@ template <typename Convolution2dWorkloadType, typename armnn::DataType DataType> static void ClDirectConvolution2dWorkloadTest() { Graph graph; - ClWorkloadFactory factory; + ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); auto workload = CreateDirectConvolution2dWorkloadTest<ClConvolution2dWorkload, DataType>(factory, graph); // Checks that outputs and inputs are as we expect them (see definition of CreateDirectConvolution2dWorkloadTest). @@ -329,7 +330,7 @@ template <typename FullyConnectedWorkloadType, typename armnn::DataType DataType static void ClCreateFullyConnectedWorkloadTest() { Graph graph; - ClWorkloadFactory factory; + ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); auto workload = CreateFullyConnectedWorkloadTest<FullyConnectedWorkloadType, DataType>(factory, graph); @@ -356,7 +357,7 @@ template <typename NormalizationWorkloadType, typename armnn::DataType DataType> static void ClNormalizationWorkloadTest(DataLayout dataLayout) { Graph graph; - ClWorkloadFactory factory; + ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); auto workload = CreateNormalizationWorkloadTest<NormalizationWorkloadType, DataType>(factory, graph, dataLayout); // Checks that inputs/outputs are as we expect them (see definition of CreateNormalizationWorkloadTest). @@ -397,7 +398,7 @@ template <typename armnn::DataType DataType> static void ClPooling2dWorkloadTest(DataLayout dataLayout) { Graph graph; - ClWorkloadFactory factory; + ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); auto workload = CreatePooling2dWorkloadTest<ClPooling2dWorkload, DataType>(factory, graph, dataLayout); @@ -439,7 +440,7 @@ template <typename armnn::DataType DataType> static void ClCreateReshapeWorkloadTest() { Graph graph; - ClWorkloadFactory factory; + ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); auto workload = CreateReshapeWorkloadTest<ClReshapeWorkload, DataType>(factory, graph); @@ -471,7 +472,7 @@ template <typename SoftmaxWorkloadType, typename armnn::DataType DataType> static void ClSoftmaxWorkloadTest() { Graph graph; - ClWorkloadFactory factory; + ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); auto workload = CreateSoftmaxWorkloadTest<SoftmaxWorkloadType, DataType>(factory, graph); @@ -499,7 +500,7 @@ template <typename armnn::DataType DataType> static void ClSplitterWorkloadTest() { Graph graph; - ClWorkloadFactory factory; + ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); auto workload = CreateSplitterWorkloadTest<ClSplitterWorkload, DataType>(factory, graph); @@ -540,7 +541,7 @@ static void ClSplitterMergerTest() // of the merger. Graph graph; - ClWorkloadFactory factory; + ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); auto workloads = CreateSplitterMergerWorkloadTest<ClSplitterWorkload, ClMergerWorkload, DataType> @@ -589,7 +590,7 @@ BOOST_AUTO_TEST_CASE(CreateSingleOutputMultipleInputs) // We create a splitter with two outputs. That each of those outputs is used by two different activation layers. Graph graph; - ClWorkloadFactory factory; + ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); std::unique_ptr<ClSplitterWorkload> wlSplitter; std::unique_ptr<ClActivationWorkload> wlActiv0_0; std::unique_ptr<ClActivationWorkload> wlActiv0_1; @@ -624,7 +625,7 @@ BOOST_AUTO_TEST_CASE(CreateSingleOutputMultipleInputs) BOOST_AUTO_TEST_CASE(CreateMemCopyWorkloadsCl) { - ClWorkloadFactory factory; + ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); CreateMemCopyWorkloads<IClTensorHandle>(factory); } @@ -632,7 +633,7 @@ template <typename L2NormalizationWorkloadType, typename armnn::DataType DataTyp static void ClL2NormalizationWorkloadTest(DataLayout dataLayout) { Graph graph; - ClWorkloadFactory factory; + ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); auto workload = CreateL2NormalizationWorkloadTest<L2NormalizationWorkloadType, DataType>(factory, graph, dataLayout); @@ -676,7 +677,7 @@ template <typename LstmWorkloadType> static void ClCreateLstmWorkloadTest() { Graph graph; - ClWorkloadFactory factory; + ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); auto workload = CreateLstmWorkloadTest<LstmWorkloadType>(factory, graph); LstmQueueDescriptor queueDescriptor = workload->GetData(); @@ -695,7 +696,7 @@ template <typename ResizeBilinearWorkloadType, typename armnn::DataType DataType static void ClResizeBilinearWorkloadTest(DataLayout dataLayout) { Graph graph; - ClWorkloadFactory factory; + ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); auto workload = CreateResizeBilinearWorkloadTest<ResizeBilinearWorkloadType, DataType>(factory, graph, dataLayout); @@ -741,7 +742,7 @@ template <typename MeanWorkloadType, typename armnn::DataType DataType> static void ClMeanWorkloadTest() { Graph graph; - ClWorkloadFactory factory; + ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); auto workload = CreateMeanWorkloadTest<MeanWorkloadType, DataType>(factory, graph); // Checks that inputs/outputs are as we expect them (see definition of CreateMeanWorkloadTest). diff --git a/src/backends/cl/test/ClLayerSupportTests.cpp b/src/backends/cl/test/ClLayerSupportTests.cpp index 0019afed6b..2218d821ef 100644 --- a/src/backends/cl/test/ClLayerSupportTests.cpp +++ b/src/backends/cl/test/ClLayerSupportTests.cpp @@ -3,6 +3,8 @@ // SPDX-License-Identifier: MIT // +#include "ClWorkloadFactoryHelper.hpp" + #include <layers/ConvertFp16ToFp32Layer.hpp> #include <layers/ConvertFp32ToFp16Layer.hpp> #include <test/TensorHelpers.hpp> @@ -21,19 +23,19 @@ BOOST_AUTO_TEST_SUITE(ClLayerSupport) BOOST_FIXTURE_TEST_CASE(IsLayerSupportedFloat16Cl, ClContextControlFixture) { - armnn::ClWorkloadFactory factory; + armnn::ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); IsLayerSupportedTests<armnn::ClWorkloadFactory, armnn::DataType::Float16>(&factory); } BOOST_FIXTURE_TEST_CASE(IsLayerSupportedFloat32Cl, ClContextControlFixture) { - armnn::ClWorkloadFactory factory; + armnn::ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); IsLayerSupportedTests<armnn::ClWorkloadFactory, armnn::DataType::Float32>(&factory); } BOOST_FIXTURE_TEST_CASE(IsLayerSupportedUint8Cl, ClContextControlFixture) { - armnn::ClWorkloadFactory factory; + armnn::ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); IsLayerSupportedTests<armnn::ClWorkloadFactory, armnn::DataType::QuantisedAsymm8>(&factory); } diff --git a/src/backends/cl/test/ClLayerTests.cpp b/src/backends/cl/test/ClLayerTests.cpp index ade0790894..c7d64ef607 100755 --- a/src/backends/cl/test/ClLayerTests.cpp +++ b/src/backends/cl/test/ClLayerTests.cpp @@ -4,6 +4,7 @@ // #include "ClContextControlFixture.hpp" +#include "ClWorkloadFactoryHelper.hpp" #include "test/TensorHelpers.hpp" #include "test/UnitTests.hpp" diff --git a/src/backends/cl/test/ClMemCopyTests.cpp b/src/backends/cl/test/ClMemCopyTests.cpp index 93b8df17bf..93d8dd5662 100644 --- a/src/backends/cl/test/ClMemCopyTests.cpp +++ b/src/backends/cl/test/ClMemCopyTests.cpp @@ -3,10 +3,14 @@ // SPDX-License-Identifier: MIT // +#include "ClWorkloadFactoryHelper.hpp" + #include <cl/ClWorkloadFactory.hpp> -#include <reference/RefWorkloadFactory.hpp> #include <aclCommon/test/MemCopyTestImpl.hpp> +#include <reference/RefWorkloadFactory.hpp> +#include <reference/test/RefWorkloadFactoryHelper.hpp> + #include <boost/test/unit_test.hpp> BOOST_AUTO_TEST_SUITE(ClMemCopy) diff --git a/src/backends/cl/test/ClOptimizedNetworkTests.cpp b/src/backends/cl/test/ClOptimizedNetworkTests.cpp index cd8a770812..7e321470c1 100644 --- a/src/backends/cl/test/ClOptimizedNetworkTests.cpp +++ b/src/backends/cl/test/ClOptimizedNetworkTests.cpp @@ -3,6 +3,8 @@ // SPDX-License-Identifier: MIT // +#include "ClWorkloadFactoryHelper.hpp" + #include <armnn/ArmNN.hpp> #include <Network.hpp> @@ -32,7 +34,7 @@ BOOST_AUTO_TEST_CASE(OptimizeValidateGpuDeviceSupportLayerNoFallback) armnn::IOptimizedNetworkPtr optNet = armnn::Optimize(*net, backends, runtime->GetDeviceSpec()); BOOST_CHECK(optNet); // validate workloads - armnn::ClWorkloadFactory fact; + armnn::ClWorkloadFactory fact = ClWorkloadFactoryHelper::GetFactory(); for (auto&& layer : static_cast<armnn::OptimizedNetwork*>(optNet.get())->GetGraph()) { BOOST_CHECK(layer->GetBackendId() == armnn::Compute::GpuAcc); diff --git a/src/backends/cl/test/ClWorkloadFactoryHelper.hpp b/src/backends/cl/test/ClWorkloadFactoryHelper.hpp new file mode 100644 index 0000000000..7b60b8ad15 --- /dev/null +++ b/src/backends/cl/test/ClWorkloadFactoryHelper.hpp @@ -0,0 +1,35 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <backendsCommon/IBackendInternal.hpp> +#include <backendsCommon/IMemoryManager.hpp> +#include <backendsCommon/test/WorkloadFactoryHelper.hpp> + +#include <cl/ClWorkloadFactory.hpp> + +#include <arm_compute/runtime/CL/CLBufferAllocator.h> + +#include <boost/polymorphic_pointer_cast.hpp> + +namespace +{ + +template<> +struct WorkloadFactoryHelper<armnn::ClWorkloadFactory> +{ + static armnn::ClWorkloadFactory GetFactory() + { + armnn::IBackendInternal::IMemoryManagerSharedPtr memoryManager = + std::make_shared<armnn::ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>()); + + return armnn::ClWorkloadFactory(boost::polymorphic_pointer_downcast<armnn::ClMemoryManager>(memoryManager)); + } +}; + +using ClWorkloadFactoryHelper = WorkloadFactoryHelper<armnn::ClWorkloadFactory>; + +} // anonymous namespace diff --git a/src/backends/cl/test/OpenClTimerTest.cpp b/src/backends/cl/test/OpenClTimerTest.cpp index 0c40a868eb..6e55be6c3d 100644 --- a/src/backends/cl/test/OpenClTimerTest.cpp +++ b/src/backends/cl/test/OpenClTimerTest.cpp @@ -5,6 +5,8 @@ #if (defined(__aarch64__)) || (defined(__x86_64__)) // disable test failing on FireFly/Armv7 +#include "ClWorkloadFactoryHelper.hpp" + #include <test/TensorHelpers.hpp> #include <backendsCommon/CpuTensorHandle.hpp> @@ -42,7 +44,7 @@ using FactoryType = ClWorkloadFactory; BOOST_AUTO_TEST_CASE(OpenClTimerBatchNorm) { - ClWorkloadFactory workloadFactory; + ClWorkloadFactory workloadFactory = ClWorkloadFactoryHelper::GetFactory(); const unsigned int width = 2; const unsigned int height = 3; |