aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/CreateWorkload.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/CreateWorkload.hpp')
-rw-r--r--src/armnn/test/CreateWorkload.hpp47
1 files changed, 47 insertions, 0 deletions
diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp
index 834aa0e620..4181db26d0 100644
--- a/src/armnn/test/CreateWorkload.hpp
+++ b/src/armnn/test/CreateWorkload.hpp
@@ -1376,4 +1376,51 @@ std::unique_ptr<SpaceToDepthWorkload> CreateSpaceToDepthWorkloadTest(armnn::IWor
return workload;
}
+template <typename StackWorkload>
+std::unique_ptr<StackWorkload> CreateStackWorkloadTest(armnn::IWorkloadFactory& factory,
+ armnn::Graph& graph,
+ const armnn::TensorShape& inputShape,
+ const armnn::TensorShape& outputShape,
+ unsigned int axis,
+ unsigned int numInputs,
+ armnn::DataType dataType)
+{
+ armnn::TensorInfo inputTensorInfo(inputShape, dataType);
+ armnn::TensorInfo outputTensorInfo(outputShape, dataType);
+
+ // Constructs the Stack layer.
+ armnn::StackDescriptor descriptor(axis, numInputs, inputShape);
+ Layer* const stackLayer = graph.AddLayer<StackLayer>(descriptor, "stack");
+ BOOST_CHECK(stackLayer != nullptr);
+
+ // Constructs layer inputs and output.
+ std::vector<Layer*> inputs;
+ for (unsigned int i=0; i<numInputs; ++i)
+ {
+ inputs.push_back(graph.AddLayer<InputLayer>(
+ static_cast<int>(i),
+ ("input" + std::to_string(i)).c_str()
+ ));
+ BOOST_CHECK(inputs[i] != nullptr);
+ }
+ Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
+ BOOST_CHECK(output != nullptr);
+
+ // Adds connections.
+ for (unsigned int i=0; i<numInputs; ++i)
+ {
+ Connect(inputs[i], stackLayer, inputTensorInfo, 0, i);
+ }
+ Connect(stackLayer, output, outputTensorInfo, 0, 0);
+
+ CreateTensorHandles(graph, factory);
+
+ auto stackWorkload = MakeAndCheckWorkload<StackWorkload>(*stackLayer, graph, factory);
+ StackQueueDescriptor queueDescriptor = stackWorkload->GetData();
+ BOOST_TEST(queueDescriptor.m_Inputs.size() == numInputs);
+ BOOST_TEST(queueDescriptor.m_Outputs.size() == 1);
+
+ return stackWorkload;
+}
+
} // Anonymous namespace