diff options
Diffstat (limited to 'src/armnn/test/CreateWorkloadClNeon.hpp')
-rw-r--r-- | src/armnn/test/CreateWorkloadClNeon.hpp | 15 |
1 files changed, 7 insertions, 8 deletions
diff --git a/src/armnn/test/CreateWorkloadClNeon.hpp b/src/armnn/test/CreateWorkloadClNeon.hpp index a41a70755f..d92111ac41 100644 --- a/src/armnn/test/CreateWorkloadClNeon.hpp +++ b/src/armnn/test/CreateWorkloadClNeon.hpp @@ -56,22 +56,21 @@ boost::test_tools::predicate_result CompareTensorHandleShape(IComputeTensorHandl return true; } -template<template <DataType> class CopyFromCpuWorkload, template <DataType> class CopyToCpuWorkload, - typename IComputeTensorHandle> +template<typename IComputeTensorHandle> void CreateMemCopyWorkloads(IWorkloadFactory& factory) { Graph graph; RefWorkloadFactory refFactory; - // create the layers we're testing + // Creates the layers we're testing. Layer* const layer1 = graph.AddLayer<MemCopyLayer>("layer1"); Layer* const layer2 = graph.AddLayer<MemCopyLayer>("layer2"); - // create extra layers + // Creates extra layers. Layer* const input = graph.AddLayer<InputLayer>(0, "input"); Layer* const output = graph.AddLayer<OutputLayer>(0, "output"); - // connect up + // Connects up. TensorInfo tensorInfo({2, 3}, DataType::Float32); Connect(input, layer1, tensorInfo); Connect(layer1, layer2, tensorInfo); @@ -83,8 +82,8 @@ void CreateMemCopyWorkloads(IWorkloadFactory& factory) output->CreateTensorHandles(graph, refFactory); // make the workloads and check them - auto workload1 = MakeAndCheckWorkload<CopyFromCpuWorkload<DataType::Float32>>(*layer1, graph, factory); - auto workload2 = MakeAndCheckWorkload<CopyToCpuWorkload<DataType::Float32>>(*layer2, graph, refFactory); + auto workload1 = MakeAndCheckWorkload<CopyMemGenericWorkload>(*layer1, graph, factory); + auto workload2 = MakeAndCheckWorkload<CopyMemGenericWorkload>(*layer2, graph, refFactory); MemCopyQueueDescriptor queueDescriptor1 = workload1->GetData(); BOOST_TEST(queueDescriptor1.m_Inputs.size() == 1); @@ -104,4 +103,4 @@ void CreateMemCopyWorkloads(IWorkloadFactory& factory) BOOST_TEST((outputHandle2->GetTensorInfo() == TensorInfo({2, 3}, DataType::Float32))); } -}
\ No newline at end of file +} //namespace
\ No newline at end of file |