aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2019-08-01 15:51:44 +0100
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2019-08-02 14:24:00 +0000
commitc8bdb3943635a7beb0fdc8538803c8e61a3abe33 (patch)
tree691ae75b3a7160be7f98bc839846523287e9a7a4
parentf094dba07b1e1dd65d8b85adcbbd98f2d77cb07a (diff)
downloadandroid-nn-driver-c8bdb3943635a7beb0fdc8538803c8e61a3abe33.tar.gz
IVGCVSW-3604 Fix TransposeConv2d padding calculation
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Signed-off-by: Aron Virginas-Tar <aron.virginas-tar@arm.com> Change-Id: I5c10ab18343ecf0ebeab24a436e5be2b6c2831c7
-rw-r--r--1.2/HalPolicy.cpp25
-rw-r--r--ConversionUtils.hpp6
2 files changed, 27 insertions, 4 deletions
diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp
index 575ae2b2..477806ef 100644
--- a/1.2/HalPolicy.cpp
+++ b/1.2/HalPolicy.cpp
@@ -1607,11 +1607,28 @@ bool HalPolicy::ConvertTransposeConv2d(const Operation& operation, const Model&
const uint32_t kernelX = weights.GetShape()[widthIndex];
const uint32_t kernelY = weights.GetShape()[heightIndex];
- const uint32_t inputX = inputInfo.GetShape()[widthIndex];
- const uint32_t inputY = inputInfo.GetShape()[heightIndex];
+ const uint32_t outputX = outputInfo.GetShape()[widthIndex];
+ const uint32_t outputY = outputInfo.GetShape()[heightIndex];
+
+ int32_t padLeft{0};
+ int32_t padRight{0};
+ int32_t padTop{0};
+ int32_t padBottom{0};
+
+ CalcPaddingTransposeConv(outputX, kernelX, desc.m_StrideX, padLeft, padRight, paddingScheme);
+ CalcPaddingTransposeConv(outputY, kernelY, desc.m_StrideY, padTop, padBottom, paddingScheme);
+
+ // NOTE: The Android NN API allows for negative padding values in TransposeConv2d,
+ // but Arm NN only supports values >= 0
+ if (padLeft < 0 || padRight < 0 || padTop < 0 || padBottom < 0)
+ {
+ return Fail("%s: Negative padding values are not supported", __func__);
+ }
- CalcPadding(inputX, kernelX, desc.m_StrideX, desc.m_PadLeft, desc.m_PadRight, paddingScheme);
- CalcPadding(inputY, kernelY, desc.m_StrideY, desc.m_PadTop, desc.m_PadBottom, paddingScheme);
+ desc.m_PadLeft = boost::numeric_cast<uint32_t>(padLeft);
+ desc.m_PadRight = boost::numeric_cast<uint32_t>(padRight);
+ desc.m_PadTop = boost::numeric_cast<uint32_t>(padTop);
+ desc.m_PadBottom = boost::numeric_cast<uint32_t>(padBottom);
}
else if (operation.inputs.size() == 11)
{
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index 18a65413..9471d781 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -323,6 +323,12 @@ void CalcPadding(uint32_t input, uint32_t kernel, uint32_t stride, uint32_t dila
outPadTail = boost::numeric_cast<uint32_t>(padTail);
}
+void CalcPaddingTransposeConv(uint32_t output, uint32_t kernel, uint32_t stride, int32_t& outPadHead,
+ int32_t& outPadTail, android::nn::PaddingScheme scheme)
+{
+ calculateExplicitPaddingTransposeConv(output, stride, kernel, scheme, &outPadHead, &outPadTail);
+}
+
#endif
Shape GetOperandShape(const V1_0::Operand& operand)