diff options
author | James Conroy <james.conroy@arm.com> | 2020-12-07 16:59:03 +0000 |
---|---|---|
committer | Jim Flynn <jim.flynn@arm.com> | 2020-12-08 11:50:02 +0000 |
commit | f24375df4087b7d39062d8c46b190e7abea4bc9c (patch) | |
tree | 52a54d9bf3e6da78d5a31cc91cc602717ffb564e /src/armnn/layers/TransposeConvolution2dLayer.cpp | |
parent | 6249d7e5f74323d2322fd69409db616efe46f8c8 (diff) | |
download | armnn-f24375df4087b7d39062d8c46b190e7abea4bc9c.tar.gz |
IVGCVSW-5500 Fix transpose conv InferOutputShape
* Use kernelShape[0] as channels for outputShape.
Signed-off-by: James Conroy <james.conroy@arm.com>
Change-Id: I070c7ff68ae365d9505a5eb28c76f9e52da1e5f9
Diffstat (limited to 'src/armnn/layers/TransposeConvolution2dLayer.cpp')
-rw-r--r-- | src/armnn/layers/TransposeConvolution2dLayer.cpp | 22 |
1 files changed, 3 insertions, 19 deletions
diff --git a/src/armnn/layers/TransposeConvolution2dLayer.cpp b/src/armnn/layers/TransposeConvolution2dLayer.cpp index 1591213d9d..189e5f6168 100644 --- a/src/armnn/layers/TransposeConvolution2dLayer.cpp +++ b/src/armnn/layers/TransposeConvolution2dLayer.cpp @@ -78,27 +78,11 @@ std::vector<TensorShape> TransposeConvolution2dLayer::InferOutputShapes( unsigned int wOutput = (wInput - 1) * m_Param.m_StrideX + wKernel - wPadding; unsigned int hOutput = (hInput - 1) * m_Param.m_StrideY + hKernel - hPadding; - - unsigned int kernelElements = kernelShape[0] * kernelShape[dataLayoutIndex.GetChannelsIndex()]; - unsigned int inputElements = batches * inputShape[dataLayoutIndex.GetChannelsIndex()]; - - ARMNN_ASSERT_MSG(inputElements != 0, "Invalid number of input elements"); - - unsigned int channels; - if (kernelElements >= inputElements) - { - ARMNN_ASSERT_MSG(kernelElements % inputElements == 0 , "Invalid number of elements"); - channels = kernelElements / inputElements; - } - else - { - ARMNN_ASSERT_MSG(inputElements % kernelElements == 0 , "Invalid number of elements"); - channels = kernelShape[0]; - } + unsigned int cOutput = kernelShape[0]; TensorShape tensorShape = m_Param.m_DataLayout == armnn::DataLayout::NHWC ? - TensorShape( { batches, hOutput, wOutput, channels } ) : - TensorShape( { batches, channels, hOutput, wOutput }); + TensorShape( { batches, hOutput, wOutput, cOutput } ) : + TensorShape( { batches, cOutput, hOutput, wOutput }); return std::vector<TensorShape>({ tensorShape }); } |