diff options
Diffstat (limited to 'src/armnn/test/InferOutputTests.hpp')
-rw-r--r-- | src/armnn/test/InferOutputTests.hpp | 48 |
1 files changed, 47 insertions, 1 deletions
diff --git a/src/armnn/test/InferOutputTests.hpp b/src/armnn/test/InferOutputTests.hpp index 58a081a130..2dd2ff0e73 100644 --- a/src/armnn/test/InferOutputTests.hpp +++ b/src/armnn/test/InferOutputTests.hpp @@ -443,4 +443,50 @@ void DepthwiseConvolution2dInferOutputShapeTest() armnn::TensorShape expectedOutputShape(4, expectedOutputSizes.data()); BOOST_CHECK(expectedOutputShape == depthwiseConvolution2dLayer->InferOutputShapes(shapes).at(0)); -}
\ No newline at end of file +} + +// QuantizedLstm +void QuantizedLstmInferOutputShapeImpl(const std::vector<armnn::TensorShape>& inputShapes, + std::vector<armnn::TensorShape>& outputShapes) +{ + armnn::Graph graph; + armnn::QuantizedLstmLayer* const quantizedLstmLayer = graph.AddLayer<armnn::QuantizedLstmLayer>("quantizedLstm"); + outputShapes = quantizedLstmLayer->InferOutputShapes(inputShapes); +} + +void QuantizedLstmInferOutputShapeTest() +{ + // Input shapes + const std::vector<unsigned int> inputShape{ 2, 5 }; + const std::vector<unsigned int> previousCellStateInShape{ 2, 10 }; + const std::vector<unsigned int> previousOutputInShape{ 2, 10 }; + armnn::TensorShape inputTensorShape(2, inputShape.data()); + armnn::TensorShape previousCellStateInTensorShape(2, previousCellStateInShape.data()); + armnn::TensorShape previousOutputInTensorShape(2, previousOutputInShape.data()); + + std::vector<armnn::TensorShape> inShapes + { + inputTensorShape, + previousCellStateInTensorShape, + previousOutputInTensorShape + }; + + // Output shapes + const std::vector<unsigned int> cellStateOutShape{ 2, 10 }; + const std::vector<unsigned int> outputShape{ 2, 10 }; + armnn::TensorShape cellStateOutTensorShape(2, cellStateOutShape.data()); + armnn::TensorShape outputTensorShape(2, outputShape.data()); + + std::vector<armnn::TensorShape> expectedOutShapes + { + cellStateOutTensorShape, + outputTensorShape + }; + + std::vector<armnn::TensorShape> actualOutShapes; + BOOST_CHECK_NO_THROW(QuantizedLstmInferOutputShapeImpl(inShapes, actualOutShapes)); + + BOOST_CHECK(actualOutShapes.size() == 2); + BOOST_CHECK(expectedOutShapes[0] == actualOutShapes[0]); + BOOST_CHECK(expectedOutShapes[1] == actualOutShapes[1]); +} |