aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/test/ConcatOfConcats.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfParser/test/ConcatOfConcats.cpp')
-rw-r--r--src/armnnTfParser/test/ConcatOfConcats.cpp316
1 files changed, 316 insertions, 0 deletions
diff --git a/src/armnnTfParser/test/ConcatOfConcats.cpp b/src/armnnTfParser/test/ConcatOfConcats.cpp
new file mode 100644
index 0000000000..7316b9f1ac
--- /dev/null
+++ b/src/armnnTfParser/test/ConcatOfConcats.cpp
@@ -0,0 +1,316 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+struct ConcatOfConcatsFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ explicit ConcatOfConcatsFixture(const armnn::TensorShape& inputShape0, const armnn::TensorShape& inputShape1,
+ const armnn::TensorShape& inputShape2, const armnn::TensorShape& inputShape3,
+ unsigned int concatDim)
+ {
+ m_Prototext = R"(
+ node {
+ name: "graphInput0"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+ }
+ node {
+ name: "graphInput1"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+ }
+ node {
+ name: "graphInput2"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+ }
+ node {
+ name: "graphInput3"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+ }
+ node {
+ name: "Relu"
+ op: "Relu"
+ input: "graphInput0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node {
+ name: "Relu_1"
+ op: "Relu"
+ input: "graphInput1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node {
+ name: "Relu_2"
+ op: "Relu"
+ input: "graphInput2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node {
+ name: "Relu_3"
+ op: "Relu"
+ input: "graphInput3"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node {
+ name: "concat/axis"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: )";
+ m_Prototext += std::to_string(concatDim);
+ m_Prototext += R"(
+ }
+ }
+ }
+ }
+ node {
+ name: "concat"
+ op: "ConcatV2"
+ input: "Relu"
+ input: "Relu_1"
+ input: "concat/axis"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ node {
+ name: "concat_1/axis"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: )";
+ m_Prototext += std::to_string(concatDim);
+ m_Prototext += R"(
+ }
+ }
+ }
+ }
+ node {
+ name: "concat_1"
+ op: "ConcatV2"
+ input: "Relu_2"
+ input: "Relu_3"
+ input: "concat_1/axis"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ node {
+ name: "concat_2/axis"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: )";
+ m_Prototext += std::to_string(concatDim);
+ m_Prototext += R"(
+ }
+ }
+ }
+ }
+ node {
+ name: "concat_2"
+ op: "ConcatV2"
+ input: "concat"
+ input: "concat_1"
+ input: "concat_2/axis"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ )";
+
+ Setup({{ "graphInput0", inputShape0 },
+ { "graphInput1", inputShape1 },
+ { "graphInput2", inputShape2 },
+ { "graphInput3", inputShape3}}, {"concat_2"});
+ }
+};
+
+struct ConcatOfConcatsFixtureNCHW : ConcatOfConcatsFixture
+{
+ ConcatOfConcatsFixtureNCHW() : ConcatOfConcatsFixture({ 1, 1, 2, 2 }, { 1, 1, 2, 2 }, { 1, 1, 2, 2 },
+ { 1, 1, 2, 2 }, 1 ) {}
+};
+
+struct ConcatOfConcatsFixtureNHWC : ConcatOfConcatsFixture
+{
+ ConcatOfConcatsFixtureNHWC() : ConcatOfConcatsFixture({ 1, 1, 2, 2 }, { 1, 1, 2, 2 }, { 1, 1, 2, 2 },
+ { 1, 1, 2, 2 }, 3 ) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseConcatOfConcatsNCHW, ConcatOfConcatsFixtureNCHW)
+{
+ RunTest<4>({{"graphInput0", {0.0, 1.0, 2.0, 3.0}},
+ {"graphInput1", {4.0, 5.0, 6.0, 7.0}},
+ {"graphInput2", {8.0, 9.0, 10.0, 11.0}},
+ {"graphInput3", {12.0, 13.0, 14.0, 15.0}}},
+ {{"concat_2", { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
+ 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0 }}});
+}
+
+BOOST_FIXTURE_TEST_CASE(ParseConcatOfConcatsNHWC, ConcatOfConcatsFixtureNHWC)
+{
+ RunTest<4>({{"graphInput0", {0.0, 1.0, 2.0, 3.0}},
+ {"graphInput1", {4.0, 5.0, 6.0, 7.0}},
+ {"graphInput2", {8.0, 9.0, 10.0, 11.0}},
+ {"graphInput3", {12.0, 13.0, 14.0, 15.0}}},
+ {{"concat_2", { 0.0, 1.0, 4.0, 5.0, 8.0, 9.0, 12.0, 13.0,
+ 2.0, 3.0, 6.0, 7.0, 10.0, 11.0, 14.0, 15.0 }}});
+}
+
+BOOST_AUTO_TEST_SUITE_END()