aboutsummaryrefslogtreecommitdiff
path: root/OutputShapeUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'OutputShapeUtils.cpp')
-rw-r--r--OutputShapeUtils.cpp24
1 files changed, 24 insertions, 0 deletions
diff --git a/OutputShapeUtils.cpp b/OutputShapeUtils.cpp
index 0c897d11..ecec0b92 100644
--- a/OutputShapeUtils.cpp
+++ b/OutputShapeUtils.cpp
@@ -8,6 +8,7 @@
#include <DataLayoutIndexed.hpp>
#include <algorithm>
+#include <numeric>
#include <vector>
namespace
@@ -167,6 +168,29 @@ TensorShape InferResizeOutputShape(const TensorShape& inputShape, const ResizeDe
return outputShape;
}
+TensorShape InferSpaceToDepthOutputShape(const TensorShape& inputShape, const SpaceToDepthDescriptor& descriptor)
+{
+ TensorShape outputShape(inputShape);
+
+ armnnUtils::DataLayoutIndexed dataLayoutIndexed(descriptor.m_DataLayout);
+
+ const unsigned int cIndex = dataLayoutIndexed.GetChannelsIndex();
+ const unsigned int wIndex = dataLayoutIndexed.GetWidthIndex();
+ const unsigned int hIndex = dataLayoutIndexed.GetHeightIndex();
+
+ if (descriptor.m_BlockSize == 0)
+ {
+ throw InvalidArgumentException("Block size must be greater than zero");
+ }
+
+ outputShape[cIndex] = inputShape[cIndex] * descriptor.m_BlockSize * descriptor.m_BlockSize;
+
+ outputShape[hIndex] = inputShape[hIndex] / descriptor.m_BlockSize;
+ outputShape[wIndex] = inputShape[wIndex] / descriptor.m_BlockSize;
+
+ return outputShape;
+}
+
TensorShape InferSubOutputShape(const TensorShape& input0Shape, const TensorShape& input1Shape)
{
return CalculateMaxShape(input0Shape, input1Shape);