diff options
Diffstat (limited to 'OutputShapeUtils.cpp')
-rw-r--r-- | OutputShapeUtils.cpp | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/OutputShapeUtils.cpp b/OutputShapeUtils.cpp index 0c897d11..ecec0b92 100644 --- a/OutputShapeUtils.cpp +++ b/OutputShapeUtils.cpp @@ -8,6 +8,7 @@ #include <DataLayoutIndexed.hpp> #include <algorithm> +#include <numeric> #include <vector> namespace @@ -167,6 +168,29 @@ TensorShape InferResizeOutputShape(const TensorShape& inputShape, const ResizeDe return outputShape; } +TensorShape InferSpaceToDepthOutputShape(const TensorShape& inputShape, const SpaceToDepthDescriptor& descriptor) +{ + TensorShape outputShape(inputShape); + + armnnUtils::DataLayoutIndexed dataLayoutIndexed(descriptor.m_DataLayout); + + const unsigned int cIndex = dataLayoutIndexed.GetChannelsIndex(); + const unsigned int wIndex = dataLayoutIndexed.GetWidthIndex(); + const unsigned int hIndex = dataLayoutIndexed.GetHeightIndex(); + + if (descriptor.m_BlockSize == 0) + { + throw InvalidArgumentException("Block size must be greater than zero"); + } + + outputShape[cIndex] = inputShape[cIndex] * descriptor.m_BlockSize * descriptor.m_BlockSize; + + outputShape[hIndex] = inputShape[hIndex] / descriptor.m_BlockSize; + outputShape[wIndex] = inputShape[wIndex] / descriptor.m_BlockSize; + + return outputShape; +} + TensorShape InferSubOutputShape(const TensorShape& input0Shape, const TensorShape& input1Shape) { return CalculateMaxShape(input0Shape, input1Shape); |