aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-09-17 12:58:22 +0100
committerDavid Monahan <david.monahan@arm.com>2020-01-15 15:56:58 +0000
commit1605fc9bda5e77a314c13d101f924971d3f52717 (patch)
treedffc56aa64ecff3ea35b76d8b9d45c761e3e06f4
parent4b7adae342658d81702301d8094b9212d73338ff (diff)
downloadarmnn-1605fc9bda5e77a314c13d101f924971d3f52717.tar.gz
IVGCVSW-3879 Fix output shape inference formula for TransposeConvolution2d
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> Change-Id: I766f4297b9daa26edacc2079fe62a083ba2fa68f
-rw-r--r--src/armnn/layers/TransposeConvolution2dLayer.cpp16
-rw-r--r--src/armnnTfLiteParser/test/TransposeConv.cpp2
2 files changed, 8 insertions, 10 deletions
diff --git a/src/armnn/layers/TransposeConvolution2dLayer.cpp b/src/armnn/layers/TransposeConvolution2dLayer.cpp
index 77a333d881..7bd2f3b9d8 100644
--- a/src/armnn/layers/TransposeConvolution2dLayer.cpp
+++ b/src/armnn/layers/TransposeConvolution2dLayer.cpp
@@ -66,7 +66,7 @@ std::vector<TensorShape> TransposeConvolution2dLayer::InferOutputShapes(
DataLayoutIndexed dataLayoutIndex(m_Param.m_DataLayout);
- const unsigned int batches = inputShape[0];
+ const unsigned int batches = inputShape[0];
const unsigned int wInput = inputShape[dataLayoutIndex.GetWidthIndex()];
const unsigned int hInput = inputShape[dataLayoutIndex.GetHeightIndex()];
@@ -74,20 +74,18 @@ std::vector<TensorShape> TransposeConvolution2dLayer::InferOutputShapes(
const unsigned int wKernel = kernelShape[dataLayoutIndex.GetWidthIndex()];
const unsigned int hKernel = kernelShape[dataLayoutIndex.GetHeightIndex()];
- const unsigned int wStridedInput = 1u + m_Param.m_StrideX * (wInput - 1);
- const unsigned int hStridedInput = 1u + m_Param.m_StrideY * (hInput - 1);
+ unsigned int wPadding = m_Param.m_PadLeft + m_Param.m_PadRight;
+ unsigned int hPadding = m_Param.m_PadTop + m_Param.m_PadBottom;
- const unsigned int wPaddedOutput = wStridedInput + wKernel - (wKernel % 2);
- const unsigned int hPaddedOutput = hStridedInput + hKernel - (hKernel % 2);
-
- unsigned int wOutput = wPaddedOutput - (m_Param.m_PadLeft + m_Param.m_PadRight);
- unsigned int hOutput = hPaddedOutput - (m_Param.m_PadTop + m_Param.m_PadBottom);
+ unsigned int wOutput = (wInput - 1) * m_Param.m_StrideX + wKernel - wPadding;
+ unsigned int hOutput = (hInput - 1) * m_Param.m_StrideY + hKernel - hPadding;
unsigned int kernelElements = kernelShape[0] * kernelShape[dataLayoutIndex.GetChannelsIndex()];
- unsigned int inputElements = batches * inputShape[dataLayoutIndex.GetChannelsIndex()];
+ unsigned int inputElements = batches * inputShape[dataLayoutIndex.GetChannelsIndex()];
BOOST_ASSERT_MSG(inputElements != 0, "Invalid number of input elements");
BOOST_ASSERT_MSG(kernelElements % inputElements == 0, "Invalid number of elements");
+
unsigned int channels = kernelElements / inputElements;
TensorShape tensorShape = m_Param.m_DataLayout == armnn::DataLayout::NHWC ?
diff --git a/src/armnnTfLiteParser/test/TransposeConv.cpp b/src/armnnTfLiteParser/test/TransposeConv.cpp
index 53212b58f9..46b02ac956 100644
--- a/src/armnnTfLiteParser/test/TransposeConv.cpp
+++ b/src/armnnTfLiteParser/test/TransposeConv.cpp
@@ -83,7 +83,7 @@ struct TransposeConvFixture : public ParserFlatbuffersFixture
"outputs": [ 3 ],
"builtin_options_type": "TransposeConvOptions",
"builtin_options": {
- "padding": "SAME",
+ "padding": "VALID",
"stride_w": )" + strideX + R"(,
"stride_h": )" + strideY + R"(
},