diff options
Diffstat (limited to 'src/armnn/layers/Pooling2dLayer.cpp')
-rw-r--r-- | src/armnn/layers/Pooling2dLayer.cpp | 30 |
1 files changed, 16 insertions, 14 deletions
diff --git a/src/armnn/layers/Pooling2dLayer.cpp b/src/armnn/layers/Pooling2dLayer.cpp index ede37d7604..68049101e7 100644 --- a/src/armnn/layers/Pooling2dLayer.cpp +++ b/src/armnn/layers/Pooling2dLayer.cpp @@ -29,15 +29,10 @@ Pooling2dLayer* Pooling2dLayer::Clone(Graph& graph) const return CloneBase<Pooling2dLayer>(graph, m_Param, GetName()); } -void Pooling2dLayer::ValidateTensorShapesFromInputs() +std::vector<TensorShape> Pooling2dLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const { - ConditionalThrow<LayerValidationException>(GetInputSlot(0).GetConnection() != nullptr, - "Pooling2dLayer: InputSlot must be connected to an OutputSlot"); - ConditionalThrow<LayerValidationException>(GetInputSlot(0).GetConnection()->IsTensorInfoSet(), - "Pooling2dLayer: TensorInfo must be set on connected InputSlot."); - - IOutputSlot* input = GetInputSlot(0).GetConnection(); - const TensorShape& inputShape = input->GetTensorInfo().GetShape(); + BOOST_ASSERT(inputShapes.size() == 1); + const TensorShape& inputShape = inputShapes[0]; // If we support multiple batch dimensions in the future, then this assert will need to change. BOOST_ASSERT_MSG(inputShape.GetNumDimensions() == 4, "Pooling2dLayer will always have 4D input."); @@ -75,8 +70,8 @@ void Pooling2dLayer::ValidateTensorShapesFromInputs() BOOST_ASSERT_MSG(false, "Unsupported Output Shape Rounding"); } - // Make sure that border operations will start from inside the input and not the padded area - // This is what both Caffe and CL does... + // MakeS sure that border operations will start from inside the input and not the padded area. + // This is what both Caffe and CL do... if ((size - 1)*stride >= inSize + lowPad) { --size; @@ -89,18 +84,25 @@ void Pooling2dLayer::ValidateTensorShapesFromInputs() m_Param.m_PaddingMethod, m_Param.m_OutputShapeRounding); outHeight= CalcSize(inHeight, m_Param.m_PadTop, m_Param.m_PadBottom, m_Param.m_PoolHeight, m_Param.m_StrideY, m_Param.m_PaddingMethod, m_Param.m_OutputShapeRounding); - - } unsigned int outChannels = inChannels; unsigned int outBatchSize = inBatchSize; - TensorShape shapeOut({outBatchSize, outChannels, outHeight, outWidth}); + return std::vector<TensorShape>({ TensorShape({outBatchSize, outChannels, outHeight, outWidth}) }); +} + +void Pooling2dLayer::ValidateTensorShapesFromInputs() +{ + VerifyLayerConnections(1, CHECK_LOCATION()); + + auto inferredShapes = InferOutputShapes({ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape() }); + + BOOST_ASSERT(inferredShapes.size() == 1); ConditionalThrowIfNotEqual<LayerValidationException>( "Pooling2dLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.", GetOutputSlot(0).GetTensorInfo().GetShape(), - shapeOut); + inferredShapes[0]); } } // namespace armnn |