aboutsummaryrefslogtreecommitdiff
path: root/src/armnn
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn')
-rw-r--r--src/armnn/layers/BatchToSpaceNdLayer.cpp76
-rw-r--r--src/armnn/test/LayerValidateOutputTest.cpp10
2 files changed, 30 insertions, 56 deletions
diff --git a/src/armnn/layers/BatchToSpaceNdLayer.cpp b/src/armnn/layers/BatchToSpaceNdLayer.cpp
index 9accf28137..a168fe8bbd 100644
--- a/src/armnn/layers/BatchToSpaceNdLayer.cpp
+++ b/src/armnn/layers/BatchToSpaceNdLayer.cpp
@@ -16,6 +16,8 @@
#include <DataLayoutIndexed.hpp>
+#include <numeric>
+
using namespace armnnUtils;
namespace armnn
@@ -55,68 +57,40 @@ void BatchToSpaceNdLayer::ValidateTensorShapesFromInputs()
std::vector<TensorShape> BatchToSpaceNdLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
{
- const DataLayoutIndexed dataLayout = m_Param.m_DataLayout;
- const TensorShape& inputShape = inputShapes[0];
- unsigned int inBatchSize = inputShape[0];
- unsigned int channelSize = inputShape[dataLayout.GetChannelsIndex()];
-
- std::vector<unsigned int> theBlockShape = m_Param.m_BlockShape;
-
- unsigned int overallSize = inBatchSize * inputShape[dataLayout.GetHeightIndex()]
- * inputShape[dataLayout.GetWidthIndex()];
+ BOOST_ASSERT(inputShapes.size() == 1);
- std::vector<std::pair<unsigned int, unsigned int>> crops = m_Param.m_Crops;
-
- std::pair<unsigned int, unsigned int> yCrops = crops[0];
- std::pair<unsigned int, unsigned int> xCrops = crops[1];
-
- unsigned int inputHeight = inputShape[dataLayout.GetHeightIndex()];
- unsigned int outputHeight;
+ const TensorShape& inputShape = inputShapes[0];
+ TensorShape outputShape(inputShape);
- unsigned int yCropsTotal = yCrops.first + yCrops.second;
+ unsigned int accumulatedBlockShape = std::accumulate(m_Param.m_BlockShape.begin(),
+ m_Param.m_BlockShape.end(),
+ 1U,
+ std::multiplies<>());
- BOOST_ASSERT_MSG(yCropsTotal <= inputHeight,
- "BatchToSpaceLayer: Overall height crop should be less than or equal to the input height.");
+ BOOST_ASSERT(inputShape[0] % accumulatedBlockShape == 0);
- unsigned int croppedHeight = inputHeight - yCropsTotal;
+ outputShape[0] = inputShape[0] / accumulatedBlockShape;
- if (theBlockShape.at(0) > 0)
- {
- outputHeight = theBlockShape.at(0) * croppedHeight;
- }
- else
- {
- outputHeight = croppedHeight;
- }
+ DataLayoutIndexed dimensionIndices = m_Param.m_DataLayout;
+ unsigned int heightIndex = dimensionIndices.GetHeightIndex();
+ unsigned int widthIndex = dimensionIndices.GetWidthIndex();
- unsigned int outputWidth;
- unsigned int inputWidth = inputShape[dataLayout.GetWidthIndex()];
+ unsigned int heightCrop = m_Param.m_Crops[0].first + m_Param.m_Crops[0].second;
+ unsigned int widthCrop = m_Param.m_Crops[1].first + m_Param.m_Crops[1].second;
- unsigned int xCropsTotal = xCrops.first + xCrops.second;
+ unsigned int outputHeight = inputShape[heightIndex] * m_Param.m_BlockShape[0];
+ unsigned int outputWidth = inputShape[widthIndex] * m_Param.m_BlockShape[1];
- BOOST_ASSERT_MSG(xCropsTotal <= inputWidth,
- "BatchToSpaceLayer: Overall width crop should be less than or equal to the input width.");
- unsigned int croppedWidth = inputWidth - xCropsTotal;
+ BOOST_ASSERT_MSG(heightCrop <= outputHeight,
+ "BatchToSpaceLayer: Overall height crop should be less than or equal to the uncropped output height.");
- if (theBlockShape.at(1) > 0)
- {
- outputWidth = theBlockShape.at(1) * croppedWidth;
- }
- else
- {
- outputWidth = croppedWidth;
- }
+ BOOST_ASSERT_MSG(widthCrop <= outputWidth,
+ "BatchToSpaceLayer: Overall width crop should be less than or equal to the uncropped output width.");
- unsigned int outputBatchSize = overallSize / (outputHeight * outputWidth);
+ outputShape[heightIndex] = outputHeight - heightCrop;
+ outputShape[widthIndex] = outputWidth - widthCrop;
- if (dataLayout == DataLayout::NHWC)
- {
- return std::vector<TensorShape>({ TensorShape({ outputBatchSize, outputHeight, outputWidth, channelSize }) });
- }
- else
- {
- return std::vector<TensorShape>({ TensorShape({ outputBatchSize, channelSize, outputHeight, outputWidth }) });
- }
+ return std::vector<TensorShape>({ outputShape });
}
void BatchToSpaceNdLayer::Accept(ILayerVisitor& visitor) const
diff --git a/src/armnn/test/LayerValidateOutputTest.cpp b/src/armnn/test/LayerValidateOutputTest.cpp
index 62b9c4a0d8..999844e252 100644
--- a/src/armnn/test/LayerValidateOutputTest.cpp
+++ b/src/armnn/test/LayerValidateOutputTest.cpp
@@ -17,22 +17,22 @@ BOOST_AUTO_TEST_CASE(TestBatchToSpaceInferOutputShape)
armnn::Graph graph;
armnn::BatchToSpaceNdDescriptor descriptor;
- std::vector<unsigned int> theBlockShape = {2, 2};
- descriptor.m_BlockShape = theBlockShape;
+ descriptor.m_BlockShape = {2, 2};
+ descriptor.m_Crops = {{0, 0}, {2, 0}};
descriptor.m_DataLayout = armnn::DataLayout::NHWC;
armnn::BatchToSpaceNdLayer* const batchToSpaceLayer =
graph.AddLayer<armnn::BatchToSpaceNdLayer>(descriptor, "batchToSpace");
std::vector<armnn::TensorShape> shapes;
- const std::vector<unsigned int> theDimSizes = {4, 2, 2, 1};
+ const std::vector<unsigned int> theDimSizes = {8, 1, 3, 1};
armnn::TensorShape shape(4, theDimSizes.data());
shapes.push_back(shape);
- const std::vector<unsigned int> expectedDimSizes = {1, 4, 4, 1};
+ const std::vector<unsigned int> expectedDimSizes = {2, 2, 4, 1};
armnn::TensorShape expectedShape(4, expectedDimSizes.data());
BOOST_CHECK(expectedShape == batchToSpaceLayer->InferOutputShapes(shapes).at(0));
}
-BOOST_AUTO_TEST_SUITE_END() \ No newline at end of file
+BOOST_AUTO_TEST_SUITE_END()