aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/Convolution2dLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/Convolution2dLayer.cpp')
-rw-r--r--src/armnn/layers/Convolution2dLayer.cpp20
1 files changed, 13 insertions, 7 deletions
diff --git a/src/armnn/layers/Convolution2dLayer.cpp b/src/armnn/layers/Convolution2dLayer.cpp
index d4b67cca3f..d611aedc06 100644
--- a/src/armnn/layers/Convolution2dLayer.cpp
+++ b/src/armnn/layers/Convolution2dLayer.cpp
@@ -58,22 +58,28 @@ std::vector<TensorShape> Convolution2dLayer::InferOutputShapes(const std::vector
// If we support multiple batch dimensions in the future, then this assert will need to change.
BOOST_ASSERT_MSG(inputShape.GetNumDimensions() == 4, "Convolutions will always have 4D input.");
- unsigned int inWidth = inputShape[3];
- unsigned int inHeight = inputShape[2];
+ DataLayoutIndexed dataLayoutIndex(m_Param.m_DataLayout);
+
+ unsigned int inWidth = inputShape[dataLayoutIndex.GetWidthIndex()];
+ unsigned int inHeight = inputShape[dataLayoutIndex.GetHeightIndex()];
unsigned int inBatchSize = inputShape[0];
- unsigned int filterWidth = filterShape[3];
+ unsigned int filterWidth = filterShape[dataLayoutIndex.GetWidthIndex()];
unsigned int readWidth = (inWidth + m_Param.m_PadLeft + m_Param.m_PadRight) - (filterWidth);
- unsigned int outWidth = 1+(readWidth / m_Param.m_StrideX);
+ unsigned int outWidth = 1 + (readWidth / m_Param.m_StrideX);
- unsigned int filterHeight = filterShape[2];
+ unsigned int filterHeight = filterShape[dataLayoutIndex.GetHeightIndex()];
unsigned int readHeight = (inHeight + m_Param.m_PadTop + m_Param.m_PadBottom) - (filterHeight);
- unsigned int outHeight = 1+(readHeight / m_Param.m_StrideY);
+ unsigned int outHeight = 1 + (readHeight / m_Param.m_StrideY);
unsigned int outChannels = filterShape[0];
unsigned int outBatchSize = inBatchSize;
- return std::vector<TensorShape>({ TensorShape({outBatchSize, outChannels, outHeight, outWidth})});
+ TensorShape tensorShape = m_Param.m_DataLayout == armnn::DataLayout::NHWC ?
+ TensorShape( { outBatchSize, outHeight, outWidth, outChannels } ) :
+ TensorShape( { outBatchSize, outChannels, outHeight, outWidth });
+
+ return std::vector<TensorShape>({ tensorShape });
}
void Convolution2dLayer::ValidateTensorShapesFromInputs()