diff options
Diffstat (limited to 'src/armnn/test/CreateWorkload.hpp')
-rw-r--r-- | src/armnn/test/CreateWorkload.hpp | 78 |
1 files changed, 39 insertions, 39 deletions
diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp index be52eadb57..135a4421cd 100644 --- a/src/armnn/test/CreateWorkload.hpp +++ b/src/armnn/test/CreateWorkload.hpp @@ -706,10 +706,10 @@ std::unique_ptr<SplitterWorkload> return workload; } -/// This function constructs a graph with both a splitter and a merger, and returns a pair of the workloads. -template<typename SplitterWorkload, typename MergerWorkload, armnn::DataType DataType> -std::pair<std::unique_ptr<SplitterWorkload>, std::unique_ptr<MergerWorkload>> - CreateSplitterMergerWorkloadTest(armnn::IWorkloadFactory& factory, armnn::Graph& graph) +/// This function constructs a graph with both a splitter and a concat, and returns a pair of the workloads. +template<typename SplitterWorkload, typename ConcatWorkload, armnn::DataType DataType> +std::pair<std::unique_ptr<SplitterWorkload>, std::unique_ptr<ConcatWorkload>> + CreateSplitterConcatWorkloadTest(armnn::IWorkloadFactory &factory, armnn::Graph &graph) { armnn::TensorInfo inputTensorInfo({ 1, 2, 100, 10 }, DataType); @@ -733,41 +733,41 @@ std::pair<std::unique_ptr<SplitterWorkload>, std::unique_ptr<MergerWorkload>> Layer* const splitter = graph.AddLayer<SplitterLayer>(splitterViews, "splitter"); BOOST_TEST_CHECKPOINT("created splitter layer"); - armnn::OriginsDescriptor mergerViews(2); - mergerViews.SetViewOriginCoord(0, 0, 0); - mergerViews.SetViewOriginCoord(0, 1, 1); - mergerViews.SetViewOriginCoord(0, 2, 0); - mergerViews.SetViewOriginCoord(0, 3, 0); + armnn::OriginsDescriptor concatViews(2); + concatViews.SetViewOriginCoord(0, 0, 0); + concatViews.SetViewOriginCoord(0, 1, 1); + concatViews.SetViewOriginCoord(0, 2, 0); + concatViews.SetViewOriginCoord(0, 3, 0); - mergerViews.SetViewOriginCoord(1, 0, 0); - mergerViews.SetViewOriginCoord(1, 1, 0); - mergerViews.SetViewOriginCoord(1, 2, 0); - mergerViews.SetViewOriginCoord(1, 3, 0); + concatViews.SetViewOriginCoord(1, 0, 0); + concatViews.SetViewOriginCoord(1, 1, 0); + concatViews.SetViewOriginCoord(1, 2, 0); + concatViews.SetViewOriginCoord(1, 3, 0); - Layer* const merger = graph.AddLayer<MergerLayer>(mergerViews, "merger"); - BOOST_TEST_CHECKPOINT("created merger layer"); + Layer* const concat = graph.AddLayer<ConcatLayer>(concatViews, "concat"); + BOOST_TEST_CHECKPOINT("created concat layer"); Layer* const output = graph.AddLayer<OutputLayer>(0, "output"); // Adds connections. Connect(input, splitter, inputTensorInfo, 0, 0); BOOST_TEST_CHECKPOINT("connect input to splitter"); - Connect(splitter, merger, splitTensorInfo1, 0, 1); // The splitter & merger are connected up. - BOOST_TEST_CHECKPOINT("connect splitter[0] to merger[1]"); - Connect(splitter, merger, splitTensorInfo2, 1, 0); // So that the outputs are flipped round. - BOOST_TEST_CHECKPOINT("connect splitter[1] to merger[0]"); - Connect(merger, output, inputTensorInfo, 0, 0); - BOOST_TEST_CHECKPOINT("connect merger to output"); + Connect(splitter, concat, splitTensorInfo1, 0, 1); // The splitter & concat are connected up. + BOOST_TEST_CHECKPOINT("connect splitter[0] to concat[1]"); + Connect(splitter, concat, splitTensorInfo2, 1, 0); // So that the outputs are flipped round. + BOOST_TEST_CHECKPOINT("connect splitter[1] to concat[0]"); + Connect(concat, output, inputTensorInfo, 0, 0); + BOOST_TEST_CHECKPOINT("connect concat to output"); CreateTensorHandles(graph, factory); BOOST_TEST_CHECKPOINT("created tensor handles"); auto workloadSplitter = MakeAndCheckWorkload<SplitterWorkload>(*splitter, graph, factory); BOOST_TEST_CHECKPOINT("created splitter workload"); - auto workloadMerger = MakeAndCheckWorkload<MergerWorkload>(*merger, graph, factory); - BOOST_TEST_CHECKPOINT("created merger workload"); + auto workloadConcat = MakeAndCheckWorkload<ConcatWorkload>(*concat, graph, factory); + BOOST_TEST_CHECKPOINT("created concat workload"); - return {std::move(workloadSplitter), std::move(workloadMerger)}; + return {std::move(workloadSplitter), std::move(workloadConcat)}; } @@ -1053,10 +1053,10 @@ std::unique_ptr<MeanWorkload> CreateMeanWorkloadTest(armnn::IWorkloadFactory& fa return workload; } -template<typename MergerWorkload, armnn::DataType DataType> -std::unique_ptr<MergerWorkload> CreateMergerWorkloadTest(armnn::IWorkloadFactory& factory, - armnn::Graph& graph, - const armnn::TensorShape& outputShape, +template<typename ConcatWorkload, armnn::DataType DataType> +std::unique_ptr<ConcatWorkload> CreateConcatWorkloadTest(armnn::IWorkloadFactory &factory, + armnn::Graph &graph, + const armnn::TensorShape &outputShape, unsigned int concatAxis) { armnn::TensorInfo inputTensorInfo({ 2, 3, 2, 5 }, DataType); @@ -1073,26 +1073,26 @@ std::unique_ptr<MergerWorkload> CreateMergerWorkloadTest(armnn::IWorkloadFactory inputShapes.end(), concatAxis); - Layer* const merger = graph.AddLayer<MergerLayer>(descriptor, "merger"); - BOOST_TEST_CHECKPOINT("created merger layer"); + Layer* const concat = graph.AddLayer<ConcatLayer>(descriptor, "concat"); + BOOST_TEST_CHECKPOINT("created concat layer"); Layer* const output = graph.AddLayer<OutputLayer>(0, "output"); // Adds connections. - Connect(input0, merger, inputTensorInfo, 0, 0); - BOOST_TEST_CHECKPOINT("connect input0 to merger"); - Connect(input1, merger, inputTensorInfo, 0, 1); - BOOST_TEST_CHECKPOINT("connect input1 to merger"); - Connect(merger, output, outputTensorInfo, 0, 0); - BOOST_TEST_CHECKPOINT("connect merger to output"); + Connect(input0, concat, inputTensorInfo, 0, 0); + BOOST_TEST_CHECKPOINT("connect input0 to concat"); + Connect(input1, concat, inputTensorInfo, 0, 1); + BOOST_TEST_CHECKPOINT("connect input1 to concat"); + Connect(concat, output, outputTensorInfo, 0, 0); + BOOST_TEST_CHECKPOINT("connect concat to output"); CreateTensorHandles(graph, factory); BOOST_TEST_CHECKPOINT("created tensor handles"); - auto workloadMerger = MakeAndCheckWorkload<MergerWorkload>(*merger, graph, factory); - BOOST_TEST_CHECKPOINT("created merger workload"); + auto workloadConcat = MakeAndCheckWorkload<ConcatWorkload>(*concat, graph, factory); + BOOST_TEST_CHECKPOINT("created concat workload"); - return std::move(workloadMerger); + return std::move(workloadConcat); } template <typename PreCompiledWorkload, armnn::DataType dataType> |