aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/test/InputOutputTensorNames.cpp
diff options
context:
space:
mode:
authortelsoa01 <telmo.soares@arm.com>2018-08-31 09:22:23 +0100
committertelsoa01 <telmo.soares@arm.com>2018-08-31 09:22:23 +0100
commitc577f2c6a3b4ddb6ba87a882723c53a248afbeba (patch)
treebd7d4c148df27f8be6649d313efb24f536b7cf34 /src/armnnTfLiteParser/test/InputOutputTensorNames.cpp
parent4c7098bfeab1ffe1cdc77f6c15548d3e73274746 (diff)
downloadarmnn-c577f2c6a3b4ddb6ba87a882723c53a248afbeba.tar.gz
Release 18.08
Diffstat (limited to 'src/armnnTfLiteParser/test/InputOutputTensorNames.cpp')
-rw-r--r--src/armnnTfLiteParser/test/InputOutputTensorNames.cpp138
1 files changed, 138 insertions, 0 deletions
diff --git a/src/armnnTfLiteParser/test/InputOutputTensorNames.cpp b/src/armnnTfLiteParser/test/InputOutputTensorNames.cpp
new file mode 100644
index 0000000000..fc88a4e58d
--- /dev/null
+++ b/src/armnnTfLiteParser/test/InputOutputTensorNames.cpp
@@ -0,0 +1,138 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+
+#include <boost/test/unit_test.hpp>
+#include "ParserFlatbuffersFixture.hpp"
+#include "../TfLiteParser.hpp"
+
+BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
+
+struct EmptyNetworkFixture : public ParserFlatbuffersFixture
+{
+ explicit EmptyNetworkFixture() {
+ m_JsonString = R"(
+ {
+ "version": 3,
+ "operator_codes": [],
+ "subgraphs": [ {} ]
+ })";
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(EmptyNetworkHasNoInputsAndOutputs, EmptyNetworkFixture)
+{
+ Setup();
+ BOOST_TEST(m_Parser->GetSubgraphCount() == 1);
+ BOOST_TEST(m_Parser->GetSubgraphInputTensorNames(0).size() == 0);
+ BOOST_TEST(m_Parser->GetSubgraphOutputTensorNames(0).size() == 0);
+}
+
+struct MissingTensorsFixture : public ParserFlatbuffersFixture
+{
+ explicit MissingTensorsFixture()
+ {
+ m_JsonString = R"(
+ {
+ "version": 3,
+ "operator_codes": [],
+ "subgraphs": [{
+ "inputs" : [ 0, 1 ],
+ "outputs" : [ 2, 3 ],
+ }]
+ })";
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(MissingTensorsThrowException, MissingTensorsFixture)
+{
+ // this throws because it cannot do the input output tensor connections
+ BOOST_CHECK_THROW(Setup(), armnn::ParseException);
+}
+
+struct InvalidTensorsFixture : public ParserFlatbuffersFixture
+{
+ explicit InvalidTensorsFixture()
+ {
+ m_JsonString = R"(
+ {
+ "version": 3,
+ "operator_codes": [ ],
+ "subgraphs": [{
+ "tensors": [ {}, {}, {}, {} ],
+ "inputs" : [ 0, 1 ],
+ "outputs" : [ 2, 3 ],
+ }]
+ })";
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(InvalidTensorsThrowException, InvalidTensorsFixture)
+{
+ // this throws because it cannot do the input output tensor connections
+ BOOST_CHECK_THROW(Setup(), armnn::InvalidArgumentException);
+}
+
+struct ValidTensorsFixture : public ParserFlatbuffersFixture
+{
+ explicit ValidTensorsFixture()
+ {
+ m_JsonString = R"(
+ {
+ "version": 3,
+ "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" } ],
+ "subgraphs": [{
+ "tensors": [ {
+ "shape": [ 1, 1, 1, 1 ],
+ "type": "FLOAT32",
+ "name": "In",
+ "buffer": 0,
+ }, {
+ "shape": [ 1, 1, 1, 1 ],
+ "type": "FLOAT32",
+ "name": "Out",
+ "buffer": 1,
+ }],
+ "inputs" : [ 0 ],
+ "outputs" : [ 1 ],
+ "operators": [{
+ "opcode_index": 0,
+ "inputs": [ 0 ],
+ "outputs": [ 1 ],
+ "builtin_options_type": "Pool2DOptions",
+ "builtin_options":
+ {
+ "padding": "VALID",
+ "stride_w": 1,
+ "stride_h": 1,
+ "filter_width": 1,
+ "filter_height": 1,
+ "fused_activation_function": "NONE"
+ },
+ "custom_options_format": "FLEXBUFFERS"
+ }]
+ }]
+ })";
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(GetValidInputOutputTensorNames, ValidTensorsFixture)
+{
+ Setup();
+ BOOST_CHECK_EQUAL(m_Parser->GetSubgraphInputTensorNames(0).size(), 1u);
+ BOOST_CHECK_EQUAL(m_Parser->GetSubgraphOutputTensorNames(0).size(), 1u);
+ BOOST_CHECK_EQUAL(m_Parser->GetSubgraphInputTensorNames(0)[0], "In");
+ BOOST_CHECK_EQUAL(m_Parser->GetSubgraphOutputTensorNames(0)[0], "Out");
+}
+
+BOOST_FIXTURE_TEST_CASE(ThrowIfSubgraphIdInvalidForInOutNames, ValidTensorsFixture)
+{
+ Setup();
+
+ // these throw because of the invalid subgraph id
+ BOOST_CHECK_THROW(m_Parser->GetSubgraphInputTensorNames(1), armnn::ParseException);
+ BOOST_CHECK_THROW(m_Parser->GetSubgraphOutputTensorNames(1), armnn::ParseException);
+}
+
+BOOST_AUTO_TEST_SUITE_END()