diff options
author | Matthew Jackson <matthew.jackson@arm.com> | 2019-07-11 12:07:09 +0100 |
---|---|---|
committer | Áron Virginás-Tar <aron.virginas-tar@arm.com> | 2019-07-17 14:31:02 +0000 |
commit | 81e601c5a5ebf3de3dd6418942708158de50252a (patch) | |
tree | 48307f6d49639d7bc9bfa2db96a2de33d1095861 /src/armnn | |
parent | 01bfd1781a18508577b9135408465ee76f346ae5 (diff) | |
download | armnn-81e601c5a5ebf3de3dd6418942708158de50252a.tar.gz |
IVGCVSW-3419 Add reference workload support for the new Stack layer
* Added reference workload for the Stack layer
* Added factory methods
* Added validation support
* Added unit tests
Signed-off-by: Matthew Jackson <matthew.jackson@arm.com>
Change-Id: Ib14b72c15f53a2a2ca152afc357ce2aa405ccc88
Diffstat (limited to 'src/armnn')
-rw-r--r-- | src/armnn/test/CreateWorkload.hpp | 47 |
1 files changed, 47 insertions, 0 deletions
diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp index 834aa0e620..4181db26d0 100644 --- a/src/armnn/test/CreateWorkload.hpp +++ b/src/armnn/test/CreateWorkload.hpp @@ -1376,4 +1376,51 @@ std::unique_ptr<SpaceToDepthWorkload> CreateSpaceToDepthWorkloadTest(armnn::IWor return workload; } +template <typename StackWorkload> +std::unique_ptr<StackWorkload> CreateStackWorkloadTest(armnn::IWorkloadFactory& factory, + armnn::Graph& graph, + const armnn::TensorShape& inputShape, + const armnn::TensorShape& outputShape, + unsigned int axis, + unsigned int numInputs, + armnn::DataType dataType) +{ + armnn::TensorInfo inputTensorInfo(inputShape, dataType); + armnn::TensorInfo outputTensorInfo(outputShape, dataType); + + // Constructs the Stack layer. + armnn::StackDescriptor descriptor(axis, numInputs, inputShape); + Layer* const stackLayer = graph.AddLayer<StackLayer>(descriptor, "stack"); + BOOST_CHECK(stackLayer != nullptr); + + // Constructs layer inputs and output. + std::vector<Layer*> inputs; + for (unsigned int i=0; i<numInputs; ++i) + { + inputs.push_back(graph.AddLayer<InputLayer>( + static_cast<int>(i), + ("input" + std::to_string(i)).c_str() + )); + BOOST_CHECK(inputs[i] != nullptr); + } + Layer* const output = graph.AddLayer<OutputLayer>(0, "output"); + BOOST_CHECK(output != nullptr); + + // Adds connections. + for (unsigned int i=0; i<numInputs; ++i) + { + Connect(inputs[i], stackLayer, inputTensorInfo, 0, i); + } + Connect(stackLayer, output, outputTensorInfo, 0, 0); + + CreateTensorHandles(graph, factory); + + auto stackWorkload = MakeAndCheckWorkload<StackWorkload>(*stackLayer, graph, factory); + StackQueueDescriptor queueDescriptor = stackWorkload->GetData(); + BOOST_TEST(queueDescriptor.m_Inputs.size() == numInputs); + BOOST_TEST(queueDescriptor.m_Outputs.size() == 1); + + return stackWorkload; +} + } // Anonymous namespace |