aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorFinn Williams <Finn.Williams@arm.com>2020-06-19 16:26:11 +0100
committerFinn Williams <Finn.Williams@arm.com>2020-06-23 13:11:26 +0100
commit8b23635b3049b24dc424905f33345456f3f6eb7d (patch)
tree27dba4b4a602618a3116ebda0d27c58fa37d3a68 /src
parent0204f09cb7462c675c45c76cd13d677d67f73589 (diff)
downloadarmnn-8b23635b3049b24dc424905f33345456f3f6eb7d.tar.gz
IVGCVSW-4924 Fix edge case for transposeConv2d shape inference
Signed-off-by: Finn Williams <Finn.Williams@arm.com> Change-Id: I0147ad10aeb16cf5c876cbf09434279ba6813714
Diffstat (limited to 'src')
-rw-r--r--src/armnn/layers/TransposeConvolution2dLayer.cpp13
1 files changed, 11 insertions, 2 deletions
diff --git a/src/armnn/layers/TransposeConvolution2dLayer.cpp b/src/armnn/layers/TransposeConvolution2dLayer.cpp
index 05941f7d78..28258820ad 100644
--- a/src/armnn/layers/TransposeConvolution2dLayer.cpp
+++ b/src/armnn/layers/TransposeConvolution2dLayer.cpp
@@ -83,9 +83,18 @@ std::vector<TensorShape> TransposeConvolution2dLayer::InferOutputShapes(
unsigned int inputElements = batches * inputShape[dataLayoutIndex.GetChannelsIndex()];
ARMNN_ASSERT_MSG(inputElements != 0, "Invalid number of input elements");
- ARMNN_ASSERT_MSG(kernelElements % inputElements == 0, "Invalid number of elements");
- unsigned int channels = kernelElements / inputElements;
+ unsigned int channels;
+ if (kernelElements >= inputElements)
+ {
+ ARMNN_ASSERT_MSG(kernelElements % inputElements == 0 , "Invalid number of elements");
+ channels = kernelElements / inputElements;
+ }
+ else
+ {
+ ARMNN_ASSERT_MSG(inputElements % kernelElements == 0 , "Invalid number of elements");
+ channels = kernelShape[0];
+ }
TensorShape tensorShape = m_Param.m_DataLayout == armnn::DataLayout::NHWC ?
TensorShape( { batches, hOutput, wOutput, channels } ) :