diff options
Diffstat (limited to 'src/armnn')
-rw-r--r-- | src/armnn/layers/TransposeConvolution2dLayer.cpp | 23 |
1 files changed, 14 insertions, 9 deletions
diff --git a/src/armnn/layers/TransposeConvolution2dLayer.cpp b/src/armnn/layers/TransposeConvolution2dLayer.cpp index f79c5887fb..534d6b431e 100644 --- a/src/armnn/layers/TransposeConvolution2dLayer.cpp +++ b/src/armnn/layers/TransposeConvolution2dLayer.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017,2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -98,20 +98,25 @@ void TransposeConvolution2dLayer::ValidateTensorShapesFromInputs() ARMNN_ASSERT_MSG(m_Weight != nullptr, "TransposeConvolution2dLayer: Weight data cannot be null."); std::vector<TensorShape> expectedOutputShape; + std::vector<TensorShape> outputShapeGivenAsInput; + + expectedOutputShape = InferOutputShapes({GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), + m_Weight->GetTensorInfo().GetShape() }); + + ARMNN_ASSERT(expectedOutputShape.size() == 1); + // If output_shape was specified then use it rather than calculate an inferred output shape. if (m_Param.m_OutputShapeEnabled) { TensorShape shapeAsTensorShape(static_cast<unsigned int>(m_Param.m_OutputShape.size()), m_Param.m_OutputShape.data()); - expectedOutputShape.push_back(shapeAsTensorShape); - } - else - { - expectedOutputShape = InferOutputShapes({GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), - m_Weight->GetTensorInfo().GetShape() }); - } + outputShapeGivenAsInput.push_back(shapeAsTensorShape); - ARMNN_ASSERT(expectedOutputShape.size() == 1); + ARMNN_ASSERT(outputShapeGivenAsInput.size() == 1); + ARMNN_ASSERT_MSG(expectedOutputShape == outputShapeGivenAsInput, + "TransposeConvolution2dLayer: output calculated by InferOutputShapes and " + "the output given as an input parameter to the layer are not matching"); + } ValidateAndCopyShape(outputShape, expectedOutputShape[0], m_ShapeInferenceMethod, "TransposeConvolution2dLayer"); } |