aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/TransposeConvolution2dLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/TransposeConvolution2dLayer.cpp')
-rw-r--r--src/armnn/layers/TransposeConvolution2dLayer.cpp13
1 files changed, 11 insertions, 2 deletions
diff --git a/src/armnn/layers/TransposeConvolution2dLayer.cpp b/src/armnn/layers/TransposeConvolution2dLayer.cpp
index 05941f7d78..28258820ad 100644
--- a/src/armnn/layers/TransposeConvolution2dLayer.cpp
+++ b/src/armnn/layers/TransposeConvolution2dLayer.cpp
@@ -83,9 +83,18 @@ std::vector<TensorShape> TransposeConvolution2dLayer::InferOutputShapes(
unsigned int inputElements = batches * inputShape[dataLayoutIndex.GetChannelsIndex()];
ARMNN_ASSERT_MSG(inputElements != 0, "Invalid number of input elements");
- ARMNN_ASSERT_MSG(kernelElements % inputElements == 0, "Invalid number of elements");
- unsigned int channels = kernelElements / inputElements;
+ unsigned int channels;
+ if (kernelElements >= inputElements)
+ {
+ ARMNN_ASSERT_MSG(kernelElements % inputElements == 0 , "Invalid number of elements");
+ channels = kernelElements / inputElements;
+ }
+ else
+ {
+ ARMNN_ASSERT_MSG(inputElements % kernelElements == 0 , "Invalid number of elements");
+ channels = kernelShape[0];
+ }
TensorShape tensorShape = m_Param.m_DataLayout == armnn::DataLayout::NHWC ?
TensorShape( { batches, hOutput, wOutput, channels } ) :