From 95807cef855738ca481ace30f32ed9f245a098dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89anna=20=C3=93=20Cath=C3=A1in?= Date: Mon, 12 Nov 2018 17:14:43 +0000 Subject: Tidying up multiple issues * Fixed error in InferOutputShape implementation * Added better error checking to the BatchToSpace implementation. * Added defaults to the batchToSpace descriptors. * Changed crops to be a vector of pairs to align with the SpaceToBatch implementation Change-Id: Ib1c16d871f0898a1caeb6629c1fee6380a773e14 --- src/armnn/layers/BatchToSpaceNdLayer.cpp | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) (limited to 'src/armnn') diff --git a/src/armnn/layers/BatchToSpaceNdLayer.cpp b/src/armnn/layers/BatchToSpaceNdLayer.cpp index 595ce4a7fe..9366a8710b 100644 --- a/src/armnn/layers/BatchToSpaceNdLayer.cpp +++ b/src/armnn/layers/BatchToSpaceNdLayer.cpp @@ -57,23 +57,19 @@ std::vector BatchToSpaceNdLayer::InferOutputShapes(const std::vecto std::vector theBlockShape = m_Param.m_BlockShape; - unsigned int overallSize = inBatchSize; + unsigned int overallSize = inBatchSize * inputShape[dataLayout.GetHeightIndex()] + * inputShape[dataLayout.GetWidthIndex()]; - for (unsigned int i = 0; i < theBlockShape.size(); ++i) - { - overallSize = overallSize * theBlockShape.at(i); - } - - std::vector> crops = m_Param.m_Crops; + std::vector> crops = m_Param.m_Crops; - std::vector yCrops = crops[0]; - std::vector xCrops = crops[1]; + std::pair yCrops = crops[0]; + std::pair xCrops = crops[1]; unsigned int inputHeight = inputShape[dataLayout.GetHeightIndex()]; - unsigned int outputHeight = theBlockShape.at(0) * (inputHeight - (yCrops[0] + yCrops[1])); + unsigned int outputHeight = theBlockShape.at(0) * (inputHeight - (yCrops.first + yCrops.second)); unsigned int inputWidth = inputShape[dataLayout.GetWidthIndex()]; - unsigned int outputWidth = theBlockShape.at(1) * (inputWidth - (xCrops[0] + xCrops[1])); + unsigned int outputWidth = theBlockShape.at(1) * (inputWidth - (xCrops.first + xCrops.second)); unsigned int outputBatchSize = overallSize / (outputHeight * outputWidth); -- cgit v1.2.1