aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/Pooling2dLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/Pooling2dLayer.cpp')
-rw-r--r--src/armnn/layers/Pooling2dLayer.cpp13
1 files changed, 8 insertions, 5 deletions
diff --git a/src/armnn/layers/Pooling2dLayer.cpp b/src/armnn/layers/Pooling2dLayer.cpp
index d87ad0f19f..779ac2041e 100644
--- a/src/armnn/layers/Pooling2dLayer.cpp
+++ b/src/armnn/layers/Pooling2dLayer.cpp
@@ -37,10 +37,9 @@ std::vector<TensorShape> Pooling2dLayer::InferOutputShapes(const std::vector<Ten
// If we support multiple batch dimensions in the future, then this assert will need to change.
BOOST_ASSERT_MSG(inputShape.GetNumDimensions() == 4, "Pooling2dLayer will always have 4D input.");
-
- unsigned int inWidth = inputShape[3];
- unsigned int inHeight = inputShape[2];
- unsigned int inChannels = inputShape[1];
+ unsigned int inWidth = inputShape[m_Param.m_DataLayout.GetWidthIndex()];
+ unsigned int inHeight = inputShape[m_Param.m_DataLayout.GetHeightIndex()];
+ unsigned int inChannels = inputShape[m_Param.m_DataLayout.GetChannelsIndex()];
unsigned int inBatchSize = inputShape[0];
bool isGlobalPooling = (m_Param.m_StrideX==0 && m_Param.m_StrideY==0);
@@ -88,7 +87,11 @@ std::vector<TensorShape> Pooling2dLayer::InferOutputShapes(const std::vector<Ten
unsigned int outChannels = inChannels;
unsigned int outBatchSize = inBatchSize;
- return std::vector<TensorShape>({ TensorShape({outBatchSize, outChannels, outHeight, outWidth}) });
+ TensorShape tensorShape = m_Param.m_DataLayout == armnn::DataLayout::NHWC ?
+ TensorShape( { outBatchSize, outHeight, outWidth, outChannels } ) :
+ TensorShape( { outBatchSize, outChannels, outHeight, outWidth });
+
+ return std::vector<TensorShape>({ tensorShape });
}
void Pooling2dLayer::ValidateTensorShapesFromInputs()