aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/test/Convolution2d.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfParser/test/Convolution2d.cpp')
-rw-r--r--src/armnnTfParser/test/Convolution2d.cpp306
1 files changed, 185 insertions, 121 deletions
diff --git a/src/armnnTfParser/test/Convolution2d.cpp b/src/armnnTfParser/test/Convolution2d.cpp
index cf714894a2..c58615f990 100644
--- a/src/armnnTfParser/test/Convolution2d.cpp
+++ b/src/armnnTfParser/test/Convolution2d.cpp
@@ -37,7 +37,22 @@ struct Convolution2dFixture : public armnnUtils::ParserPrototxtFixture<armnnTfPa
" i: " + std::to_string(stride) + " \n");
}
- std::string dilationString = std::to_string(dilation);
+ std::string dilationString;
+ if (dataLayout == "NHWC")
+ {
+ dilationString.append(" i: 1 \n"
+ " i: " + std::to_string(dilation) + " \n"
+ " i: " + std::to_string(dilation) + " \n"
+ " i: 1 \n");
+ }
+ else // dataLayout == "NCHW"
+ {
+ dilationString.append(" i: 1 \n"
+ " i: 1 \n"
+ " i: " + std::to_string(dilation) + " \n"
+ " i: " + std::to_string(dilation) + " \n");
+ }
+
m_Prototext = "node { \n"
" name: \"graphInput\" \n"
" op: \"Placeholder\" \n"
@@ -130,16 +145,10 @@ struct Convolution2dFixture : public armnnUtils::ParserPrototxtFixture<armnnTfPa
m_Prototext.append(" attr { \n"
" key: \"dilations\" \n"
" value { \n"
- " list { \n"
- " i: 1 \n"
- " i: ");
- m_Prototext.append(dilationString);
- m_Prototext.append(" \n"
- " i: ");
+ " list { \n");
m_Prototext.append(dilationString);
- m_Prototext.append(" \n"
- " i: 1 \n"
- " } \n"
+
+ m_Prototext.append(" } \n"
" } \n"
" } \n");
}
@@ -167,7 +176,6 @@ struct Convolution2dFixture : public armnnUtils::ParserPrototxtFixture<armnnTfPa
}
};
-
struct Convolution2dNhwcSameFixture : Convolution2dFixture
{
Convolution2dNhwcSameFixture() : Convolution2dFixture("NHWC", "SAME", 1){}
@@ -262,118 +270,174 @@ BOOST_FIXTURE_TEST_CASE(ParseConv2dDilation1Nchw, Convolution2dDilation1NchwFixt
RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
}
+struct Convolution2dDilationFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ explicit Convolution2dDilationFixture(const std::string& dataLayout, const std::string& paddingType)
+ : Convolution2dDilationFixture(dataLayout, paddingType, 1)
+ {}
+
+ explicit Convolution2dDilationFixture(const std::string& dataLayout, const std::string& paddingType,
+ int stride, int dilation = 0)
+ {
+ std::string strideString;
+ if (dataLayout == "NHWC")
+ {
+ strideString.append(" i: 1 \n"
+ " i: " + std::to_string(stride) + " \n"
+ " i: " + std::to_string(stride) + " \n"
+ " i: 1 \n");
+ }
+ else // dataLayout == "NCHW"
+ {
+ strideString.append(" i: 1 \n"
+ " i: 1 \n"
+ " i: " + std::to_string(stride) + " \n"
+ " i: " + std::to_string(stride) + " \n");
+ }
+
+ std::string dilationString;
+ if (dataLayout == "NHWC")
+ {
+ dilationString.append(" i: 1 \n"
+ " i: " + std::to_string(dilation) + " \n"
+ " i: " + std::to_string(dilation) + " \n"
+ " i: 1 \n");
+ }
+ else // dataLayout == "NCHW"
+ {
+ dilationString.append(" i: 1 \n"
+ " i: 1 \n"
+ " i: " + std::to_string(dilation) + " \n"
+ " i: " + std::to_string(dilation) + " \n");
+ }
-BOOST_AUTO_TEST_CASE(ParseConv2dDilation2)
+ m_Prototext = "node { \n"
+ " name: \"graphInput\" \n"
+ " op: \"Placeholder\" \n"
+ " attr { \n"
+ " key: \"dtype\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"shape\" \n"
+ " value { \n"
+ " shape { \n"
+ " } \n"
+ " } \n"
+ " } \n"
+ " } \n"
+ " node { \n"
+ " name: \"Const_1\" \n"
+ " op: \"Const\" \n"
+ " attr { \n"
+ " key: \"dtype\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"value\" \n"
+ " value { \n"
+ " tensor { \n"
+ " dtype: DT_FLOAT \n"
+ " tensor_shape { \n"
+ " dim { \n"
+ " size: 3 \n"
+ " } \n"
+ " dim { \n"
+ " size: 1 \n"
+ " } \n"
+ " dim { \n"
+ " size: 1 \n"
+ " } \n"
+ " dim { \n"
+ " size: 1 \n"
+ " } \n"
+ " } \n"
+ " tensor_content: \"\\001\\000\\000?\\000\\000\\000?\\001\\000\\000?\" \n"
+ " } \n"
+ " } \n"
+ " } \n"
+ "} \n"
+ "node { \n"
+ " name: \"potato\" \n"
+ " op: \"Conv2D\" \n"
+ " input: \"graphInput\" \n"
+ " input: \"Const_1\" \n"
+ " attr { \n"
+ " key: \"T\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"data_format\" \n"
+ " value { \n"
+ " s: \"";
+ m_Prototext.append(dataLayout);
+ m_Prototext.append("\"\n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"padding\" \n"
+ " value { \n"
+ " s: \"");
+ m_Prototext.append(paddingType);
+ m_Prototext.append("\"\n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"strides\" \n"
+ " value { \n"
+ " list { \n");
+ m_Prototext.append(strideString);
+
+ m_Prototext.append(" } \n"
+ " } \n"
+ " } \n");
+
+ if (dilation > 0)
+ {
+ m_Prototext.append(" attr { \n"
+ " key: \"dilations\" \n"
+ " value { \n"
+ " list { \n");
+ m_Prototext.append(dilationString);
+
+ m_Prototext.append(" } \n"
+ " } \n"
+ " } \n");
+ }
+ m_Prototext.append(" attr { \n"
+ " key: \"use_cudnn_on_gpu\" \n"
+ " value { \n"
+ " b: false \n"
+ " } \n"
+ " } \n"
+ "} \n");
+
+ // Manual height computation based on stride parameter.
+ std::array<unsigned int, 4> dims = { 1u, 1u, 6u, 6u };;
+
+ SetupSingleInputSingleOutput(armnn::TensorShape(4, dims.data()), "graphInput", "potato");
+ }
+};
+
+struct Convolution2dDilation2NchwValidFixture : Convolution2dDilationFixture
+{
+ Convolution2dDilation2NchwValidFixture() : Convolution2dDilationFixture("NCHW", "VALID", 1, 2){}
+};
+BOOST_FIXTURE_TEST_CASE(ParseConv2dDilation2NchwValid, Convolution2dDilation2NchwValidFixture)
{
- const char* prototext = ""
- "node {\n"
- " name: \"graphInput\"\n"
- " op: \"Placeholder\"\n"
- " attr {\n"
- " key: \"dtype\"\n"
- " value {\n"
- " type: DT_FLOAT\n"
- " }\n"
- " }\n"
- " attr {\n"
- " key: \"shape\"\n"
- " value {\n"
- " shape {\n"
- " }\n"
- " }\n"
- " }\n"
- "}\n"
- "node {\n"
- " name: \"Const_1\"\n"
- " op: \"Const\"\n"
- " attr {\n"
- " key: \"dtype\"\n"
- " value {\n"
- " type: DT_FLOAT\n"
- " }\n"
- " }\n"
- " attr {\n"
- " key: \"value\"\n"
- " value {\n"
- " tensor {\n"
- " dtype: DT_FLOAT\n"
- " tensor_shape {\n"
- " dim {\n"
- " size: 1\n"
- " }\n"
- " dim {\n"
- " size: 3\n"
- " }\n"
- " dim {\n"
- " size: 1\n"
- " }\n"
- " dim {\n"
- " size: 1\n"
- " }\n"
- " }\n"
- " tensor_content: \"\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?\"\n"
- " }\n"
- " }\n"
- " }\n"
- "}\n"
- "node {\n"
- " name: \"potato\"\n"
- " op: \"Conv2D\"\n"
- " input: \"graphInput\"\n"
- " input: \"Const_1\"\n"
- " attr {\n"
- " key: \"T\"\n"
- " value {\n"
- " type: DT_FLOAT\n"
- " }\n"
- " }\n"
- " attr {\n"
- " key: \"data_format\"\n"
- " value {\n"
- " s: \"NHWC\"\n"
- " }\n"
- " }\n"
- " attr {\n"
- " key: \"padding\"\n"
- " value {\n"
- " s: \"SAME\"\n"
- " }\n"
- " }\n"
- " attr {\n"
- " key: \"strides\"\n"
- " value {\n"
- " list {\n"
- " i: 1\n"
- " i: 1\n"
- " i: 1\n"
- " i: 1\n"
- " }\n"
- " }\n"
- " }\n"
- " attr {\n"
- " key: \"dilations\"\n"
- " value {\n"
- " list {\n"
- " i: 1\n"
- " i: 2\n"
- " i: 2\n"
- " i: 1\n"
- " }\n"
- " }\n"
- " }\n"
- " attr {\n"
- " key: \"use_cudnn_on_gpu\"\n"
- " value {\n"
- " b: false\n"
- " }\n"
- " }\n"
- "}\n";
-
- std::map<std::string, armnn::TensorShape> inputShapes;
- armnn::TensorShape tensorShape = { 1, 3, 3, 1 };
- inputShapes["graphInput"] = tensorShape;
- armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
- BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, inputShapes, { "potato" }), armnn::ParseException);
+ RunTest<4>({1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
+ 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
+ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
+ 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
+ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
+ 7.0, 8.0, 9.0, 10.0, 11.0, 12.0},
+ {1.5f, 3.0f, 4.5f, 6.0f, 7.5f, 9.0f, 10.5f, 12.f, 13.5f, 15.0f, 16.5f, 18.0f});
}