aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/test/Split.cpp
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2018-12-27 11:23:44 +0000
committerSaoirse Stewart Arm <saoirse.stewart@arm.com>2019-01-07 10:42:57 +0000
commit2ad6cb486164ff3aabe4e9ecabc47f08da48da35 (patch)
tree57ad464aa77179d9e93d7e0c26830d67464667a6 /src/armnnTfParser/test/Split.cpp
parent747ef82c88f9afe14a8b80b6b3b34118353e97f2 (diff)
downloadarmnn-2ad6cb486164ff3aabe4e9ecabc47f08da48da35.tar.gz
IVGCVSW-2384 Add Split parser function to Tensor flow parser
* Added Unit test * Updated TensorFlowSupport.md file Change-Id: I5f07de5e91ffb681c0ad17c7c73ee0326e7f1e0a
Diffstat (limited to 'src/armnnTfParser/test/Split.cpp')
-rw-r--r--src/armnnTfParser/test/Split.cpp114
1 files changed, 114 insertions, 0 deletions
diff --git a/src/armnnTfParser/test/Split.cpp b/src/armnnTfParser/test/Split.cpp
new file mode 100644
index 0000000000..de6b5d861e
--- /dev/null
+++ b/src/armnnTfParser/test/Split.cpp
@@ -0,0 +1,114 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+struct SplitFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ SplitFixture() {
+ m_Prototext =
+ "node { \n"
+ " name: \"graphInput\" \n"
+ " op: \"Placeholder\" \n"
+ " attr { \n"
+ " key: \"dtype\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"shape\" \n"
+ " value { \n"
+ " shape { \n"
+ " } \n"
+ " } \n"
+ " } \n"
+ " } \n"
+ " node {"
+ " name: \"splitInput\" \n"
+ " op: \"Const\" \n"
+ "attr {\n"
+ " key: \"dtype\" \n"
+ " value {"
+ " type: DT_INT32"
+ " }"
+ "}"
+ "attr {"
+ " key: \"value\"\n"
+ " value { "
+ " tensor {"
+ " dtype: DT_INT32"
+ " tensor_shape {"
+ "}"
+ "int_val: 1"
+ "}"
+ "}"
+ "}"
+ "}"
+ "node { \n"
+ " name: \"Split\" \n"
+ " op: \"Split\" \n"
+ "input: \"graphInput\"\n"
+ "input: \"splitInput\"\n"
+ "attr { \n "
+ "key: \"T\"\n"
+ "value {\n"
+ "type: DT_FLOAT\n"
+ " }\n"
+ "}\n"
+ "\n"
+ " attr { \n"
+ " key: \"num_or_size_splits\" \n"
+ " value { \n"
+ " i:2 \n "
+ " } \n"
+ " } \n"
+ "} \n"
+ "node { \n"
+ "name: \"Relu_1\"\n"
+ "op: \"Relu\"\n"
+ "input: \"Split:0\"\n"
+ "attr { \n "
+ "key: \"T\"\n"
+ "value {\n"
+ "type: DT_FLOAT\n"
+ " }\n"
+ "}\n"
+ "}\n"
+ "node { \n"
+ "name: \"Relu_2\"\n"
+ "op: \"Relu\"\n"
+ "input: \"Split:1\"\n"
+ "attr { \n "
+ "key: \"T\"\n"
+ "value {\n"
+ "type: DT_FLOAT\n"
+ " }\n"
+ "}\n"
+ "}\n";
+
+ Setup( { { "graphInput", { 1, 2, 2 , 2} } },
+ { "Relu_1", "Relu_2" });
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitTwo, SplitFixture)
+{
+ BOOST_TEST(
+ (m_Parser->GetNetworkOutputBindingInfo("Relu_1").second.GetShape() == armnn::TensorShape({ 1, 1, 2, 2 })));
+
+ BOOST_TEST(
+ (m_Parser->GetNetworkOutputBindingInfo("Relu_2").second.GetShape() == armnn::TensorShape({ 1, 1, 2, 2 })));
+
+ RunTest<4>({ { "graphInput", { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f, 1.75f } } },
+ { { "Relu_1", { 0.0f, 0.0f, 1.25f, 0.0f } },
+ { "Relu_2", { 0.0f, 0.5f, 0.0f, 1.75f } } });
+}
+
+BOOST_AUTO_TEST_SUITE_END()