diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/armnn/layers/TransposeConvolution2dLayer.cpp | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/src/armnn/layers/TransposeConvolution2dLayer.cpp b/src/armnn/layers/TransposeConvolution2dLayer.cpp index 05941f7d78..28258820ad 100644 --- a/src/armnn/layers/TransposeConvolution2dLayer.cpp +++ b/src/armnn/layers/TransposeConvolution2dLayer.cpp @@ -83,9 +83,18 @@ std::vector<TensorShape> TransposeConvolution2dLayer::InferOutputShapes( unsigned int inputElements = batches * inputShape[dataLayoutIndex.GetChannelsIndex()]; ARMNN_ASSERT_MSG(inputElements != 0, "Invalid number of input elements"); - ARMNN_ASSERT_MSG(kernelElements % inputElements == 0, "Invalid number of elements"); - unsigned int channels = kernelElements / inputElements; + unsigned int channels; + if (kernelElements >= inputElements) + { + ARMNN_ASSERT_MSG(kernelElements % inputElements == 0 , "Invalid number of elements"); + channels = kernelElements / inputElements; + } + else + { + ARMNN_ASSERT_MSG(inputElements % kernelElements == 0 , "Invalid number of elements"); + channels = kernelShape[0]; + } TensorShape tensorShape = m_Param.m_DataLayout == armnn::DataLayout::NHWC ? TensorShape( { batches, hOutput, wOutput, channels } ) : |