aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/BatchToSpaceNdLayer.cpp
diff options
context:
space:
mode:
authorÉanna Ó Catháin <eanna.ocathain@arm.com>2018-11-12 17:14:43 +0000
committerSaoirse Stewart Arm <saoirse.stewart@arm.com>2018-11-12 21:59:48 +0000
commit95807cef855738ca481ace30f32ed9f245a098dd (patch)
tree7df7181d6dc19f3db3054614076478af4f417e8a /src/armnn/layers/BatchToSpaceNdLayer.cpp
parent111b5d94d7e854c21377f8d2c0b4234317a903f6 (diff)
downloadarmnn-95807cef855738ca481ace30f32ed9f245a098dd.tar.gz
Tidying up multiple issues
* Fixed error in InferOutputShape implementation * Added better error checking to the BatchToSpace implementation. * Added defaults to the batchToSpace descriptors. * Changed crops to be a vector of pairs to align with the SpaceToBatch implementation Change-Id: Ib1c16d871f0898a1caeb6629c1fee6380a773e14
Diffstat (limited to 'src/armnn/layers/BatchToSpaceNdLayer.cpp')
-rw-r--r--src/armnn/layers/BatchToSpaceNdLayer.cpp18
1 files changed, 7 insertions, 11 deletions
diff --git a/src/armnn/layers/BatchToSpaceNdLayer.cpp b/src/armnn/layers/BatchToSpaceNdLayer.cpp
index 595ce4a7fe..9366a8710b 100644
--- a/src/armnn/layers/BatchToSpaceNdLayer.cpp
+++ b/src/armnn/layers/BatchToSpaceNdLayer.cpp
@@ -57,23 +57,19 @@ std::vector<TensorShape> BatchToSpaceNdLayer::InferOutputShapes(const std::vecto
std::vector<unsigned int> theBlockShape = m_Param.m_BlockShape;
- unsigned int overallSize = inBatchSize;
+ unsigned int overallSize = inBatchSize * inputShape[dataLayout.GetHeightIndex()]
+ * inputShape[dataLayout.GetWidthIndex()];
- for (unsigned int i = 0; i < theBlockShape.size(); ++i)
- {
- overallSize = overallSize * theBlockShape.at(i);
- }
-
- std::vector<std::vector<unsigned int>> crops = m_Param.m_Crops;
+ std::vector<std::pair<unsigned int, unsigned int>> crops = m_Param.m_Crops;
- std::vector<unsigned int> yCrops = crops[0];
- std::vector<unsigned int> xCrops = crops[1];
+ std::pair<unsigned int, unsigned int> yCrops = crops[0];
+ std::pair<unsigned int, unsigned int> xCrops = crops[1];
unsigned int inputHeight = inputShape[dataLayout.GetHeightIndex()];
- unsigned int outputHeight = theBlockShape.at(0) * (inputHeight - (yCrops[0] + yCrops[1]));
+ unsigned int outputHeight = theBlockShape.at(0) * (inputHeight - (yCrops.first + yCrops.second));
unsigned int inputWidth = inputShape[dataLayout.GetWidthIndex()];
- unsigned int outputWidth = theBlockShape.at(1) * (inputWidth - (xCrops[0] + xCrops[1]));
+ unsigned int outputWidth = theBlockShape.at(1) * (inputWidth - (xCrops.first + xCrops.second));
unsigned int outputBatchSize = overallSize / (outputHeight * outputWidth);