aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/CreateWorkload.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/CreateWorkload.hpp')
-rw-r--r--src/armnn/test/CreateWorkload.hpp42
1 files changed, 42 insertions, 0 deletions
diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp
index 111df4b328..1a9bd56ac5 100644
--- a/src/armnn/test/CreateWorkload.hpp
+++ b/src/armnn/test/CreateWorkload.hpp
@@ -1050,4 +1050,46 @@ std::unique_ptr<MeanWorkload> CreateMeanWorkloadTest(armnn::IWorkloadFactory& fa
return workload;
}
+template<typename MergerWorkload, armnn::DataType DataType>
+std::unique_ptr<MergerWorkload> CreateMergerWorkloadTest(armnn::IWorkloadFactory& factory,
+ armnn::Graph& graph,
+ const armnn::TensorShape& outputShape,
+ unsigned int concatAxis)
+{
+ armnn::TensorInfo inputTensorInfo({ 2, 3, 2, 5 }, DataType);
+ armnn::TensorInfo outputTensorInfo(outputShape, DataType);
+
+ // Constructs the graph.
+ Layer* const input0 = graph.AddLayer<InputLayer>(0, "input0");
+ Layer* const input1 = graph.AddLayer<InputLayer>(1, "input1");
+ armnn::OriginsDescriptor descriptor;
+
+ std::vector<armnn::TensorShape> inputShapes{{ 2, 3, 2, 5 }, { 2, 3, 2, 5 }};
+
+ descriptor = CreateMergerDescriptorForConcatenation(inputShapes.begin(),
+ inputShapes.end(),
+ concatAxis);
+
+ Layer* const merger = graph.AddLayer<MergerLayer>(descriptor, "merger");
+ BOOST_TEST_CHECKPOINT("created merger layer");
+
+ Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
+
+ // Adds connections.
+ Connect(input0, merger, inputTensorInfo, 0, 0);
+ BOOST_TEST_CHECKPOINT("connect input0 to merger");
+ Connect(input1, merger, inputTensorInfo, 0, 1);
+ BOOST_TEST_CHECKPOINT("connect input1 to merger");
+ Connect(merger, output, outputTensorInfo, 0, 0);
+ BOOST_TEST_CHECKPOINT("connect merger to output");
+
+ CreateTensorHandles(graph, factory);
+ BOOST_TEST_CHECKPOINT("created tensor handles");
+
+ auto workloadMerger = MakeAndCheckWorkload<MergerWorkload>(*merger, graph, factory);
+ BOOST_TEST_CHECKPOINT("created merger workload");
+
+ return std::move(workloadMerger);
+}
+
}