From 7dc18207d86299cf605a2316e86dc9c098708729 Mon Sep 17 00:00:00 2001 From: Keith Mok Date: Sun, 20 Dec 2020 13:45:51 -0800 Subject: Add Caffe Parser Dilation support Signed-off-by: Keith Mok Change-Id: I3a85de2d082d489fbf5a775c2ae551080d189294 --- src/armnnCaffeParser/CaffeParser.cpp | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/src/armnnCaffeParser/CaffeParser.cpp b/src/armnnCaffeParser/CaffeParser.cpp index d11da466b8..51b58ccea3 100644 --- a/src/armnnCaffeParser/CaffeParser.cpp +++ b/src/armnnCaffeParser/CaffeParser.cpp @@ -479,12 +479,12 @@ void CaffeParserBase::AddConvLayerWithSplits(const caffe::LayerParameter& layerP outputShape.add_dim(2); outputShape.set_dim( 2, (static_cast( - static_cast(inputShape.dim(2) + 2 * desc.m_PadBottom - kernelH) / + static_cast(inputShape.dim(2) + 2 * desc.m_PadBottom - (desc.m_DilationX * (kernelH - 1) + 1)) / static_cast(desc.m_StrideY)) + 1)); outputShape.add_dim(3); outputShape.set_dim( 3, (static_cast( - static_cast(inputShape.dim(3) + 2 * desc.m_PadRight - kernelW) / + static_cast(inputShape.dim(3) + 2 * desc.m_PadRight - (desc.m_DilationY * (kernelW - 1) + 1)) / static_cast(desc.m_StrideX)) + 1)); // Load the weight data for ALL groups @@ -609,6 +609,8 @@ void CaffeParserBase::AddConvLayerWithDepthwiseConv(const caffe::LayerParameter& desc.m_PadBottom = convDesc.m_PadBottom; desc.m_StrideX = convDesc.m_StrideX; desc.m_StrideY = convDesc.m_StrideY; + desc.m_DilationX = convDesc.m_DilationX; + desc.m_DilationY = convDesc.m_DilationY; desc.m_BiasEnabled = convDesc.m_BiasEnabled; unsigned int numFilters = convParam.num_output(); @@ -621,12 +623,12 @@ void CaffeParserBase::AddConvLayerWithDepthwiseConv(const caffe::LayerParameter& outputShape.add_dim(2); outputShape.set_dim( 2, (static_cast( - static_cast(inputShape.dim(2) + 2 * desc.m_PadBottom - kernelH) / + static_cast(inputShape.dim(2) + 2 * desc.m_PadBottom - (desc.m_DilationX * (kernelH - 1) + 1)) / static_cast(desc.m_StrideY)) + 1)); outputShape.add_dim(3); outputShape.set_dim( 3, (static_cast( - static_cast(inputShape.dim(3) + 2 * desc.m_PadRight - kernelW) / + static_cast(inputShape.dim(3) + 2 * desc.m_PadRight - (desc.m_DilationY * (kernelW - 1) + 1)) / static_cast(desc.m_StrideX)) + 1)); // Load the weight data @@ -682,7 +684,6 @@ void CaffeParserBase::AddConvLayerWithDepthwiseConv(const caffe::LayerParameter& void CaffeParserBase::ParseConvLayer(const LayerParameter& layerParam) { // Ignored Caffe Parameters - // * Dilation Size // * Weight Filler // * Bias Filler // * Engine @@ -717,6 +718,10 @@ void CaffeParserBase::ParseConvLayer(const LayerParameter& layerParam) unsigned int padW = GET_OPTIONAL_WITH_VECTOR_FALLBACK(convParam, ConvolutionParameter, pad_w, pad, unsigned int, 0u); + unsigned int dilationH = convParam.dilation_size() > 0 ? convParam.dilation(0) : 1; + unsigned int dilationW = convParam.dilation_size() > 1 ? convParam.dilation(1) : + convParam.dilation_size() > 0 ? convParam.dilation(0) : 1; + Convolution2dDescriptor convolution2dDescriptor; convolution2dDescriptor.m_PadLeft = padW; convolution2dDescriptor.m_PadRight = padW; @@ -724,6 +729,8 @@ void CaffeParserBase::ParseConvLayer(const LayerParameter& layerParam) convolution2dDescriptor.m_PadBottom = padH; convolution2dDescriptor.m_StrideX = strideW; convolution2dDescriptor.m_StrideY = strideH; + convolution2dDescriptor.m_DilationX = dilationW; + convolution2dDescriptor.m_DilationY = dilationH; convolution2dDescriptor.m_BiasEnabled = convParam.has_bias_term() ? convParam.bias_term() : true; if (numGroups > numFilters) @@ -789,12 +796,12 @@ void CaffeParserBase::ParseConvLayer(const LayerParameter& layerParam) outputShape.add_dim(2); outputShape.set_dim( 2, (static_cast( - static_cast(inputShape.dim(2) + 2 * padH - kernelH) / + static_cast(inputShape.dim(2) + 2 * padH - (dilationH * (kernelH - 1) + 1)) / static_cast(strideH)) + 1)); outputShape.add_dim(3); outputShape.set_dim( 3, (static_cast( - static_cast(inputShape.dim(3) + 2 * padW - kernelW) / + static_cast(inputShape.dim(3) + 2 * padW - (dilationW * (kernelW - 1) + 1)) / static_cast(strideW)) + 1)); // Load the weight data for ALL groups -- cgit v1.2.1