diff options
-rw-r--r-- | 1.2/HalPolicy.cpp | 8 | ||||
-rw-r--r-- | ConversionUtils.hpp | 14 |
2 files changed, 18 insertions, 4 deletions
diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp index 5a940bea..d7452c68 100644 --- a/1.2/HalPolicy.cpp +++ b/1.2/HalPolicy.cpp @@ -243,8 +243,8 @@ bool HalPolicy::ConvertConv2d(const Operation& operation, const Model& model, Co const uint32_t inputX = inputInfo.GetShape()[widthIndex]; const uint32_t inputY = inputInfo.GetShape()[heightIndex]; - CalcPadding(inputX, kernelX, desc.m_StrideX, desc.m_PadLeft, desc.m_PadRight, paddingScheme); - CalcPadding(inputY, kernelY, desc.m_StrideY, desc.m_PadTop, desc.m_PadBottom, paddingScheme); + CalcPadding(inputX, kernelX, desc.m_StrideX, desc.m_DilationX, desc.m_PadLeft, desc.m_PadRight, paddingScheme); + CalcPadding(inputY, kernelY, desc.m_StrideY, desc.m_DilationY, desc.m_PadTop, desc.m_PadBottom, paddingScheme); } else if (operation.inputs.size() >= 10) @@ -400,8 +400,8 @@ bool HalPolicy::ConvertDepthwiseConv2d(const Operation& operation, const Model& const uint32_t inputX = inputInfo.GetShape()[widthIndex]; const uint32_t inputY = inputInfo.GetShape()[heightIndex]; - CalcPadding(inputX, kernelX, desc.m_StrideX, desc.m_PadLeft, desc.m_PadRight, paddingScheme); - CalcPadding(inputY, kernelY, desc.m_StrideY, desc.m_PadTop, desc.m_PadBottom, paddingScheme); + CalcPadding(inputX, kernelX, desc.m_StrideX, desc.m_DilationX, desc.m_PadLeft, desc.m_PadRight, paddingScheme); + CalcPadding(inputY, kernelY, desc.m_StrideY, desc.m_DilationY, desc.m_PadTop, desc.m_PadBottom, paddingScheme); } else if (operation.inputs.size() >= 11) { diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp index c9be0003..c59da1d5 100644 --- a/ConversionUtils.hpp +++ b/ConversionUtils.hpp @@ -300,6 +300,20 @@ void CalcPadding(uint32_t input, uint32_t kernel, uint32_t stride, uint32_t& out outPadTail = boost::numeric_cast<uint32_t>(padTail); } +#ifdef ARMNN_ANDROID_NN_V1_2 + +void CalcPadding(uint32_t input, uint32_t kernel, uint32_t stride, uint32_t dilation, uint32_t& outPadHead, + uint32_t& outPadTail, android::nn::PaddingScheme scheme) +{ + int32_t padHead; + int32_t padTail; + calculateExplicitPadding(input, stride, dilation, kernel, scheme, &padHead, &padTail); + outPadHead = boost::numeric_cast<uint32_t>(padHead); + outPadTail = boost::numeric_cast<uint32_t>(padTail); +} + +#endif + Shape GetOperandShape(const V1_0::Operand& operand) { Shape shape; |