aboutsummaryrefslogtreecommitdiff
path: root/src/armnnOnnxParser
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2021-01-11 15:15:01 +0000
committerSadik Armagan <sadik.armagan@arm.com>2021-02-05 12:10:06 +0000
commit60bb9d80fa6fedfcb51afc0c9a74d6c2948873fd (patch)
tree9b578645e2b066fa971a88778f003dbbaac5608e /src/armnnOnnxParser
parenta4533faaeb07151476e074298f3403896f95668b (diff)
downloadarmnn-60bb9d80fa6fedfcb51afc0c9a74d6c2948873fd.tar.gz
MLCE-326 'Support Dilation in Conv2D in ONNX and Tensorflow Parsers'
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com> Change-Id: I4a0f07b1e8f80aff0d29405def1f33bde7944e31
Diffstat (limited to 'src/armnnOnnxParser')
-rw-r--r--src/armnnOnnxParser/OnnxParser.cpp74
-rw-r--r--src/armnnOnnxParser/test/Conv2D.cpp151
2 files changed, 197 insertions, 28 deletions
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index 9f5aa1975a..b4e7133239 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -331,22 +331,28 @@ std::string TensorInfoAsString(const TensorInfo& info,
return ss.str();
}
-void CalcPadding(uint32_t inputSize, uint32_t filterSize, uint32_t stride, uint32_t* paddingFront,
- uint32_t* paddingBack, bool isUpper)
+void CalcPadding(uint32_t inputSize,
+ uint32_t filterSize,
+ uint32_t stride,
+ uint32_t dilation,
+ uint32_t* paddingFront,
+ uint32_t* paddingBack,
+ bool isUpper)
{
uint32_t outputSize = (inputSize + stride - 1) / stride;
- uint32_t temp = (outputSize - 1) * stride + filterSize;
+ uint32_t dilatedSize = filterSize + (dilation - 1) * (filterSize - 1);
+ uint32_t temp = (outputSize - 1) * stride + dilatedSize;
*paddingFront = (temp - inputSize) / 2;
*paddingBack = *paddingFront;
if((temp - inputSize) % 2 == 1)
{
if (isUpper)
{
- *paddingBack += 1;
+ *paddingBack += 1;
}
else
{
- *paddingFront += 1;
+ *paddingFront += 1;
}
}
}
@@ -1025,8 +1031,20 @@ void OnnxParserImpl::AddPoolingLayer(const onnx::NodeProto& node, Pooling2dDescr
auto inputInfo = *m_TensorsInfo[node.input(0)].m_info;
uint32_t inputHeight = inputInfo.GetShape()[2];
uint32_t inputWidth = inputInfo.GetShape()[3];
- CalcPadding(inputHeight, desc.m_PoolHeight, desc.m_StrideY, &desc.m_PadTop, &desc.m_PadBottom, isUpper);
- CalcPadding(inputWidth, desc.m_PoolWidth, desc.m_StrideX, &desc.m_PadLeft, &desc.m_PadRight, isUpper);
+ CalcPadding(inputHeight,
+ desc.m_PoolHeight,
+ desc.m_StrideY,
+ 1u,
+ &desc.m_PadTop,
+ &desc.m_PadBottom,
+ isUpper);
+ CalcPadding(inputWidth,
+ desc.m_PoolWidth,
+ desc.m_StrideX,
+ 1u,
+ &desc.m_PadLeft,
+ &desc.m_PadRight,
+ isUpper);
}
}
else
@@ -1327,25 +1345,6 @@ void OnnxParserImpl::ParseConv(const onnx::NodeProto& node)
auto inputInfo = *m_TensorsInfo[node.input(0)].m_info;
- std::vector<uint32_t> dilations = ReadOptionalNodeUint32ListAttribute(node, "dilations");
- if (!dilations.empty())
- {
- std::stringstream ss;
- ss << "[ ";
- for (auto dilation : dilations)
- {
- ss << dilation << ", ";
- if (dilation != 1u)
- {
- ss << "... ]";
- throw ParseException(
- fmt::format("ArmNN only supports Convolution layers with dilations [1,1], and node '{}' "
- "has dilatation {} {}",
- node.name(), ss.str(), CHECK_LOCATION().AsString()));
- }
- }
- }
-
Convolution2dDescriptor desc;
desc.m_BiasEnabled = false;
@@ -1361,6 +1360,13 @@ void OnnxParserImpl::ParseConv(const onnx::NodeProto& node)
desc.m_StrideY = strides[0];
}
+ std::vector<uint32_t> dilations = ReadOptionalNodeUint32ListAttribute(node, "dilations");
+ if(!dilations.empty())
+ {
+ desc.m_DilationX = dilations[1];
+ desc.m_DilationY = dilations[0];
+ }
+
std::vector<uint32_t> pads = ReadOptionalNodeUint32ListAttribute(node, "pads");
//Check new padding version first
if(pads.empty())
@@ -1404,8 +1410,20 @@ void OnnxParserImpl::ParseConv(const onnx::NodeProto& node)
weightHeight = kernel_shape[0];
weightWidth = kernel_shape[1];
}
- CalcPadding(inputHeight, weightHeight, desc.m_StrideY, &desc.m_PadTop, &desc.m_PadBottom, isUpper);
- CalcPadding(inputWidth, weightWidth, desc.m_StrideX, &desc.m_PadLeft, &desc.m_PadRight, isUpper);
+ CalcPadding(inputHeight,
+ weightHeight,
+ desc.m_StrideY,
+ desc.m_DilationY,
+ &desc.m_PadTop,
+ &desc.m_PadBottom,
+ isUpper);
+ CalcPadding(inputWidth,
+ weightWidth,
+ desc.m_StrideX,
+ desc.m_DilationX,
+ &desc.m_PadLeft,
+ &desc.m_PadRight,
+ isUpper);
}
}
else
diff --git a/src/armnnOnnxParser/test/Conv2D.cpp b/src/armnnOnnxParser/test/Conv2D.cpp
index da67985107..a38cc192ed 100644
--- a/src/armnnOnnxParser/test/Conv2D.cpp
+++ b/src/armnnOnnxParser/test/Conv2D.cpp
@@ -438,6 +438,146 @@ struct Conv2DDimReducingFixture : public armnnUtils::ParserPrototxtFixture<armn
}
};
+struct Conv2DwithDilationFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
+{
+ Conv2DwithDilationFixture()
+ {
+ m_Prototext = R"(
+ ir_version: 3
+ producer_name: "CNTK"
+ producer_version: "2.5.1"
+ domain: "ai.cntk"
+ model_version: 1
+ graph {
+ name: "CNTKGraph"
+ input {
+ name: "Input"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 1
+ }
+ dim {
+ dim_value: 1
+ }
+ dim {
+ dim_value: 6
+ }
+ dim {
+ dim_value: 6
+ }
+ }
+ }
+ }
+ }
+ input {
+ name: "Weight"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 1
+ }
+ dim {
+ dim_value: 1
+ }
+ dim {
+ dim_value: 3
+ }
+ dim {
+ dim_value: 3
+ }
+ }
+ }
+ }
+ }
+ initializer {
+ dims: 1
+ dims: 1
+ dims: 3
+ dims: 3
+ data_type: 1
+ float_data: 2
+ float_data: 1
+ float_data: 0
+ float_data: 6
+ float_data: 2
+ float_data: 1
+ float_data: 4
+ float_data: 1
+ float_data: 2
+ name: "Weight"
+ }
+ node {
+ input: "Input"
+ input: "Weight"
+ output: "Output"
+ name: "Convolution"
+ op_type: "Conv"
+ attribute {
+ name: "kernel_shape"
+ ints: 3
+ ints: 3
+ type: INTS
+ }
+ attribute {
+ name: "strides"
+ ints: 1
+ ints: 1
+ type: INTS
+ }
+ attribute {
+ name: "auto_pad"
+ s: "VALID"
+ type: STRING
+ }
+ attribute {
+ name: "group"
+ i: 1
+ type: INT
+ }
+ attribute {
+ name: "dilations"
+ ints: 2
+ ints: 2
+ type: INTS
+ }
+ doc_string: ""
+ domain: ""
+ }
+ output {
+ name: "Output"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 1
+ }
+ dim {
+ dim_value: 1
+ }
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_value: 2
+ }
+ }
+ }
+ }
+ }
+ }
+ opset_import {
+ version: 7
+ })";
+ Setup();
+ }
+};
+
BOOST_FIXTURE_TEST_CASE(ValidConvTest, SimpleConv2DFixture)
{
RunTest<4>({{"Input", {1.0, 2.0, 3.0,
@@ -466,4 +606,15 @@ BOOST_FIXTURE_TEST_CASE(ValidConvDimReducTest, Conv2DDimReducingFixture)
1, 2, 3, 4}}});
}
+BOOST_FIXTURE_TEST_CASE(ValidConvWithDilationTest, Conv2DwithDilationFixture)
+{
+ RunTest<4>({{"Input", {1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
+ 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
+ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
+ 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
+ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
+ 7.0, 8.0, 9.0, 10.0, 11.0, 12.0}}},
+ {{"Output", {39.0, 58.0, 153.0, 172.0 }}});
+}
+
BOOST_AUTO_TEST_SUITE_END()