diff options
Diffstat (limited to 'src/backends/reference/test/RefCreateWorkloadTests.cpp')
-rw-r--r-- | src/backends/reference/test/RefCreateWorkloadTests.cpp | 48 |
1 files changed, 28 insertions, 20 deletions
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp index 945a87430c..8fe18f5d78 100644 --- a/src/backends/reference/test/RefCreateWorkloadTests.cpp +++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp @@ -36,6 +36,14 @@ void CheckInputsOutput(std::unique_ptr<Workload> workload, BOOST_TEST((inputHandle1->GetTensorInfo() == inputInfo1)); BOOST_TEST((outputHandle->GetTensorInfo() == outputInfo)); } + +armnn::RefWorkloadFactory GetFactory() +{ + std::shared_ptr<RefMemoryManager> memoryManager = std::make_shared<RefMemoryManager>(); + return RefWorkloadFactory(memoryManager); +} + + } BOOST_AUTO_TEST_SUITE(CreateWorkloadRef) @@ -44,7 +52,7 @@ template <typename ActivationWorkloadType, armnn::DataType DataType> static void RefCreateActivationWorkloadTest() { Graph graph; - RefWorkloadFactory factory; + RefWorkloadFactory factory = GetFactory(); auto workload = CreateActivationWorkloadTest<ActivationWorkloadType, DataType>(factory, graph); // Checks that outputs are as we expect them (see definition of CreateActivationWorkloadTest). @@ -70,7 +78,7 @@ template <typename WorkloadType, static void RefCreateElementwiseWorkloadTest() { Graph graph; - RefWorkloadFactory factory; + RefWorkloadFactory factory = GetFactory(); auto workload = CreateElementwiseWorkloadTest<WorkloadType, DescriptorType, LayerType, DataType>( factory, graph); @@ -180,7 +188,7 @@ template <typename BatchNormalizationWorkloadType, armnn::DataType DataType> static void RefCreateBatchNormalizationWorkloadTest(DataLayout dataLayout) { Graph graph; - RefWorkloadFactory factory; + RefWorkloadFactory factory = GetFactory(); auto workload = CreateBatchNormalizationWorkloadTest<BatchNormalizationWorkloadType, DataType>(factory, graph, dataLayout); @@ -244,7 +252,7 @@ BOOST_AUTO_TEST_CASE(CreateBatchNormalizationInt16WorkloadNhwc) BOOST_AUTO_TEST_CASE(CreateConvertFp16ToFp32Float32Workload) { Graph graph; - RefWorkloadFactory factory; + RefWorkloadFactory factory = GetFactory(); auto workload = CreateConvertFp16ToFp32WorkloadTest<RefConvertFp16ToFp32Workload>(factory, graph); // Checks that outputs and inputs are as we expect them @@ -255,7 +263,7 @@ BOOST_AUTO_TEST_CASE(CreateConvertFp16ToFp32Float32Workload) BOOST_AUTO_TEST_CASE(CreateConvertFp32ToFp16Float16Workload) { Graph graph; - RefWorkloadFactory factory; + RefWorkloadFactory factory = GetFactory(); auto workload = CreateConvertFp32ToFp16WorkloadTest<RefConvertFp32ToFp16Workload>(factory, graph); // Checks that outputs and inputs are as we expect them @@ -266,7 +274,7 @@ BOOST_AUTO_TEST_CASE(CreateConvertFp32ToFp16Float16Workload) static void RefCreateConvolution2dWorkloadTest(DataLayout dataLayout = DataLayout::NCHW) { Graph graph; - RefWorkloadFactory factory; + RefWorkloadFactory factory = GetFactory(); auto workload = CreateConvolution2dWorkloadTest<RefConvolution2dWorkload, DataType::Float32> (factory, graph, dataLayout); @@ -294,7 +302,7 @@ BOOST_AUTO_TEST_CASE(CreateConvolution2dFloatNhwcWorkload) static void RefCreateDepthwiseConvolutionWorkloadTest(DataLayout dataLayout) { Graph graph; - RefWorkloadFactory factory; + RefWorkloadFactory factory = GetFactory(); auto workload = CreateDepthwiseConvolution2dWorkloadTest<RefDepthwiseConvolution2dWorkload, DataType::Float32> (factory, graph, dataLayout); @@ -318,7 +326,7 @@ template <typename FullyConnectedWorkloadType, armnn::DataType DataType> static void RefCreateFullyConnectedWorkloadTest() { Graph graph; - RefWorkloadFactory factory; + RefWorkloadFactory factory = GetFactory(); auto workload = CreateFullyConnectedWorkloadTest<FullyConnectedWorkloadType, DataType>(factory, graph); // Checks that outputs and inputs are as we expect them (see definition of CreateFullyConnectedWorkloadTest). @@ -348,7 +356,7 @@ template <typename NormalizationWorkloadType, armnn::DataType DataType> static void RefCreateNormalizationWorkloadTest(DataLayout dataLayout) { Graph graph; - RefWorkloadFactory factory; + RefWorkloadFactory factory = GetFactory(); auto workload = CreateNormalizationWorkloadTest<NormalizationWorkloadType, DataType>(factory, graph, dataLayout); TensorShape inputShape; @@ -405,7 +413,7 @@ template <typename Pooling2dWorkloadType, armnn::DataType DataType> static void RefCreatePooling2dWorkloadTest(DataLayout dataLayout) { Graph graph; - RefWorkloadFactory factory; + RefWorkloadFactory factory = GetFactory(); auto workload = CreatePooling2dWorkloadTest<Pooling2dWorkloadType, DataType>(factory, graph, dataLayout); TensorShape inputShape; @@ -463,7 +471,7 @@ template <typename SoftmaxWorkloadType, armnn::DataType DataType> static void RefCreateSoftmaxWorkloadTest() { Graph graph; - RefWorkloadFactory factory; + RefWorkloadFactory factory = GetFactory(); auto workload = CreateSoftmaxWorkloadTest<SoftmaxWorkloadType, DataType>(factory, graph); // Checks that outputs and inputs are as we expect them (see definition of CreateSoftmaxWorkloadTest). @@ -492,7 +500,7 @@ template <typename SplitterWorkloadType, armnn::DataType DataType> static void RefCreateSplitterWorkloadTest() { Graph graph; - RefWorkloadFactory factory; + RefWorkloadFactory factory = GetFactory(); auto workload = CreateSplitterWorkloadTest<SplitterWorkloadType, DataType>(factory, graph); // Checks that outputs are as we expect them (see definition of CreateSplitterWorkloadTest). @@ -530,7 +538,7 @@ static void RefCreateSplitterConcatWorkloadTest() // of the concat. Graph graph; - RefWorkloadFactory factory; + RefWorkloadFactory factory = GetFactory(); auto workloads = CreateSplitterConcatWorkloadTest<SplitterWorkloadType, ConcatWorkloadType, DataType> (factory, graph); @@ -570,7 +578,7 @@ static void RefCreateSingleOutputMultipleInputsTest() // We created a splitter with two outputs. That each of those outputs is used by two different activation layers. Graph graph; - RefWorkloadFactory factory; + RefWorkloadFactory factory = GetFactory(); std::unique_ptr<SplitterWorkloadType> wlSplitter; std::unique_ptr<ActivationWorkloadType> wlActiv0_0; std::unique_ptr<ActivationWorkloadType> wlActiv0_1; @@ -617,7 +625,7 @@ template <typename ResizeBilinearWorkloadType, armnn::DataType DataType> static void RefCreateResizeBilinearTest(DataLayout dataLayout) { Graph graph; - RefWorkloadFactory factory; + RefWorkloadFactory factory = GetFactory(); auto workload = CreateResizeBilinearWorkloadTest<ResizeBilinearWorkloadType, DataType>(factory, graph, dataLayout); TensorShape inputShape; @@ -665,7 +673,7 @@ template <typename RsqrtWorkloadType, armnn::DataType DataType> static void RefCreateRsqrtTest() { Graph graph; - RefWorkloadFactory factory; + RefWorkloadFactory factory = GetFactory(); auto workload = CreateRsqrtWorkloadTest<RsqrtWorkloadType, DataType>(factory, graph); @@ -723,7 +731,7 @@ template <typename L2NormalizationWorkloadType, armnn::DataType DataType> static void RefCreateL2NormalizationTest(DataLayout dataLayout) { Graph graph; - RefWorkloadFactory factory; + RefWorkloadFactory factory = GetFactory(); auto workload = CreateL2NormalizationWorkloadTest<L2NormalizationWorkloadType, DataType>(factory, graph, dataLayout); @@ -781,7 +789,7 @@ template <typename ReshapeWorkloadType, armnn::DataType DataType> static void RefCreateReshapeWorkloadTest() { Graph graph; - RefWorkloadFactory factory; + RefWorkloadFactory factory = GetFactory(); auto workload = CreateReshapeWorkloadTest<ReshapeWorkloadType, DataType>(factory, graph); // Checks that outputs and inputs are as we expect them (see definition of CreateReshapeWorkloadTest). @@ -811,7 +819,7 @@ static void RefCreateConcatWorkloadTest(const armnn::TensorShape& outputShape, unsigned int concatAxis) { Graph graph; - RefWorkloadFactory factory; + RefWorkloadFactory factory = GetFactory(); auto workload = CreateConcatWorkloadTest<ConcatWorkloadType, DataType>(factory, graph, outputShape, concatAxis); CheckInputsOutput(std::move(workload), @@ -869,7 +877,7 @@ template <typename ConstantWorkloadType, armnn::DataType DataType> static void RefCreateConstantWorkloadTest(const armnn::TensorShape& outputShape) { armnn::Graph graph; - RefWorkloadFactory factory; + RefWorkloadFactory factory = GetFactory(); auto workload = CreateConstantWorkloadTest<ConstantWorkloadType, DataType>(factory, graph, outputShape); // Check output is as expected |