aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/TransposeConvolution2dLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/TransposeConvolution2dLayer.cpp')
-rw-r--r--src/armnn/layers/TransposeConvolution2dLayer.cpp20
1 files changed, 15 insertions, 5 deletions
diff --git a/src/armnn/layers/TransposeConvolution2dLayer.cpp b/src/armnn/layers/TransposeConvolution2dLayer.cpp
index ffe92bbbd2..8a264253e0 100644
--- a/src/armnn/layers/TransposeConvolution2dLayer.cpp
+++ b/src/armnn/layers/TransposeConvolution2dLayer.cpp
@@ -111,16 +111,26 @@ void TransposeConvolution2dLayer::ValidateTensorShapesFromInputs(ShapeInferenceM
ARMNN_ASSERT_MSG(m_Weight != nullptr, "TransposeConvolution2dLayer: Weight data cannot be null.");
- auto inferredShapes = InferOutputShapes({
- GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
- m_Weight->GetTensorInfo().GetShape() });
+ std::vector<TensorShape> expectedOutputShape;
+ // If output_shape was specified then use it rather than calculate an inferred output shape.
+ if (m_Param.m_OutputShapeEnabled)
+ {
+ TensorShape shapeAsTensorShape(static_cast<unsigned int>(m_Param.m_OutputShape.size()),
+ m_Param.m_OutputShape.data());
+ expectedOutputShape.push_back(shapeAsTensorShape);
+ }
+ else
+ {
+ expectedOutputShape = InferOutputShapes({GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
+ m_Weight->GetTensorInfo().GetShape() });
+ }
- ARMNN_ASSERT(inferredShapes.size() == 1);
+ ARMNN_ASSERT(expectedOutputShape.size() == 1);
ConditionalThrowIfNotEqual<LayerValidationException>(
"TransposeConvolution2dLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
GetOutputSlot(0).GetTensorInfo().GetShape(),
- inferredShapes[0]);
+ expectedOutputShape[0]);
}
Layer::ConstantTensors TransposeConvolution2dLayer::GetConstantTensorsByRef()