From ae7b832a6f5eda4b28577f57909111135a36dee9 Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Fri, 2 Aug 2019 15:08:59 +0100 Subject: IVGCVSW-3604 Fix channel shape calculation in TransposeConvolution2dLayer::InferOutputShapes Signed-off-by: Narumol Prangnawarat Change-Id: I2e3d5922bb89c8f3b84ff5458fda981ff177c3ce --- src/armnn/layers/TransposeConvolution2dLayer.cpp | 8 +++++++- src/armnn/test/InferOutputTests.hpp | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/armnn/layers/TransposeConvolution2dLayer.cpp b/src/armnn/layers/TransposeConvolution2dLayer.cpp index 1a994e7442..77a333d881 100644 --- a/src/armnn/layers/TransposeConvolution2dLayer.cpp +++ b/src/armnn/layers/TransposeConvolution2dLayer.cpp @@ -67,7 +67,6 @@ std::vector TransposeConvolution2dLayer::InferOutputShapes( DataLayoutIndexed dataLayoutIndex(m_Param.m_DataLayout); const unsigned int batches = inputShape[0]; - const unsigned int channels = inputShape[dataLayoutIndex.GetChannelsIndex()]; const unsigned int wInput = inputShape[dataLayoutIndex.GetWidthIndex()]; const unsigned int hInput = inputShape[dataLayoutIndex.GetHeightIndex()]; @@ -84,6 +83,13 @@ std::vector TransposeConvolution2dLayer::InferOutputShapes( unsigned int wOutput = wPaddedOutput - (m_Param.m_PadLeft + m_Param.m_PadRight); unsigned int hOutput = hPaddedOutput - (m_Param.m_PadTop + m_Param.m_PadBottom); + unsigned int kernelElements = kernelShape[0] * kernelShape[dataLayoutIndex.GetChannelsIndex()]; + unsigned int inputElements = batches * inputShape[dataLayoutIndex.GetChannelsIndex()]; + + BOOST_ASSERT_MSG(inputElements != 0, "Invalid number of input elements"); + BOOST_ASSERT_MSG(kernelElements % inputElements == 0, "Invalid number of elements"); + unsigned int channels = kernelElements / inputElements; + TensorShape tensorShape = m_Param.m_DataLayout == armnn::DataLayout::NHWC ? TensorShape( { batches, hOutput, wOutput, channels } ) : TensorShape( { batches, channels, hOutput, wOutput }); diff --git a/src/armnn/test/InferOutputTests.hpp b/src/armnn/test/InferOutputTests.hpp index 2dd2ff0e73..c428a9db61 100644 --- a/src/armnn/test/InferOutputTests.hpp +++ b/src/armnn/test/InferOutputTests.hpp @@ -406,7 +406,7 @@ void TransposeConvolution2dInferOutputShapeTest() armnn::TensorShape filterShape(4, filterSize.data()); shapes.push_back(filterShape); - const std::vector expectedOutputSizes = {1, 2, 6, 6}; + const std::vector expectedOutputSizes = {1, 1, 6, 6}; armnn::TensorShape expectedOutputShape(4, expectedOutputSizes.data()); BOOST_CHECK(expectedOutputShape == transposeConvolution2dLayer->InferOutputShapes(shapes).at(0)); -- cgit v1.2.1