aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test
diff options
context:
space:
mode:
authorMatthew Jackson <matthew.jackson@arm.com>2019-07-11 12:07:09 +0100
committerÁron Virginás-Tar <aron.virginas-tar@arm.com>2019-07-17 14:31:02 +0000
commit81e601c5a5ebf3de3dd6418942708158de50252a (patch)
tree48307f6d49639d7bc9bfa2db96a2de33d1095861 /src/armnn/test
parent01bfd1781a18508577b9135408465ee76f346ae5 (diff)
downloadarmnn-81e601c5a5ebf3de3dd6418942708158de50252a.tar.gz
IVGCVSW-3419 Add reference workload support for the new Stack layer
* Added reference workload for the Stack layer * Added factory methods * Added validation support * Added unit tests Signed-off-by: Matthew Jackson <matthew.jackson@arm.com> Change-Id: Ib14b72c15f53a2a2ca152afc357ce2aa405ccc88
Diffstat (limited to 'src/armnn/test')
-rw-r--r--src/armnn/test/CreateWorkload.hpp47
1 files changed, 47 insertions, 0 deletions
diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp
index 834aa0e620..4181db26d0 100644
--- a/src/armnn/test/CreateWorkload.hpp
+++ b/src/armnn/test/CreateWorkload.hpp
@@ -1376,4 +1376,51 @@ std::unique_ptr<SpaceToDepthWorkload> CreateSpaceToDepthWorkloadTest(armnn::IWor
return workload;
}
+template <typename StackWorkload>
+std::unique_ptr<StackWorkload> CreateStackWorkloadTest(armnn::IWorkloadFactory& factory,
+ armnn::Graph& graph,
+ const armnn::TensorShape& inputShape,
+ const armnn::TensorShape& outputShape,
+ unsigned int axis,
+ unsigned int numInputs,
+ armnn::DataType dataType)
+{
+ armnn::TensorInfo inputTensorInfo(inputShape, dataType);
+ armnn::TensorInfo outputTensorInfo(outputShape, dataType);
+
+ // Constructs the Stack layer.
+ armnn::StackDescriptor descriptor(axis, numInputs, inputShape);
+ Layer* const stackLayer = graph.AddLayer<StackLayer>(descriptor, "stack");
+ BOOST_CHECK(stackLayer != nullptr);
+
+ // Constructs layer inputs and output.
+ std::vector<Layer*> inputs;
+ for (unsigned int i=0; i<numInputs; ++i)
+ {
+ inputs.push_back(graph.AddLayer<InputLayer>(
+ static_cast<int>(i),
+ ("input" + std::to_string(i)).c_str()
+ ));
+ BOOST_CHECK(inputs[i] != nullptr);
+ }
+ Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
+ BOOST_CHECK(output != nullptr);
+
+ // Adds connections.
+ for (unsigned int i=0; i<numInputs; ++i)
+ {
+ Connect(inputs[i], stackLayer, inputTensorInfo, 0, i);
+ }
+ Connect(stackLayer, output, outputTensorInfo, 0, 0);
+
+ CreateTensorHandles(graph, factory);
+
+ auto stackWorkload = MakeAndCheckWorkload<StackWorkload>(*stackLayer, graph, factory);
+ StackQueueDescriptor queueDescriptor = stackWorkload->GetData();
+ BOOST_TEST(queueDescriptor.m_Inputs.size() == numInputs);
+ BOOST_TEST(queueDescriptor.m_Outputs.size() == 1);
+
+ return stackWorkload;
+}
+
} // Anonymous namespace