diff options
author | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2018-11-14 18:35:18 +0000 |
---|---|---|
committer | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2018-11-15 10:38:19 +0000 |
commit | 5caf907efc31e774f8afde54b17a5596477772f6 (patch) | |
tree | 9fcdfe44ccf7c96e5088a2cef06b7d74dfd3221c /src/backends/cl | |
parent | dd9d8ca997cb6c63677249350247e9f44525104c (diff) | |
download | armnn-5caf907efc31e774f8afde54b17a5596477772f6.tar.gz |
IVGCVSW-2136: Remove memory management methods from workload factories
Change-Id: Idc0f94590566ac362f7e1d1999361d025cc2f67a
Diffstat (limited to 'src/backends/cl')
-rw-r--r-- | src/backends/cl/ClWorkloadFactory.cpp | 246 | ||||
-rw-r--r-- | src/backends/cl/ClWorkloadFactory.hpp | 4 | ||||
-rw-r--r-- | src/backends/cl/test/ClCreateWorkloadTests.cpp | 75 | ||||
-rw-r--r-- | src/backends/cl/test/ClLayerSupportTests.cpp | 9 | ||||
-rw-r--r-- | src/backends/cl/test/ClOptimizedNetworkTests.cpp | 3 | ||||
-rw-r--r-- | src/backends/cl/test/ClWorkloadFactoryHelper.hpp | 13 | ||||
-rw-r--r-- | src/backends/cl/test/OpenClTimerTest.cpp | 3 |
7 files changed, 78 insertions, 275 deletions
diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp index 567954919d..1f112008c9 100644 --- a/src/backends/cl/ClWorkloadFactory.cpp +++ b/src/backends/cl/ClWorkloadFactory.cpp @@ -5,26 +5,22 @@ #include "ClWorkloadFactory.hpp" #include "ClBackendId.hpp" +#include <Layer.hpp> + #include <armnn/Exceptions.hpp> #include <armnn/Utils.hpp> -#include <string> #include <backendsCommon/CpuTensorHandle.hpp> -#include <Layer.hpp> - -#ifdef ARMCOMPUTECL_ENABLED -#include <arm_compute/core/CL/CLKernelLibrary.h> -#include <arm_compute/runtime/CL/CLBufferAllocator.h> -#include <arm_compute/runtime/CL/CLScheduler.h> - +#include <backendsCommon/MakeWorkloadHelper.hpp> #include <backendsCommon/MemCopyWorkload.hpp> #include <cl/ClTensorHandle.hpp> #include <cl/workloads/ClWorkloads.hpp> #include <cl/workloads/ClWorkloadUtils.hpp> -#endif -#include <backendsCommon/MakeWorkloadHelper.hpp> +#include <arm_compute/core/CL/CLKernelLibrary.h> +#include <arm_compute/runtime/CL/CLBufferAllocator.h> +#include <arm_compute/runtime/CL/CLScheduler.h> #include <boost/polymorphic_cast.hpp> #include <boost/format.hpp> @@ -50,8 +46,6 @@ const BackendId& ClWorkloadFactory::GetBackendId() const return s_Id; } -#ifdef ARMCOMPUTECL_ENABLED - template <typename FloatWorkload, typename Uint8Workload, typename QueueDescriptorType, typename... Args> std::unique_ptr<IWorkload> ClWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor, const WorkloadInfo& info, @@ -320,232 +314,4 @@ std::unique_ptr<IWorkload> ClWorkloadFactory::CreateBatchToSpaceNd(const BatchTo return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info); } -void ClWorkloadFactory::Release() -{ - m_MemoryManager->Release(); -} - -void ClWorkloadFactory::Acquire() -{ - m_MemoryManager->Acquire(); -} - -#else // #if ARMCOMPUTECL_ENABLED - -std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const -{ - return nullptr; -} - -std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo, - DataLayout dataLayout) const -{ - return nullptr; -} - -std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateSubTensorHandle(ITensorHandle& parent, - TensorShape const& subTensorShape, - unsigned int const* subTensorOrigin) const -{ - return nullptr; -} - -std::unique_ptr<IWorkload> ClWorkloadFactory::CreateInput(const InputQueueDescriptor& descriptor, - const WorkloadInfo& info) const -{ - return nullptr; -} - -std::unique_ptr<IWorkload> ClWorkloadFactory::CreateOutput(const OutputQueueDescriptor& descriptor, - const WorkloadInfo& info) const -{ - return nullptr; -} - -std::unique_ptr<IWorkload> ClWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor, - const WorkloadInfo& info) const -{ - return nullptr; -} - -std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& descriptor, - const WorkloadInfo& info) const -{ - return nullptr; -} - -std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& descriptor, - const WorkloadInfo& info) const -{ - return nullptr; -} - -std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMerger(const MergerQueueDescriptor& descriptor, - const WorkloadInfo& info) const -{ - return nullptr; -} - -std::unique_ptr<IWorkload> ClWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& descriptor, - const WorkloadInfo& info) const -{ - return nullptr; -} - -std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor, - const WorkloadInfo& info) const -{ - return nullptr; -} - -std::unique_ptr<IWorkload> ClWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor, - const WorkloadInfo& info) const -{ - return nullptr; -} - -std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor, - const WorkloadInfo& info) const -{ - return nullptr; -} - -std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDepthwiseConvolution2d( - const DepthwiseConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const -{ - return nullptr; -} - -std::unique_ptr<IWorkload> ClWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& descriptor, - const WorkloadInfo& info) const -{ - return nullptr; -} - -std::unique_ptr<IWorkload> ClWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor, - const WorkloadInfo& info) const -{ - return nullptr; -} - -std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& descriptor, - const WorkloadInfo& info) const -{ - return nullptr; -} - -std::unique_ptr<IWorkload> ClWorkloadFactory::CreateBatchNormalization( - const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const -{ - return nullptr; -} - -std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& descriptor, - const WorkloadInfo& info) const -{ - return nullptr; -} - -std::unique_ptr<IWorkload> ClWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor, - const WorkloadInfo& info) const -{ - return nullptr; -} - -std::unique_ptr<IWorkload> ClWorkloadFactory::CreateFakeQuantization(const FakeQuantizationQueueDescriptor& descriptor, - const WorkloadInfo& info) const -{ - return nullptr; -} - -std::unique_ptr<IWorkload> ClWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor, - const WorkloadInfo& info) const -{ - return nullptr; -} - -std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& descriptor, - const WorkloadInfo& info) const -{ - return nullptr; -} - -std::unique_ptr<IWorkload> ClWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor, - const WorkloadInfo& info) const -{ - return nullptr; -} - -std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor, - const WorkloadInfo& info) const -{ - return nullptr; -} - -std::unique_ptr<IWorkload> ClWorkloadFactory::CreateFloor(const FloorQueueDescriptor& descriptor, - const WorkloadInfo& info) const -{ - return nullptr; -} - -std::unique_ptr<IWorkload> ClWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor, - const WorkloadInfo& info) const -{ - return nullptr; -} - -std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConvertFp16ToFp32( - const ConvertFp16ToFp32QueueDescriptor& descriptor, - const WorkloadInfo& info) const -{ - return nullptr; -} - -std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConvertFp32ToFp16( - const ConvertFp32ToFp16QueueDescriptor& descriptor, - const WorkloadInfo& info) const -{ - return nullptr; -} - -std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& descriptor, - const WorkloadInfo& info) const -{ - return nullptr; -} - -std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& descriptor, - const WorkloadInfo& info) const -{ - return nullptr; -} - -std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMean(const MeanQueueDescriptor& descriptor, - const WorkloadInfo& info) const -{ - return nullptr; -} - -std::unique_ptr<IWorkload> ClWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor, - const WorkloadInfo& info) const -{ - return nullptr; -} - -std::unique_ptr<IWorkload> ClWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor, - const WorkloadInfo& info) const -{ - return nullptr; -} - -void ClWorkloadFactory::Release() -{ -} - -void ClWorkloadFactory::Acquire() -{ -} - -#endif // #if ARMCOMPUTECL_ENABLED - } // namespace armnn diff --git a/src/backends/cl/ClWorkloadFactory.hpp b/src/backends/cl/ClWorkloadFactory.hpp index cb715e1db9..d37a31ffa4 100644 --- a/src/backends/cl/ClWorkloadFactory.hpp +++ b/src/backends/cl/ClWorkloadFactory.hpp @@ -129,10 +129,6 @@ public: virtual std::unique_ptr<IWorkload> CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor, const WorkloadInfo& info) const override; - virtual void Release() override; - - virtual void Acquire() override; - private: template<typename FloatWorkload, typename Uint8Workload, typename QueueDescriptorType, typename... Args> static std::unique_ptr<IWorkload> MakeWorkload(const QueueDescriptorType& descriptor, diff --git a/src/backends/cl/test/ClCreateWorkloadTests.cpp b/src/backends/cl/test/ClCreateWorkloadTests.cpp index 978b3bce9a..b243ca8007 100644 --- a/src/backends/cl/test/ClCreateWorkloadTests.cpp +++ b/src/backends/cl/test/ClCreateWorkloadTests.cpp @@ -27,7 +27,8 @@ template <armnn::DataType DataType> static void ClCreateActivationWorkloadTest() { Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); auto workload = CreateActivationWorkloadTest<ClActivationWorkload, DataType>(factory, graph); @@ -57,7 +58,9 @@ template <typename WorkloadType, static void ClCreateArithmethicWorkloadTest() { Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); + auto workload = CreateArithmeticWorkloadTest<WorkloadType, DescriptorType, LayerType, DataType>(factory, graph); // Checks that inputs/outputs are as we expect them (see definition of CreateArithmeticWorkloadTest). @@ -146,7 +149,8 @@ template <typename BatchNormalizationWorkloadType, armnn::DataType DataType> static void ClCreateBatchNormalizationWorkloadTest(DataLayout dataLayout) { Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); auto workload = CreateBatchNormalizationWorkloadTest<BatchNormalizationWorkloadType, DataType> (factory, graph, dataLayout); @@ -195,7 +199,9 @@ BOOST_AUTO_TEST_CASE(CreateBatchNormalizationNhwcFloat16NhwcWorkload) BOOST_AUTO_TEST_CASE(CreateConvertFp16ToFp32Workload) { Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); + auto workload = CreateConvertFp16ToFp32WorkloadTest<ClConvertFp16ToFp32Workload>(factory, graph); ConvertFp16ToFp32QueueDescriptor queueDescriptor = workload->GetData(); @@ -211,7 +217,9 @@ BOOST_AUTO_TEST_CASE(CreateConvertFp16ToFp32Workload) BOOST_AUTO_TEST_CASE(CreateConvertFp32ToFp16Workload) { Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); + auto workload = CreateConvertFp32ToFp16WorkloadTest<ClConvertFp32ToFp16Workload>(factory, graph); ConvertFp32ToFp16QueueDescriptor queueDescriptor = workload->GetData(); @@ -228,7 +236,9 @@ template <typename Convolution2dWorkloadType, typename armnn::DataType DataType> static void ClConvolution2dWorkloadTest(DataLayout dataLayout) { Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); + auto workload = CreateConvolution2dWorkloadTest<ClConvolution2dWorkload, DataType>(factory, graph, dataLayout); @@ -270,7 +280,8 @@ template <typename DepthwiseConvolutionWorkloadType, typename armnn::DataType Da static void ClDepthwiseConvolutionWorkloadTest(DataLayout dataLayout) { Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); auto workload = CreateDepthwiseConvolution2dWorkloadTest<DepthwiseConvolutionWorkloadType, DataType> (factory, graph, dataLayout); @@ -300,7 +311,9 @@ template <typename Convolution2dWorkloadType, typename armnn::DataType DataType> static void ClDirectConvolution2dWorkloadTest() { Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); + auto workload = CreateDirectConvolution2dWorkloadTest<ClConvolution2dWorkload, DataType>(factory, graph); // Checks that outputs and inputs are as we expect them (see definition of CreateDirectConvolution2dWorkloadTest). @@ -330,7 +343,9 @@ template <typename FullyConnectedWorkloadType, typename armnn::DataType DataType static void ClCreateFullyConnectedWorkloadTest() { Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); + auto workload = CreateFullyConnectedWorkloadTest<FullyConnectedWorkloadType, DataType>(factory, graph); @@ -357,7 +372,9 @@ template <typename NormalizationWorkloadType, typename armnn::DataType DataType> static void ClNormalizationWorkloadTest(DataLayout dataLayout) { Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); + auto workload = CreateNormalizationWorkloadTest<NormalizationWorkloadType, DataType>(factory, graph, dataLayout); // Checks that inputs/outputs are as we expect them (see definition of CreateNormalizationWorkloadTest). @@ -398,7 +415,8 @@ template <typename armnn::DataType DataType> static void ClPooling2dWorkloadTest(DataLayout dataLayout) { Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); auto workload = CreatePooling2dWorkloadTest<ClPooling2dWorkload, DataType>(factory, graph, dataLayout); @@ -440,7 +458,8 @@ template <typename armnn::DataType DataType> static void ClCreateReshapeWorkloadTest() { Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); auto workload = CreateReshapeWorkloadTest<ClReshapeWorkload, DataType>(factory, graph); @@ -472,7 +491,8 @@ template <typename SoftmaxWorkloadType, typename armnn::DataType DataType> static void ClSoftmaxWorkloadTest() { Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); auto workload = CreateSoftmaxWorkloadTest<SoftmaxWorkloadType, DataType>(factory, graph); @@ -500,7 +520,8 @@ template <typename armnn::DataType DataType> static void ClSplitterWorkloadTest() { Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); auto workload = CreateSplitterWorkloadTest<ClSplitterWorkload, DataType>(factory, graph); @@ -541,7 +562,8 @@ static void ClSplitterMergerTest() // of the merger. Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); auto workloads = CreateSplitterMergerWorkloadTest<ClSplitterWorkload, ClMergerWorkload, DataType> @@ -590,7 +612,9 @@ BOOST_AUTO_TEST_CASE(CreateSingleOutputMultipleInputs) // We create a splitter with two outputs. That each of those outputs is used by two different activation layers. Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); + std::unique_ptr<ClSplitterWorkload> wlSplitter; std::unique_ptr<ClActivationWorkload> wlActiv0_0; std::unique_ptr<ClActivationWorkload> wlActiv0_1; @@ -625,7 +649,9 @@ BOOST_AUTO_TEST_CASE(CreateSingleOutputMultipleInputs) BOOST_AUTO_TEST_CASE(CreateMemCopyWorkloadsCl) { - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); + CreateMemCopyWorkloads<IClTensorHandle>(factory); } @@ -633,7 +659,9 @@ template <typename L2NormalizationWorkloadType, typename armnn::DataType DataTyp static void ClL2NormalizationWorkloadTest(DataLayout dataLayout) { Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); + auto workload = CreateL2NormalizationWorkloadTest<L2NormalizationWorkloadType, DataType>(factory, graph, dataLayout); @@ -677,7 +705,9 @@ template <typename LstmWorkloadType> static void ClCreateLstmWorkloadTest() { Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); + auto workload = CreateLstmWorkloadTest<LstmWorkloadType>(factory, graph); LstmQueueDescriptor queueDescriptor = workload->GetData(); @@ -696,7 +726,8 @@ template <typename ResizeBilinearWorkloadType, typename armnn::DataType DataType static void ClResizeBilinearWorkloadTest(DataLayout dataLayout) { Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); auto workload = CreateResizeBilinearWorkloadTest<ResizeBilinearWorkloadType, DataType>(factory, graph, dataLayout); @@ -742,7 +773,9 @@ template <typename MeanWorkloadType, typename armnn::DataType DataType> static void ClMeanWorkloadTest() { Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); + auto workload = CreateMeanWorkloadTest<MeanWorkloadType, DataType>(factory, graph); // Checks that inputs/outputs are as we expect them (see definition of CreateMeanWorkloadTest). diff --git a/src/backends/cl/test/ClLayerSupportTests.cpp b/src/backends/cl/test/ClLayerSupportTests.cpp index 2218d821ef..acfd8c3483 100644 --- a/src/backends/cl/test/ClLayerSupportTests.cpp +++ b/src/backends/cl/test/ClLayerSupportTests.cpp @@ -23,19 +23,22 @@ BOOST_AUTO_TEST_SUITE(ClLayerSupport) BOOST_FIXTURE_TEST_CASE(IsLayerSupportedFloat16Cl, ClContextControlFixture) { - armnn::ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + armnn::ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); IsLayerSupportedTests<armnn::ClWorkloadFactory, armnn::DataType::Float16>(&factory); } BOOST_FIXTURE_TEST_CASE(IsLayerSupportedFloat32Cl, ClContextControlFixture) { - armnn::ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + armnn::ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); IsLayerSupportedTests<armnn::ClWorkloadFactory, armnn::DataType::Float32>(&factory); } BOOST_FIXTURE_TEST_CASE(IsLayerSupportedUint8Cl, ClContextControlFixture) { - armnn::ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + armnn::ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); IsLayerSupportedTests<armnn::ClWorkloadFactory, armnn::DataType::QuantisedAsymm8>(&factory); } diff --git a/src/backends/cl/test/ClOptimizedNetworkTests.cpp b/src/backends/cl/test/ClOptimizedNetworkTests.cpp index 7e321470c1..f8c1a327ef 100644 --- a/src/backends/cl/test/ClOptimizedNetworkTests.cpp +++ b/src/backends/cl/test/ClOptimizedNetworkTests.cpp @@ -34,7 +34,8 @@ BOOST_AUTO_TEST_CASE(OptimizeValidateGpuDeviceSupportLayerNoFallback) armnn::IOptimizedNetworkPtr optNet = armnn::Optimize(*net, backends, runtime->GetDeviceSpec()); BOOST_CHECK(optNet); // validate workloads - armnn::ClWorkloadFactory fact = ClWorkloadFactoryHelper::GetFactory(); + armnn::ClWorkloadFactory fact = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); for (auto&& layer : static_cast<armnn::OptimizedNetwork*>(optNet.get())->GetGraph()) { BOOST_CHECK(layer->GetBackendId() == armnn::Compute::GpuAcc); diff --git a/src/backends/cl/test/ClWorkloadFactoryHelper.hpp b/src/backends/cl/test/ClWorkloadFactoryHelper.hpp index 7b60b8ad15..777bc84b8a 100644 --- a/src/backends/cl/test/ClWorkloadFactoryHelper.hpp +++ b/src/backends/cl/test/ClWorkloadFactoryHelper.hpp @@ -9,10 +9,9 @@ #include <backendsCommon/IMemoryManager.hpp> #include <backendsCommon/test/WorkloadFactoryHelper.hpp> +#include <cl/ClBackend.hpp> #include <cl/ClWorkloadFactory.hpp> -#include <arm_compute/runtime/CL/CLBufferAllocator.h> - #include <boost/polymorphic_pointer_cast.hpp> namespace @@ -21,11 +20,15 @@ namespace template<> struct WorkloadFactoryHelper<armnn::ClWorkloadFactory> { - static armnn::ClWorkloadFactory GetFactory() + static armnn::IBackendInternal::IMemoryManagerSharedPtr GetMemoryManager() { - armnn::IBackendInternal::IMemoryManagerSharedPtr memoryManager = - std::make_shared<armnn::ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>()); + armnn::ClBackend backend; + return backend.CreateMemoryManager(); + } + static armnn::ClWorkloadFactory GetFactory( + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) + { return armnn::ClWorkloadFactory(boost::polymorphic_pointer_downcast<armnn::ClMemoryManager>(memoryManager)); } }; diff --git a/src/backends/cl/test/OpenClTimerTest.cpp b/src/backends/cl/test/OpenClTimerTest.cpp index 6e55be6c3d..6f44cc4772 100644 --- a/src/backends/cl/test/OpenClTimerTest.cpp +++ b/src/backends/cl/test/OpenClTimerTest.cpp @@ -44,7 +44,8 @@ using FactoryType = ClWorkloadFactory; BOOST_AUTO_TEST_CASE(OpenClTimerBatchNorm) { - ClWorkloadFactory workloadFactory = ClWorkloadFactoryHelper::GetFactory(); + auto memoryManager = ClWorkloadFactoryHelper::GetMemoryManager(); + ClWorkloadFactory workloadFactory = ClWorkloadFactoryHelper::GetFactory(memoryManager); const unsigned int width = 2; const unsigned int height = 3; |