diff options
Diffstat (limited to 'src/armnn/layers/BatchToSpaceNdLayer.cpp')
-rw-r--r-- | src/armnn/layers/BatchToSpaceNdLayer.cpp | 18 |
1 files changed, 7 insertions, 11 deletions
diff --git a/src/armnn/layers/BatchToSpaceNdLayer.cpp b/src/armnn/layers/BatchToSpaceNdLayer.cpp index 595ce4a7fe..9366a8710b 100644 --- a/src/armnn/layers/BatchToSpaceNdLayer.cpp +++ b/src/armnn/layers/BatchToSpaceNdLayer.cpp @@ -57,23 +57,19 @@ std::vector<TensorShape> BatchToSpaceNdLayer::InferOutputShapes(const std::vecto std::vector<unsigned int> theBlockShape = m_Param.m_BlockShape; - unsigned int overallSize = inBatchSize; + unsigned int overallSize = inBatchSize * inputShape[dataLayout.GetHeightIndex()] + * inputShape[dataLayout.GetWidthIndex()]; - for (unsigned int i = 0; i < theBlockShape.size(); ++i) - { - overallSize = overallSize * theBlockShape.at(i); - } - - std::vector<std::vector<unsigned int>> crops = m_Param.m_Crops; + std::vector<std::pair<unsigned int, unsigned int>> crops = m_Param.m_Crops; - std::vector<unsigned int> yCrops = crops[0]; - std::vector<unsigned int> xCrops = crops[1]; + std::pair<unsigned int, unsigned int> yCrops = crops[0]; + std::pair<unsigned int, unsigned int> xCrops = crops[1]; unsigned int inputHeight = inputShape[dataLayout.GetHeightIndex()]; - unsigned int outputHeight = theBlockShape.at(0) * (inputHeight - (yCrops[0] + yCrops[1])); + unsigned int outputHeight = theBlockShape.at(0) * (inputHeight - (yCrops.first + yCrops.second)); unsigned int inputWidth = inputShape[dataLayout.GetWidthIndex()]; - unsigned int outputWidth = theBlockShape.at(1) * (inputWidth - (xCrops[0] + xCrops[1])); + unsigned int outputWidth = theBlockShape.at(1) * (inputWidth - (xCrops.first + xCrops.second)); unsigned int outputBatchSize = overallSize / (outputHeight * outputWidth); |