From 5caf907efc31e774f8afde54b17a5596477772f6 Mon Sep 17 00:00:00 2001 From: Aron Virginas-Tar Date: Wed, 14 Nov 2018 18:35:18 +0000 Subject: IVGCVSW-2136: Remove memory management methods from workload factories Change-Id: Idc0f94590566ac362f7e1d1999361d025cc2f67a --- src/backends/cl/test/ClCreateWorkloadTests.cpp | 75 ++++++++++++++++++-------- 1 file changed, 54 insertions(+), 21 deletions(-) (limited to 'src/backends/cl/test/ClCreateWorkloadTests.cpp') diff --git a/src/backends/cl/test/ClCreateWorkloadTests.cpp b/src/backends/cl/test/ClCreateWorkloadTests.cpp index 978b3bce9a..b243ca8007 100644 --- a/src/backends/cl/test/ClCreateWorkloadTests.cpp +++ b/src/backends/cl/test/ClCreateWorkloadTests.cpp @@ -27,7 +27,8 @@ template static void ClCreateActivationWorkloadTest() { Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); auto workload = CreateActivationWorkloadTest(factory, graph); @@ -57,7 +58,9 @@ template (factory, graph); // Checks that inputs/outputs are as we expect them (see definition of CreateArithmeticWorkloadTest). @@ -146,7 +149,8 @@ template static void ClCreateBatchNormalizationWorkloadTest(DataLayout dataLayout) { Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); auto workload = CreateBatchNormalizationWorkloadTest (factory, graph, dataLayout); @@ -195,7 +199,9 @@ BOOST_AUTO_TEST_CASE(CreateBatchNormalizationNhwcFloat16NhwcWorkload) BOOST_AUTO_TEST_CASE(CreateConvertFp16ToFp32Workload) { Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); + auto workload = CreateConvertFp16ToFp32WorkloadTest(factory, graph); ConvertFp16ToFp32QueueDescriptor queueDescriptor = workload->GetData(); @@ -211,7 +217,9 @@ BOOST_AUTO_TEST_CASE(CreateConvertFp16ToFp32Workload) BOOST_AUTO_TEST_CASE(CreateConvertFp32ToFp16Workload) { Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); + auto workload = CreateConvertFp32ToFp16WorkloadTest(factory, graph); ConvertFp32ToFp16QueueDescriptor queueDescriptor = workload->GetData(); @@ -228,7 +236,9 @@ template static void ClConvolution2dWorkloadTest(DataLayout dataLayout) { Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); + auto workload = CreateConvolution2dWorkloadTest(factory, graph, dataLayout); @@ -270,7 +280,8 @@ template (factory, graph, dataLayout); @@ -300,7 +311,9 @@ template static void ClDirectConvolution2dWorkloadTest() { Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); + auto workload = CreateDirectConvolution2dWorkloadTest(factory, graph); // Checks that outputs and inputs are as we expect them (see definition of CreateDirectConvolution2dWorkloadTest). @@ -330,7 +343,9 @@ template (factory, graph); @@ -357,7 +372,9 @@ template static void ClNormalizationWorkloadTest(DataLayout dataLayout) { Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); + auto workload = CreateNormalizationWorkloadTest(factory, graph, dataLayout); // Checks that inputs/outputs are as we expect them (see definition of CreateNormalizationWorkloadTest). @@ -398,7 +415,8 @@ template static void ClPooling2dWorkloadTest(DataLayout dataLayout) { Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); auto workload = CreatePooling2dWorkloadTest(factory, graph, dataLayout); @@ -440,7 +458,8 @@ template static void ClCreateReshapeWorkloadTest() { Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); auto workload = CreateReshapeWorkloadTest(factory, graph); @@ -472,7 +491,8 @@ template static void ClSoftmaxWorkloadTest() { Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); auto workload = CreateSoftmaxWorkloadTest(factory, graph); @@ -500,7 +520,8 @@ template static void ClSplitterWorkloadTest() { Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); auto workload = CreateSplitterWorkloadTest(factory, graph); @@ -541,7 +562,8 @@ static void ClSplitterMergerTest() // of the merger. Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); auto workloads = CreateSplitterMergerWorkloadTest @@ -590,7 +612,9 @@ BOOST_AUTO_TEST_CASE(CreateSingleOutputMultipleInputs) // We create a splitter with two outputs. That each of those outputs is used by two different activation layers. Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); + std::unique_ptr wlSplitter; std::unique_ptr wlActiv0_0; std::unique_ptr wlActiv0_1; @@ -625,7 +649,9 @@ BOOST_AUTO_TEST_CASE(CreateSingleOutputMultipleInputs) BOOST_AUTO_TEST_CASE(CreateMemCopyWorkloadsCl) { - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); + CreateMemCopyWorkloads(factory); } @@ -633,7 +659,9 @@ template (factory, graph, dataLayout); @@ -677,7 +705,9 @@ template static void ClCreateLstmWorkloadTest() { Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); + auto workload = CreateLstmWorkloadTest(factory, graph); LstmQueueDescriptor queueDescriptor = workload->GetData(); @@ -696,7 +726,8 @@ template (factory, graph, dataLayout); @@ -742,7 +773,9 @@ template static void ClMeanWorkloadTest() { Graph graph; - ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(); + ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); + auto workload = CreateMeanWorkloadTest(factory, graph); // Checks that inputs/outputs are as we expect them (see definition of CreateMeanWorkloadTest). -- cgit v1.2.1