diff options
Diffstat (limited to 'src/armnn')
-rw-r--r-- | src/armnn/layers/TransposeConvolution2dLayer.cpp | 22 |
1 files changed, 3 insertions, 19 deletions
diff --git a/src/armnn/layers/TransposeConvolution2dLayer.cpp b/src/armnn/layers/TransposeConvolution2dLayer.cpp index 1591213d9d..189e5f6168 100644 --- a/src/armnn/layers/TransposeConvolution2dLayer.cpp +++ b/src/armnn/layers/TransposeConvolution2dLayer.cpp @@ -78,27 +78,11 @@ std::vector<TensorShape> TransposeConvolution2dLayer::InferOutputShapes( unsigned int wOutput = (wInput - 1) * m_Param.m_StrideX + wKernel - wPadding; unsigned int hOutput = (hInput - 1) * m_Param.m_StrideY + hKernel - hPadding; - - unsigned int kernelElements = kernelShape[0] * kernelShape[dataLayoutIndex.GetChannelsIndex()]; - unsigned int inputElements = batches * inputShape[dataLayoutIndex.GetChannelsIndex()]; - - ARMNN_ASSERT_MSG(inputElements != 0, "Invalid number of input elements"); - - unsigned int channels; - if (kernelElements >= inputElements) - { - ARMNN_ASSERT_MSG(kernelElements % inputElements == 0 , "Invalid number of elements"); - channels = kernelElements / inputElements; - } - else - { - ARMNN_ASSERT_MSG(inputElements % kernelElements == 0 , "Invalid number of elements"); - channels = kernelShape[0]; - } + unsigned int cOutput = kernelShape[0]; TensorShape tensorShape = m_Param.m_DataLayout == armnn::DataLayout::NHWC ? - TensorShape( { batches, hOutput, wOutput, channels } ) : - TensorShape( { batches, channels, hOutput, wOutput }); + TensorShape( { batches, hOutput, wOutput, cOutput } ) : + TensorShape( { batches, cOutput, hOutput, wOutput }); return std::vector<TensorShape>({ tensorShape }); } |