From f24375df4087b7d39062d8c46b190e7abea4bc9c Mon Sep 17 00:00:00 2001 From: James Conroy Date: Mon, 7 Dec 2020 16:59:03 +0000 Subject: IVGCVSW-5500 Fix transpose conv InferOutputShape * Use kernelShape[0] as channels for outputShape. Signed-off-by: James Conroy Change-Id: I070c7ff68ae365d9505a5eb28c76f9e52da1e5f9 --- src/armnn/layers/TransposeConvolution2dLayer.cpp | 22 +++------------------- 1 file changed, 3 insertions(+), 19 deletions(-) (limited to 'src/armnn/layers') diff --git a/src/armnn/layers/TransposeConvolution2dLayer.cpp b/src/armnn/layers/TransposeConvolution2dLayer.cpp index 1591213d9d..189e5f6168 100644 --- a/src/armnn/layers/TransposeConvolution2dLayer.cpp +++ b/src/armnn/layers/TransposeConvolution2dLayer.cpp @@ -78,27 +78,11 @@ std::vector TransposeConvolution2dLayer::InferOutputShapes( unsigned int wOutput = (wInput - 1) * m_Param.m_StrideX + wKernel - wPadding; unsigned int hOutput = (hInput - 1) * m_Param.m_StrideY + hKernel - hPadding; - - unsigned int kernelElements = kernelShape[0] * kernelShape[dataLayoutIndex.GetChannelsIndex()]; - unsigned int inputElements = batches * inputShape[dataLayoutIndex.GetChannelsIndex()]; - - ARMNN_ASSERT_MSG(inputElements != 0, "Invalid number of input elements"); - - unsigned int channels; - if (kernelElements >= inputElements) - { - ARMNN_ASSERT_MSG(kernelElements % inputElements == 0 , "Invalid number of elements"); - channels = kernelElements / inputElements; - } - else - { - ARMNN_ASSERT_MSG(inputElements % kernelElements == 0 , "Invalid number of elements"); - channels = kernelShape[0]; - } + unsigned int cOutput = kernelShape[0]; TensorShape tensorShape = m_Param.m_DataLayout == armnn::DataLayout::NHWC ? - TensorShape( { batches, hOutput, wOutput, channels } ) : - TensorShape( { batches, channels, hOutput, wOutput }); + TensorShape( { batches, hOutput, wOutput, cOutput } ) : + TensorShape( { batches, cOutput, hOutput, wOutput }); return std::vector({ tensorShape }); } -- cgit v1.2.1