aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJames Conroy <james.conroy@arm.com>2020-12-07 16:59:03 +0000
committerJim Flynn <jim.flynn@arm.com>2020-12-08 11:50:02 +0000
commitf24375df4087b7d39062d8c46b190e7abea4bc9c (patch)
tree52a54d9bf3e6da78d5a31cc91cc602717ffb564e
parent6249d7e5f74323d2322fd69409db616efe46f8c8 (diff)
downloadarmnn-f24375df4087b7d39062d8c46b190e7abea4bc9c.tar.gz
IVGCVSW-5500 Fix transpose conv InferOutputShape
* Use kernelShape[0] as channels for outputShape. Signed-off-by: James Conroy <james.conroy@arm.com> Change-Id: I070c7ff68ae365d9505a5eb28c76f9e52da1e5f9
-rw-r--r--src/armnn/layers/TransposeConvolution2dLayer.cpp22
1 files changed, 3 insertions, 19 deletions
diff --git a/src/armnn/layers/TransposeConvolution2dLayer.cpp b/src/armnn/layers/TransposeConvolution2dLayer.cpp
index 1591213d9d..189e5f6168 100644
--- a/src/armnn/layers/TransposeConvolution2dLayer.cpp
+++ b/src/armnn/layers/TransposeConvolution2dLayer.cpp
@@ -78,27 +78,11 @@ std::vector<TensorShape> TransposeConvolution2dLayer::InferOutputShapes(
unsigned int wOutput = (wInput - 1) * m_Param.m_StrideX + wKernel - wPadding;
unsigned int hOutput = (hInput - 1) * m_Param.m_StrideY + hKernel - hPadding;
-
- unsigned int kernelElements = kernelShape[0] * kernelShape[dataLayoutIndex.GetChannelsIndex()];
- unsigned int inputElements = batches * inputShape[dataLayoutIndex.GetChannelsIndex()];
-
- ARMNN_ASSERT_MSG(inputElements != 0, "Invalid number of input elements");
-
- unsigned int channels;
- if (kernelElements >= inputElements)
- {
- ARMNN_ASSERT_MSG(kernelElements % inputElements == 0 , "Invalid number of elements");
- channels = kernelElements / inputElements;
- }
- else
- {
- ARMNN_ASSERT_MSG(inputElements % kernelElements == 0 , "Invalid number of elements");
- channels = kernelShape[0];
- }
+ unsigned int cOutput = kernelShape[0];
TensorShape tensorShape = m_Param.m_DataLayout == armnn::DataLayout::NHWC ?
- TensorShape( { batches, hOutput, wOutput, channels } ) :
- TensorShape( { batches, channels, hOutput, wOutput });
+ TensorShape( { batches, hOutput, wOutput, cOutput } ) :
+ TensorShape( { batches, cOutput, hOutput, wOutput });
return std::vector<TensorShape>({ tensorShape });
}