aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/BatchToSpaceNdLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/BatchToSpaceNdLayer.cpp')
-rw-r--r--src/armnn/layers/BatchToSpaceNdLayer.cpp18
1 files changed, 7 insertions, 11 deletions
diff --git a/src/armnn/layers/BatchToSpaceNdLayer.cpp b/src/armnn/layers/BatchToSpaceNdLayer.cpp
index 595ce4a7fe..9366a8710b 100644
--- a/src/armnn/layers/BatchToSpaceNdLayer.cpp
+++ b/src/armnn/layers/BatchToSpaceNdLayer.cpp
@@ -57,23 +57,19 @@ std::vector<TensorShape> BatchToSpaceNdLayer::InferOutputShapes(const std::vecto
std::vector<unsigned int> theBlockShape = m_Param.m_BlockShape;
- unsigned int overallSize = inBatchSize;
+ unsigned int overallSize = inBatchSize * inputShape[dataLayout.GetHeightIndex()]
+ * inputShape[dataLayout.GetWidthIndex()];
- for (unsigned int i = 0; i < theBlockShape.size(); ++i)
- {
- overallSize = overallSize * theBlockShape.at(i);
- }
-
- std::vector<std::vector<unsigned int>> crops = m_Param.m_Crops;
+ std::vector<std::pair<unsigned int, unsigned int>> crops = m_Param.m_Crops;
- std::vector<unsigned int> yCrops = crops[0];
- std::vector<unsigned int> xCrops = crops[1];
+ std::pair<unsigned int, unsigned int> yCrops = crops[0];
+ std::pair<unsigned int, unsigned int> xCrops = crops[1];
unsigned int inputHeight = inputShape[dataLayout.GetHeightIndex()];
- unsigned int outputHeight = theBlockShape.at(0) * (inputHeight - (yCrops[0] + yCrops[1]));
+ unsigned int outputHeight = theBlockShape.at(0) * (inputHeight - (yCrops.first + yCrops.second));
unsigned int inputWidth = inputShape[dataLayout.GetWidthIndex()];
- unsigned int outputWidth = theBlockShape.at(1) * (inputWidth - (xCrops[0] + xCrops[1]));
+ unsigned int outputWidth = theBlockShape.at(1) * (inputWidth - (xCrops.first + xCrops.second));
unsigned int outputBatchSize = overallSize / (outputHeight * outputWidth);