aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/CreateWorkload.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/CreateWorkload.hpp')
-rw-r--r--src/armnn/test/CreateWorkload.hpp78
1 files changed, 39 insertions, 39 deletions
diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp
index be52eadb57..135a4421cd 100644
--- a/src/armnn/test/CreateWorkload.hpp
+++ b/src/armnn/test/CreateWorkload.hpp
@@ -706,10 +706,10 @@ std::unique_ptr<SplitterWorkload>
return workload;
}
-/// This function constructs a graph with both a splitter and a merger, and returns a pair of the workloads.
-template<typename SplitterWorkload, typename MergerWorkload, armnn::DataType DataType>
-std::pair<std::unique_ptr<SplitterWorkload>, std::unique_ptr<MergerWorkload>>
- CreateSplitterMergerWorkloadTest(armnn::IWorkloadFactory& factory, armnn::Graph& graph)
+/// This function constructs a graph with both a splitter and a concat, and returns a pair of the workloads.
+template<typename SplitterWorkload, typename ConcatWorkload, armnn::DataType DataType>
+std::pair<std::unique_ptr<SplitterWorkload>, std::unique_ptr<ConcatWorkload>>
+ CreateSplitterConcatWorkloadTest(armnn::IWorkloadFactory &factory, armnn::Graph &graph)
{
armnn::TensorInfo inputTensorInfo({ 1, 2, 100, 10 }, DataType);
@@ -733,41 +733,41 @@ std::pair<std::unique_ptr<SplitterWorkload>, std::unique_ptr<MergerWorkload>>
Layer* const splitter = graph.AddLayer<SplitterLayer>(splitterViews, "splitter");
BOOST_TEST_CHECKPOINT("created splitter layer");
- armnn::OriginsDescriptor mergerViews(2);
- mergerViews.SetViewOriginCoord(0, 0, 0);
- mergerViews.SetViewOriginCoord(0, 1, 1);
- mergerViews.SetViewOriginCoord(0, 2, 0);
- mergerViews.SetViewOriginCoord(0, 3, 0);
+ armnn::OriginsDescriptor concatViews(2);
+ concatViews.SetViewOriginCoord(0, 0, 0);
+ concatViews.SetViewOriginCoord(0, 1, 1);
+ concatViews.SetViewOriginCoord(0, 2, 0);
+ concatViews.SetViewOriginCoord(0, 3, 0);
- mergerViews.SetViewOriginCoord(1, 0, 0);
- mergerViews.SetViewOriginCoord(1, 1, 0);
- mergerViews.SetViewOriginCoord(1, 2, 0);
- mergerViews.SetViewOriginCoord(1, 3, 0);
+ concatViews.SetViewOriginCoord(1, 0, 0);
+ concatViews.SetViewOriginCoord(1, 1, 0);
+ concatViews.SetViewOriginCoord(1, 2, 0);
+ concatViews.SetViewOriginCoord(1, 3, 0);
- Layer* const merger = graph.AddLayer<MergerLayer>(mergerViews, "merger");
- BOOST_TEST_CHECKPOINT("created merger layer");
+ Layer* const concat = graph.AddLayer<ConcatLayer>(concatViews, "concat");
+ BOOST_TEST_CHECKPOINT("created concat layer");
Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
// Adds connections.
Connect(input, splitter, inputTensorInfo, 0, 0);
BOOST_TEST_CHECKPOINT("connect input to splitter");
- Connect(splitter, merger, splitTensorInfo1, 0, 1); // The splitter & merger are connected up.
- BOOST_TEST_CHECKPOINT("connect splitter[0] to merger[1]");
- Connect(splitter, merger, splitTensorInfo2, 1, 0); // So that the outputs are flipped round.
- BOOST_TEST_CHECKPOINT("connect splitter[1] to merger[0]");
- Connect(merger, output, inputTensorInfo, 0, 0);
- BOOST_TEST_CHECKPOINT("connect merger to output");
+ Connect(splitter, concat, splitTensorInfo1, 0, 1); // The splitter & concat are connected up.
+ BOOST_TEST_CHECKPOINT("connect splitter[0] to concat[1]");
+ Connect(splitter, concat, splitTensorInfo2, 1, 0); // So that the outputs are flipped round.
+ BOOST_TEST_CHECKPOINT("connect splitter[1] to concat[0]");
+ Connect(concat, output, inputTensorInfo, 0, 0);
+ BOOST_TEST_CHECKPOINT("connect concat to output");
CreateTensorHandles(graph, factory);
BOOST_TEST_CHECKPOINT("created tensor handles");
auto workloadSplitter = MakeAndCheckWorkload<SplitterWorkload>(*splitter, graph, factory);
BOOST_TEST_CHECKPOINT("created splitter workload");
- auto workloadMerger = MakeAndCheckWorkload<MergerWorkload>(*merger, graph, factory);
- BOOST_TEST_CHECKPOINT("created merger workload");
+ auto workloadConcat = MakeAndCheckWorkload<ConcatWorkload>(*concat, graph, factory);
+ BOOST_TEST_CHECKPOINT("created concat workload");
- return {std::move(workloadSplitter), std::move(workloadMerger)};
+ return {std::move(workloadSplitter), std::move(workloadConcat)};
}
@@ -1053,10 +1053,10 @@ std::unique_ptr<MeanWorkload> CreateMeanWorkloadTest(armnn::IWorkloadFactory& fa
return workload;
}
-template<typename MergerWorkload, armnn::DataType DataType>
-std::unique_ptr<MergerWorkload> CreateMergerWorkloadTest(armnn::IWorkloadFactory& factory,
- armnn::Graph& graph,
- const armnn::TensorShape& outputShape,
+template<typename ConcatWorkload, armnn::DataType DataType>
+std::unique_ptr<ConcatWorkload> CreateConcatWorkloadTest(armnn::IWorkloadFactory &factory,
+ armnn::Graph &graph,
+ const armnn::TensorShape &outputShape,
unsigned int concatAxis)
{
armnn::TensorInfo inputTensorInfo({ 2, 3, 2, 5 }, DataType);
@@ -1073,26 +1073,26 @@ std::unique_ptr<MergerWorkload> CreateMergerWorkloadTest(armnn::IWorkloadFactory
inputShapes.end(),
concatAxis);
- Layer* const merger = graph.AddLayer<MergerLayer>(descriptor, "merger");
- BOOST_TEST_CHECKPOINT("created merger layer");
+ Layer* const concat = graph.AddLayer<ConcatLayer>(descriptor, "concat");
+ BOOST_TEST_CHECKPOINT("created concat layer");
Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
// Adds connections.
- Connect(input0, merger, inputTensorInfo, 0, 0);
- BOOST_TEST_CHECKPOINT("connect input0 to merger");
- Connect(input1, merger, inputTensorInfo, 0, 1);
- BOOST_TEST_CHECKPOINT("connect input1 to merger");
- Connect(merger, output, outputTensorInfo, 0, 0);
- BOOST_TEST_CHECKPOINT("connect merger to output");
+ Connect(input0, concat, inputTensorInfo, 0, 0);
+ BOOST_TEST_CHECKPOINT("connect input0 to concat");
+ Connect(input1, concat, inputTensorInfo, 0, 1);
+ BOOST_TEST_CHECKPOINT("connect input1 to concat");
+ Connect(concat, output, outputTensorInfo, 0, 0);
+ BOOST_TEST_CHECKPOINT("connect concat to output");
CreateTensorHandles(graph, factory);
BOOST_TEST_CHECKPOINT("created tensor handles");
- auto workloadMerger = MakeAndCheckWorkload<MergerWorkload>(*merger, graph, factory);
- BOOST_TEST_CHECKPOINT("created merger workload");
+ auto workloadConcat = MakeAndCheckWorkload<ConcatWorkload>(*concat, graph, factory);
+ BOOST_TEST_CHECKPOINT("created concat workload");
- return std::move(workloadMerger);
+ return std::move(workloadConcat);
}
template <typename PreCompiledWorkload, armnn::DataType dataType>