diff options
Diffstat (limited to 'src/armnn/test/InferOutputTests.hpp')
-rw-r--r-- | src/armnn/test/InferOutputTests.hpp | 58 |
1 files changed, 57 insertions, 1 deletions
diff --git a/src/armnn/test/InferOutputTests.hpp b/src/armnn/test/InferOutputTests.hpp index b03449b568..70afbc9b3f 100644 --- a/src/armnn/test/InferOutputTests.hpp +++ b/src/armnn/test/InferOutputTests.hpp @@ -7,7 +7,6 @@ #include "TestUtils.hpp" - #include <Graph.hpp> #include <layers/ArgMinMaxLayer.hpp> #include <layers/BatchToSpaceNdLayer.hpp> @@ -530,6 +529,63 @@ void DepthwiseConvolution2dInferOutputShapeTest() BOOST_CHECK(expectedOutputShape == depthwiseConvolution2dLayer->InferOutputShapes(shapes).at(0)); } +// QLstm +void QLstmInferOutputShapeImpl(const armnn::QLstmDescriptor descriptor, + const std::vector<armnn::TensorShape>& inputShapes, + std::vector<armnn::TensorShape>& outputShapes) +{ + armnn::Graph graph; + armnn::QLstmLayer* const qLstmLayer = graph.AddLayer<armnn::QLstmLayer>(descriptor, "qLstm"); + outputShapes = qLstmLayer->InferOutputShapes(inputShapes); +} + +void QLstmInferOutputShapeTest() +{ + armnn::QLstmDescriptor descriptor; + descriptor.m_PeepholeEnabled = true; + descriptor.m_CifgEnabled = false; + descriptor.m_ProjectionEnabled = false; + + // Input shapes + const std::vector<unsigned int> inputShape{ 2, 5 }; + const std::vector<unsigned int> previousOutputInShape{ 2, 4 }; + const std::vector<unsigned int> previousCellStateInShape{ 2, 4 }; + + armnn::TensorShape inputTensorShape(2, inputShape.data()); + armnn::TensorShape previousOutputInTensorShape(2, previousOutputInShape.data()); + armnn::TensorShape previousCellStateInTensorShape(2, previousCellStateInShape.data()); + + std::vector<armnn::TensorShape> inShapes + { + inputTensorShape, + previousOutputInTensorShape, + previousCellStateInTensorShape + }; + + // Output shapes + const std::vector<unsigned int> outputStateOutShape{ 2, 4 }; + const std::vector<unsigned int> cellStateOutShape{ 2, 4 }; + const std::vector<unsigned int> outputShape{ 2, 4 }; + armnn::TensorShape outputStateOutTensorShape(2, outputShape.data()); + armnn::TensorShape cellStateOutTensorShape(2, cellStateOutShape.data()); + armnn::TensorShape outputTensorShape(2, outputShape.data()); + + std::vector<armnn::TensorShape> expectedOutShapes + { + outputStateOutTensorShape, + cellStateOutTensorShape, + outputTensorShape + }; + + std::vector<armnn::TensorShape> actualOutShapes; + BOOST_CHECK_NO_THROW(QLstmInferOutputShapeImpl(descriptor, inShapes, actualOutShapes)); + + BOOST_CHECK(actualOutShapes.size() == 3); + BOOST_CHECK(expectedOutShapes[0] == actualOutShapes[0]); + BOOST_CHECK(expectedOutShapes[1] == actualOutShapes[1]); + BOOST_CHECK(expectedOutShapes[2] == actualOutShapes[2]); +} + // QuantizedLstm void QuantizedLstmInferOutputShapeImpl(const std::vector<armnn::TensorShape>& inputShapes, std::vector<armnn::TensorShape>& outputShapes) |