From 5cdda351b4e12c5299173ec6b0fc75a948bdcda0 Mon Sep 17 00:00:00 2001 From: narpra01 Date: Mon, 19 Nov 2018 15:30:27 +0000 Subject: IVGCVSW-2105 - Unit tests for merger * Add LayerTests * Add WorkloadTests !android-nn-driver:166 Change-Id: I903461002879f60fc9f8ae929f18784e2d9b1fc1 --- src/armnn/test/CreateWorkload.hpp | 42 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) (limited to 'src/armnn/test/CreateWorkload.hpp') diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp index 111df4b328..1a9bd56ac5 100644 --- a/src/armnn/test/CreateWorkload.hpp +++ b/src/armnn/test/CreateWorkload.hpp @@ -1050,4 +1050,46 @@ std::unique_ptr CreateMeanWorkloadTest(armnn::IWorkloadFactory& fa return workload; } +template +std::unique_ptr CreateMergerWorkloadTest(armnn::IWorkloadFactory& factory, + armnn::Graph& graph, + const armnn::TensorShape& outputShape, + unsigned int concatAxis) +{ + armnn::TensorInfo inputTensorInfo({ 2, 3, 2, 5 }, DataType); + armnn::TensorInfo outputTensorInfo(outputShape, DataType); + + // Constructs the graph. + Layer* const input0 = graph.AddLayer(0, "input0"); + Layer* const input1 = graph.AddLayer(1, "input1"); + armnn::OriginsDescriptor descriptor; + + std::vector inputShapes{{ 2, 3, 2, 5 }, { 2, 3, 2, 5 }}; + + descriptor = CreateMergerDescriptorForConcatenation(inputShapes.begin(), + inputShapes.end(), + concatAxis); + + Layer* const merger = graph.AddLayer(descriptor, "merger"); + BOOST_TEST_CHECKPOINT("created merger layer"); + + Layer* const output = graph.AddLayer(0, "output"); + + // Adds connections. + Connect(input0, merger, inputTensorInfo, 0, 0); + BOOST_TEST_CHECKPOINT("connect input0 to merger"); + Connect(input1, merger, inputTensorInfo, 0, 1); + BOOST_TEST_CHECKPOINT("connect input1 to merger"); + Connect(merger, output, outputTensorInfo, 0, 0); + BOOST_TEST_CHECKPOINT("connect merger to output"); + + CreateTensorHandles(graph, factory); + BOOST_TEST_CHECKPOINT("created tensor handles"); + + auto workloadMerger = MakeAndCheckWorkload(*merger, graph, factory); + BOOST_TEST_CHECKPOINT("created merger workload"); + + return std::move(workloadMerger); +} + } -- cgit v1.2.1