aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2019-08-02 15:08:59 +0100
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2019-08-02 15:26:58 +0100
commitae7b832a6f5eda4b28577f57909111135a36dee9 (patch)
tree3ea84622726dc3800c741d7be3cebb775eb20bb2 /src/armnn/layers
parent87f65eab4abb65273ea11eb8ca876196ef82c6c5 (diff)
downloadarmnn-ae7b832a6f5eda4b28577f57909111135a36dee9.tar.gz
IVGCVSW-3604 Fix channel shape calculation in TransposeConvolution2dLayer::InferOutputShapes
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: I2e3d5922bb89c8f3b84ff5458fda981ff177c3ce
Diffstat (limited to 'src/armnn/layers')
-rw-r--r--src/armnn/layers/TransposeConvolution2dLayer.cpp8
1 files changed, 7 insertions, 1 deletions
diff --git a/src/armnn/layers/TransposeConvolution2dLayer.cpp b/src/armnn/layers/TransposeConvolution2dLayer.cpp
index 1a994e7442..77a333d881 100644
--- a/src/armnn/layers/TransposeConvolution2dLayer.cpp
+++ b/src/armnn/layers/TransposeConvolution2dLayer.cpp
@@ -67,7 +67,6 @@ std::vector<TensorShape> TransposeConvolution2dLayer::InferOutputShapes(
DataLayoutIndexed dataLayoutIndex(m_Param.m_DataLayout);
const unsigned int batches = inputShape[0];
- const unsigned int channels = inputShape[dataLayoutIndex.GetChannelsIndex()];
const unsigned int wInput = inputShape[dataLayoutIndex.GetWidthIndex()];
const unsigned int hInput = inputShape[dataLayoutIndex.GetHeightIndex()];
@@ -84,6 +83,13 @@ std::vector<TensorShape> TransposeConvolution2dLayer::InferOutputShapes(
unsigned int wOutput = wPaddedOutput - (m_Param.m_PadLeft + m_Param.m_PadRight);
unsigned int hOutput = hPaddedOutput - (m_Param.m_PadTop + m_Param.m_PadBottom);
+ unsigned int kernelElements = kernelShape[0] * kernelShape[dataLayoutIndex.GetChannelsIndex()];
+ unsigned int inputElements = batches * inputShape[dataLayoutIndex.GetChannelsIndex()];
+
+ BOOST_ASSERT_MSG(inputElements != 0, "Invalid number of input elements");
+ BOOST_ASSERT_MSG(kernelElements % inputElements == 0, "Invalid number of elements");
+ unsigned int channels = kernelElements / inputElements;
+
TensorShape tensorShape = m_Param.m_DataLayout == armnn::DataLayout::NHWC ?
TensorShape( { batches, hOutput, wOutput, channels } ) :
TensorShape( { batches, channels, hOutput, wOutput });