aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/LayerValidateOutputTest.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/LayerValidateOutputTest.cpp')
-rw-r--r--src/armnn/test/LayerValidateOutputTest.cpp23
1 files changed, 23 insertions, 0 deletions
diff --git a/src/armnn/test/LayerValidateOutputTest.cpp b/src/armnn/test/LayerValidateOutputTest.cpp
index acefd51110..d47959cb65 100644
--- a/src/armnn/test/LayerValidateOutputTest.cpp
+++ b/src/armnn/test/LayerValidateOutputTest.cpp
@@ -58,4 +58,27 @@ BOOST_AUTO_TEST_CASE(TestSpaceToDepthInferOutputShape)
BOOST_CHECK(expectedShape == spaceToDepthLayer->InferOutputShapes(shapes).at(0));
}
+BOOST_AUTO_TEST_CASE(TestPreluInferOutputShape)
+{
+ armnn::Graph graph;
+
+ armnn::PreluLayer* const preluLayer = graph.AddLayer<armnn::PreluLayer>("prelu");
+
+ std::vector<armnn::TensorShape> inputShapes
+ {
+ { 4, 1, 2 }, // Input shape
+ { 5, 4, 3, 1} // Alpha shape
+ };
+
+ const std::vector<armnn::TensorShape> expectedOutputShapes
+ {
+ { 5, 4, 3, 2 } // Output shape
+ };
+
+ const std::vector<armnn::TensorShape> outputShapes = preluLayer->InferOutputShapes(inputShapes);
+
+ BOOST_CHECK(outputShapes.size() == 1);
+ BOOST_CHECK(outputShapes[0] == expectedOutputShapes[0]);
+}
+
BOOST_AUTO_TEST_SUITE_END()