aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/BatchToSpaceNdLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/BatchToSpaceNdLayer.cpp')
-rw-r--r--src/armnn/layers/BatchToSpaceNdLayer.cpp76
1 files changed, 25 insertions, 51 deletions
diff --git a/src/armnn/layers/BatchToSpaceNdLayer.cpp b/src/armnn/layers/BatchToSpaceNdLayer.cpp
index 9accf28137..a168fe8bbd 100644
--- a/src/armnn/layers/BatchToSpaceNdLayer.cpp
+++ b/src/armnn/layers/BatchToSpaceNdLayer.cpp
@@ -16,6 +16,8 @@
#include <DataLayoutIndexed.hpp>
+#include <numeric>
+
using namespace armnnUtils;
namespace armnn
@@ -55,68 +57,40 @@ void BatchToSpaceNdLayer::ValidateTensorShapesFromInputs()
std::vector<TensorShape> BatchToSpaceNdLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
{
- const DataLayoutIndexed dataLayout = m_Param.m_DataLayout;
- const TensorShape& inputShape = inputShapes[0];
- unsigned int inBatchSize = inputShape[0];
- unsigned int channelSize = inputShape[dataLayout.GetChannelsIndex()];
-
- std::vector<unsigned int> theBlockShape = m_Param.m_BlockShape;
-
- unsigned int overallSize = inBatchSize * inputShape[dataLayout.GetHeightIndex()]
- * inputShape[dataLayout.GetWidthIndex()];
+ BOOST_ASSERT(inputShapes.size() == 1);
- std::vector<std::pair<unsigned int, unsigned int>> crops = m_Param.m_Crops;
-
- std::pair<unsigned int, unsigned int> yCrops = crops[0];
- std::pair<unsigned int, unsigned int> xCrops = crops[1];
-
- unsigned int inputHeight = inputShape[dataLayout.GetHeightIndex()];
- unsigned int outputHeight;
+ const TensorShape& inputShape = inputShapes[0];
+ TensorShape outputShape(inputShape);
- unsigned int yCropsTotal = yCrops.first + yCrops.second;
+ unsigned int accumulatedBlockShape = std::accumulate(m_Param.m_BlockShape.begin(),
+ m_Param.m_BlockShape.end(),
+ 1U,
+ std::multiplies<>());
- BOOST_ASSERT_MSG(yCropsTotal <= inputHeight,
- "BatchToSpaceLayer: Overall height crop should be less than or equal to the input height.");
+ BOOST_ASSERT(inputShape[0] % accumulatedBlockShape == 0);
- unsigned int croppedHeight = inputHeight - yCropsTotal;
+ outputShape[0] = inputShape[0] / accumulatedBlockShape;
- if (theBlockShape.at(0) > 0)
- {
- outputHeight = theBlockShape.at(0) * croppedHeight;
- }
- else
- {
- outputHeight = croppedHeight;
- }
+ DataLayoutIndexed dimensionIndices = m_Param.m_DataLayout;
+ unsigned int heightIndex = dimensionIndices.GetHeightIndex();
+ unsigned int widthIndex = dimensionIndices.GetWidthIndex();
- unsigned int outputWidth;
- unsigned int inputWidth = inputShape[dataLayout.GetWidthIndex()];
+ unsigned int heightCrop = m_Param.m_Crops[0].first + m_Param.m_Crops[0].second;
+ unsigned int widthCrop = m_Param.m_Crops[1].first + m_Param.m_Crops[1].second;
- unsigned int xCropsTotal = xCrops.first + xCrops.second;
+ unsigned int outputHeight = inputShape[heightIndex] * m_Param.m_BlockShape[0];
+ unsigned int outputWidth = inputShape[widthIndex] * m_Param.m_BlockShape[1];
- BOOST_ASSERT_MSG(xCropsTotal <= inputWidth,
- "BatchToSpaceLayer: Overall width crop should be less than or equal to the input width.");
- unsigned int croppedWidth = inputWidth - xCropsTotal;
+ BOOST_ASSERT_MSG(heightCrop <= outputHeight,
+ "BatchToSpaceLayer: Overall height crop should be less than or equal to the uncropped output height.");
- if (theBlockShape.at(1) > 0)
- {
- outputWidth = theBlockShape.at(1) * croppedWidth;
- }
- else
- {
- outputWidth = croppedWidth;
- }
+ BOOST_ASSERT_MSG(widthCrop <= outputWidth,
+ "BatchToSpaceLayer: Overall width crop should be less than or equal to the uncropped output width.");
- unsigned int outputBatchSize = overallSize / (outputHeight * outputWidth);
+ outputShape[heightIndex] = outputHeight - heightCrop;
+ outputShape[widthIndex] = outputWidth - widthCrop;
- if (dataLayout == DataLayout::NHWC)
- {
- return std::vector<TensorShape>({ TensorShape({ outputBatchSize, outputHeight, outputWidth, channelSize }) });
- }
- else
- {
- return std::vector<TensorShape>({ TensorShape({ outputBatchSize, channelSize, outputHeight, outputWidth }) });
- }
+ return std::vector<TensorShape>({ outputShape });
}
void BatchToSpaceNdLayer::Accept(ILayerVisitor& visitor) const