From 8b23635b3049b24dc424905f33345456f3f6eb7d Mon Sep 17 00:00:00 2001 From: Finn Williams Date: Fri, 19 Jun 2020 16:26:11 +0100 Subject: IVGCVSW-4924 Fix edge case for transposeConv2d shape inference Signed-off-by: Finn Williams Change-Id: I0147ad10aeb16cf5c876cbf09434279ba6813714 --- src/armnn/layers/TransposeConvolution2dLayer.cpp | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) (limited to 'src') diff --git a/src/armnn/layers/TransposeConvolution2dLayer.cpp b/src/armnn/layers/TransposeConvolution2dLayer.cpp index 05941f7d78..28258820ad 100644 --- a/src/armnn/layers/TransposeConvolution2dLayer.cpp +++ b/src/armnn/layers/TransposeConvolution2dLayer.cpp @@ -83,9 +83,18 @@ std::vector TransposeConvolution2dLayer::InferOutputShapes( unsigned int inputElements = batches * inputShape[dataLayoutIndex.GetChannelsIndex()]; ARMNN_ASSERT_MSG(inputElements != 0, "Invalid number of input elements"); - ARMNN_ASSERT_MSG(kernelElements % inputElements == 0, "Invalid number of elements"); - unsigned int channels = kernelElements / inputElements; + unsigned int channels; + if (kernelElements >= inputElements) + { + ARMNN_ASSERT_MSG(kernelElements % inputElements == 0 , "Invalid number of elements"); + channels = kernelElements / inputElements; + } + else + { + ARMNN_ASSERT_MSG(inputElements % kernelElements == 0 , "Invalid number of elements"); + channels = kernelShape[0]; + } TensorShape tensorShape = m_Param.m_DataLayout == armnn::DataLayout::NHWC ? TensorShape( { batches, hOutput, wOutput, channels } ) : -- cgit v1.2.1