aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/test/OutputShapeOfSqueeze.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfLiteParser/test/OutputShapeOfSqueeze.cpp')
-rw-r--r--src/armnnTfLiteParser/test/OutputShapeOfSqueeze.cpp61
1 files changed, 61 insertions, 0 deletions
diff --git a/src/armnnTfLiteParser/test/OutputShapeOfSqueeze.cpp b/src/armnnTfLiteParser/test/OutputShapeOfSqueeze.cpp
new file mode 100644
index 0000000000..590675b46c
--- /dev/null
+++ b/src/armnnTfLiteParser/test/OutputShapeOfSqueeze.cpp
@@ -0,0 +1,61 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+
+#include <boost/test/unit_test.hpp>
+#include "../TfLiteParser.hpp"
+#include <iostream>
+#include <string>
+
+struct TfLiteParserFixture
+{
+
+ armnnTfLiteParser::TfLiteParser m_Parser;
+ unsigned int m_InputShape[4];
+
+ TfLiteParserFixture() : m_Parser( ), m_InputShape { 1, 2, 2, 1 } {
+ m_Parser.Create();
+ }
+ ~TfLiteParserFixture() { }
+
+};
+
+BOOST_AUTO_TEST_SUITE(TensorflowLiteParser);
+
+
+BOOST_FIXTURE_TEST_CASE( EmptySqueezeDims_OutputWithAllDimensionsSqueezed, TfLiteParserFixture )
+{
+
+ std::vector<uint32_t> squeezeDims = { };
+
+ armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32);
+ armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo);
+ BOOST_TEST(outputTensorInfo.GetNumElements() == 4);
+ BOOST_TEST(outputTensorInfo.GetNumDimensions() == 2);
+ BOOST_TEST((outputTensorInfo.GetShape() == armnn::TensorShape({ 2, 2 })));
+};
+
+BOOST_FIXTURE_TEST_CASE( SqueezeDimsNotIncludingSizeOneDimensions_NoDimensionsSqueezedInOutput, TfLiteParserFixture )
+{
+ std::vector<uint32_t> squeezeDims = { 1, 2 };
+
+ armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32);
+ armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo);
+ BOOST_TEST(outputTensorInfo.GetNumElements() == 4);
+ BOOST_TEST(outputTensorInfo.GetNumDimensions() == 4);
+ BOOST_TEST((outputTensorInfo.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
+};
+
+BOOST_FIXTURE_TEST_CASE( SqueezeDimsRangePartial_OutputWithDimensionsWithinRangeSqueezed, TfLiteParserFixture )
+{
+ std::vector<uint32_t> squeezeDims = { 1, 3 };
+
+ armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32);
+ armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo);
+ BOOST_TEST(outputTensorInfo.GetNumElements() == 4);
+ BOOST_TEST(outputTensorInfo.GetNumDimensions() == 3);
+ BOOST_TEST((outputTensorInfo.GetShape() == armnn::TensorShape({ 1, 2, 2 })));
+};
+
+BOOST_AUTO_TEST_SUITE_END(); \ No newline at end of file