aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/BatchToSpaceNdLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/BatchToSpaceNdLayer.cpp')
-rw-r--r--src/armnn/layers/BatchToSpaceNdLayer.cpp35
1 files changed, 33 insertions, 2 deletions
diff --git a/src/armnn/layers/BatchToSpaceNdLayer.cpp b/src/armnn/layers/BatchToSpaceNdLayer.cpp
index 9366a8710b..aff818e664 100644
--- a/src/armnn/layers/BatchToSpaceNdLayer.cpp
+++ b/src/armnn/layers/BatchToSpaceNdLayer.cpp
@@ -66,10 +66,41 @@ std::vector<TensorShape> BatchToSpaceNdLayer::InferOutputShapes(const std::vecto
std::pair<unsigned int, unsigned int> xCrops = crops[1];
unsigned int inputHeight = inputShape[dataLayout.GetHeightIndex()];
- unsigned int outputHeight = theBlockShape.at(0) * (inputHeight - (yCrops.first + yCrops.second));
+ unsigned int outputHeight;
+ unsigned int yCropsTotal = yCrops.first + yCrops.second;
+
+ BOOST_ASSERT_MSG(yCropsTotal <= inputHeight,
+ "BatchToSpaceLayer: Overall height crop should be less than or equal to the input height.");
+
+ unsigned int croppedHeight = inputHeight - yCropsTotal;
+
+ if (theBlockShape.at(0) > 0)
+ {
+ outputHeight = theBlockShape.at(0) * croppedHeight;
+ }
+ else
+ {
+ outputHeight = croppedHeight;
+ }
+
+ unsigned int outputWidth;
unsigned int inputWidth = inputShape[dataLayout.GetWidthIndex()];
- unsigned int outputWidth = theBlockShape.at(1) * (inputWidth - (xCrops.first + xCrops.second));
+
+ unsigned int xCropsTotal = xCrops.first + xCrops.second;
+
+ BOOST_ASSERT_MSG(xCropsTotal <= inputWidth,
+ "BatchToSpaceLayer: Overall width crop should be less than or equal to the input width.");
+ unsigned int croppedWidth = inputWidth - xCropsTotal;
+
+ if (theBlockShape.at(1) > 0)
+ {
+ outputWidth = theBlockShape.at(1) * croppedWidth;
+ }
+ else
+ {
+ outputWidth = croppedWidth;
+ }
unsigned int outputBatchSize = overallSize / (outputHeight * outputWidth);