diff options
author | Finn Williams <Finn.Williams@arm.com> | 2020-06-19 16:26:11 +0100 |
---|---|---|
committer | Finn Williams <Finn.Williams@arm.com> | 2020-06-23 13:11:26 +0100 |
commit | 8b23635b3049b24dc424905f33345456f3f6eb7d (patch) | |
tree | 27dba4b4a602618a3116ebda0d27c58fa37d3a68 /src/armnn | |
parent | 0204f09cb7462c675c45c76cd13d677d67f73589 (diff) | |
download | armnn-8b23635b3049b24dc424905f33345456f3f6eb7d.tar.gz |
IVGCVSW-4924 Fix edge case for transposeConv2d shape inference
Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Change-Id: I0147ad10aeb16cf5c876cbf09434279ba6813714
Diffstat (limited to 'src/armnn')
-rw-r--r-- | src/armnn/layers/TransposeConvolution2dLayer.cpp | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/src/armnn/layers/TransposeConvolution2dLayer.cpp b/src/armnn/layers/TransposeConvolution2dLayer.cpp index 05941f7d78..28258820ad 100644 --- a/src/armnn/layers/TransposeConvolution2dLayer.cpp +++ b/src/armnn/layers/TransposeConvolution2dLayer.cpp @@ -83,9 +83,18 @@ std::vector<TensorShape> TransposeConvolution2dLayer::InferOutputShapes( unsigned int inputElements = batches * inputShape[dataLayoutIndex.GetChannelsIndex()]; ARMNN_ASSERT_MSG(inputElements != 0, "Invalid number of input elements"); - ARMNN_ASSERT_MSG(kernelElements % inputElements == 0, "Invalid number of elements"); - unsigned int channels = kernelElements / inputElements; + unsigned int channels; + if (kernelElements >= inputElements) + { + ARMNN_ASSERT_MSG(kernelElements % inputElements == 0 , "Invalid number of elements"); + channels = kernelElements / inputElements; + } + else + { + ARMNN_ASSERT_MSG(inputElements % kernelElements == 0 , "Invalid number of elements"); + channels = kernelShape[0]; + } TensorShape tensorShape = m_Param.m_DataLayout == armnn::DataLayout::NHWC ? TensorShape( { batches, hOutput, wOutput, channels } ) : |