aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/SpaceToBatchNdLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/SpaceToBatchNdLayer.cpp')
-rw-r--r--src/armnn/layers/SpaceToBatchNdLayer.cpp31
1 files changed, 12 insertions, 19 deletions
diff --git a/src/armnn/layers/SpaceToBatchNdLayer.cpp b/src/armnn/layers/SpaceToBatchNdLayer.cpp
index 151b6a5301..a758617e2e 100644
--- a/src/armnn/layers/SpaceToBatchNdLayer.cpp
+++ b/src/armnn/layers/SpaceToBatchNdLayer.cpp
@@ -1,15 +1,11 @@
//
-// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017,2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include "SpaceToBatchNdLayer.hpp"
#include "LayerCloneBase.hpp"
-#include <armnn/TypesUtils.hpp>
-
-#include <armnnUtils/DataLayoutIndexed.hpp>
-
#include <armnn/backends/WorkloadData.hpp>
#include <armnn/backends/WorkloadFactory.hpp>
@@ -42,9 +38,7 @@ SpaceToBatchNdLayer* SpaceToBatchNdLayer::Clone(Graph& graph) const
std::vector<TensorShape> SpaceToBatchNdLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
{
- ARMNN_ASSERT(inputShapes.size() == 1);
-
- TensorShape inputShape = inputShapes[0];
+ const TensorShape inputShape = inputShapes[0];
TensorShape outputShape(inputShape);
outputShape[0] = inputShape[0] * std::accumulate(m_Param.m_BlockShape.begin(),
@@ -52,17 +46,16 @@ std::vector<TensorShape> SpaceToBatchNdLayer::InferOutputShapes(const std::vecto
1U,
std::multiplies<>());
- DataLayoutIndexed dimensionIndices = m_Param.m_DataLayout;
- unsigned int heightIndex = dimensionIndices.GetHeightIndex();
- unsigned int widthIndex = dimensionIndices.GetWidthIndex();
-
- std::pair<unsigned int, unsigned int> heightPad = m_Param.m_PadList[0];
- std::pair<unsigned int, unsigned int> widthPad = m_Param.m_PadList[1];
-
- outputShape[heightIndex] =
- (inputShape[heightIndex] + heightPad.first + heightPad.second) / m_Param.m_BlockShape[0];
- outputShape[widthIndex] =
- (inputShape[widthIndex] + widthPad.first + widthPad.second) / m_Param.m_BlockShape[1];
+ // In a 4D tensor, there will be 2 spatialDimensions (H and W), and the for loop will run twice.
+ // In a 3D tensor, there will be 1 spatialDimensions, and the for loop will run once.
+ unsigned int firstSpatialDimension = m_Param.m_DataLayout == DataLayout::NCHW ? 2 : 1;
+ for (unsigned int i = 0; i < m_Param.m_BlockShape.size(); ++i)
+ {
+ unsigned int spatialDimension = firstSpatialDimension + i;
+ outputShape[spatialDimension] =
+ (inputShape[spatialDimension] + m_Param.m_PadList[i].first + m_Param.m_PadList[i].second)
+ / m_Param.m_BlockShape[i];
+ }
return std::vector<TensorShape>({ outputShape });
}