aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/test/Split.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfParser/test/Split.cpp')
-rw-r--r--src/armnnTfParser/test/Split.cpp226
1 files changed, 144 insertions, 82 deletions
diff --git a/src/armnnTfParser/test/Split.cpp b/src/armnnTfParser/test/Split.cpp
index de6b5d861e..87cd6544c9 100644
--- a/src/armnnTfParser/test/Split.cpp
+++ b/src/armnnTfParser/test/Split.cpp
@@ -11,93 +11,140 @@ BOOST_AUTO_TEST_SUITE(TensorflowParser)
struct SplitFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
{
- SplitFixture() {
- m_Prototext =
- "node { \n"
- " name: \"graphInput\" \n"
- " op: \"Placeholder\" \n"
- " attr { \n"
- " key: \"dtype\" \n"
- " value { \n"
- " type: DT_FLOAT \n"
- " } \n"
- " } \n"
- " attr { \n"
- " key: \"shape\" \n"
- " value { \n"
- " shape { \n"
- " } \n"
- " } \n"
- " } \n"
- " } \n"
- " node {"
- " name: \"splitInput\" \n"
- " op: \"Const\" \n"
- "attr {\n"
- " key: \"dtype\" \n"
- " value {"
- " type: DT_INT32"
- " }"
- "}"
- "attr {"
- " key: \"value\"\n"
- " value { "
- " tensor {"
- " dtype: DT_INT32"
- " tensor_shape {"
- "}"
- "int_val: 1"
- "}"
- "}"
- "}"
- "}"
- "node { \n"
- " name: \"Split\" \n"
- " op: \"Split\" \n"
- "input: \"graphInput\"\n"
- "input: \"splitInput\"\n"
- "attr { \n "
- "key: \"T\"\n"
- "value {\n"
- "type: DT_FLOAT\n"
- " }\n"
- "}\n"
- "\n"
- " attr { \n"
- " key: \"num_or_size_splits\" \n"
- " value { \n"
- " i:2 \n "
- " } \n"
- " } \n"
- "} \n"
- "node { \n"
- "name: \"Relu_1\"\n"
- "op: \"Relu\"\n"
- "input: \"Split:0\"\n"
- "attr { \n "
- "key: \"T\"\n"
- "value {\n"
- "type: DT_FLOAT\n"
- " }\n"
- "}\n"
- "}\n"
- "node { \n"
- "name: \"Relu_2\"\n"
- "op: \"Relu\"\n"
- "input: \"Split:1\"\n"
- "attr { \n "
- "key: \"T\"\n"
- "value {\n"
- "type: DT_FLOAT\n"
- " }\n"
- "}\n"
- "}\n";
+ SplitFixture(bool withDimZero=false) {
+ m_Prototext = R"(
+ node {
+ name: "graphInput"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+ }
+ node {
+ name: "graphInput2"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+ }
+ node {
+ name: "multiplication"
+ op : "Mul"
+ input: "graphInput"
+ input: "graphInput2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node {
+ name: "SplitInput"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: )";
- Setup( { { "graphInput", { 1, 2, 2 , 2} } },
+ if(withDimZero)
+ {
+ m_Prototext += std::to_string(3);
+ }
+ else
+ {
+ m_Prototext += std::to_string(1);
+ }
+
+ m_Prototext += R"(
+ }
+ }
+ }
+ }
+ node {
+ name: "Split"
+ op: "Split" )";
+ if(withDimZero)
+ {
+ m_Prototext += "input: \"SplitInput\"\n";
+ m_Prototext += "input: \"multiplication\"\n";
+ }
+ else
+ {
+ m_Prototext += "input: \"graphInput\"\n";
+ m_Prototext += "input: \"SplitInput\"\n";
+ }
+ m_Prototext += R"(
+ attr {
+ key: "num_or_size_splits"
+ value {
+ i: 2
+ }
+ }
+ }
+ node {
+ name: "Relu_1"
+ op: "Relu"
+ input: "Split:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node {
+ name: "Relu_2"
+ op: "Relu"
+ input:"Split:1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ } )";
+
+ Setup( { { "graphInput", { 1, 2, 2 , 2} } , { "graphInput2", { 1, 2, 2 , 2} }},
{ "Relu_1", "Relu_2" });
}
};
+struct InputFirstSplitFixture : SplitFixture
+{
+ InputFirstSplitFixture() : SplitFixture(true) {}
+};
+
BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitTwo, SplitFixture)
{
BOOST_TEST(
@@ -111,4 +158,19 @@ BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitTwo, SplitFixture)
{ "Relu_2", { 0.0f, 0.5f, 0.0f, 1.75f } } });
}
+BOOST_FIXTURE_TEST_CASE(ParseSplit, InputFirstSplitFixture)
+{
+
+ BOOST_TEST(
+ (m_Parser->GetNetworkOutputBindingInfo("Relu_1").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
+
+ BOOST_TEST(
+ (m_Parser->GetNetworkOutputBindingInfo("Relu_2").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
+
+ RunTest<4>({ { "graphInput", { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f , 1.75f } } ,
+ { "graphInput2", { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f , 1.75f } } },
+ { { "Relu_1", { 1.0f, 1.5625f, 0, 0.5625f } },
+ { "Relu_2", { 0.25, 9.0f, 0.25f, 3.0625f } } });
+}
+
BOOST_AUTO_TEST_SUITE_END()