aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/CreateWorkloadClNeon.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/CreateWorkloadClNeon.hpp')
-rw-r--r--src/armnn/test/CreateWorkloadClNeon.hpp15
1 files changed, 7 insertions, 8 deletions
diff --git a/src/armnn/test/CreateWorkloadClNeon.hpp b/src/armnn/test/CreateWorkloadClNeon.hpp
index a41a70755f..d92111ac41 100644
--- a/src/armnn/test/CreateWorkloadClNeon.hpp
+++ b/src/armnn/test/CreateWorkloadClNeon.hpp
@@ -56,22 +56,21 @@ boost::test_tools::predicate_result CompareTensorHandleShape(IComputeTensorHandl
return true;
}
-template<template <DataType> class CopyFromCpuWorkload, template <DataType> class CopyToCpuWorkload,
- typename IComputeTensorHandle>
+template<typename IComputeTensorHandle>
void CreateMemCopyWorkloads(IWorkloadFactory& factory)
{
Graph graph;
RefWorkloadFactory refFactory;
- // create the layers we're testing
+ // Creates the layers we're testing.
Layer* const layer1 = graph.AddLayer<MemCopyLayer>("layer1");
Layer* const layer2 = graph.AddLayer<MemCopyLayer>("layer2");
- // create extra layers
+ // Creates extra layers.
Layer* const input = graph.AddLayer<InputLayer>(0, "input");
Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
- // connect up
+ // Connects up.
TensorInfo tensorInfo({2, 3}, DataType::Float32);
Connect(input, layer1, tensorInfo);
Connect(layer1, layer2, tensorInfo);
@@ -83,8 +82,8 @@ void CreateMemCopyWorkloads(IWorkloadFactory& factory)
output->CreateTensorHandles(graph, refFactory);
// make the workloads and check them
- auto workload1 = MakeAndCheckWorkload<CopyFromCpuWorkload<DataType::Float32>>(*layer1, graph, factory);
- auto workload2 = MakeAndCheckWorkload<CopyToCpuWorkload<DataType::Float32>>(*layer2, graph, refFactory);
+ auto workload1 = MakeAndCheckWorkload<CopyMemGenericWorkload>(*layer1, graph, factory);
+ auto workload2 = MakeAndCheckWorkload<CopyMemGenericWorkload>(*layer2, graph, refFactory);
MemCopyQueueDescriptor queueDescriptor1 = workload1->GetData();
BOOST_TEST(queueDescriptor1.m_Inputs.size() == 1);
@@ -104,4 +103,4 @@ void CreateMemCopyWorkloads(IWorkloadFactory& factory)
BOOST_TEST((outputHandle2->GetTensorInfo() == TensorInfo({2, 3}, DataType::Float32)));
}
-} \ No newline at end of file
+} //namespace \ No newline at end of file