diff options
author | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2019-07-25 11:24:42 +0100 |
---|---|---|
committer | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2019-07-25 11:31:48 +0100 |
commit | ad1ab53f2898862e82f9b354853764fdcd1df97d (patch) | |
tree | 9df7f2c13686473a84d85988bcc21cc938002ad3 /OutputShapeUtils.cpp | |
parent | 6111316eb609bd71589b963cf6fc56b18ba3d241 (diff) | |
download | android-nn-driver-ad1ab53f2898862e82f9b354853764fdcd1df97d.tar.gz |
IVGCVSW-3569 Fix conversion of HAL1.2 SpaceToDepth
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Change-Id: I82e38c8a9e44e773c099e347f8ce0070bb5f8662
Diffstat (limited to 'OutputShapeUtils.cpp')
-rw-r--r-- | OutputShapeUtils.cpp | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/OutputShapeUtils.cpp b/OutputShapeUtils.cpp index 0c897d11..ecec0b92 100644 --- a/OutputShapeUtils.cpp +++ b/OutputShapeUtils.cpp @@ -8,6 +8,7 @@ #include <DataLayoutIndexed.hpp> #include <algorithm> +#include <numeric> #include <vector> namespace @@ -167,6 +168,29 @@ TensorShape InferResizeOutputShape(const TensorShape& inputShape, const ResizeDe return outputShape; } +TensorShape InferSpaceToDepthOutputShape(const TensorShape& inputShape, const SpaceToDepthDescriptor& descriptor) +{ + TensorShape outputShape(inputShape); + + armnnUtils::DataLayoutIndexed dataLayoutIndexed(descriptor.m_DataLayout); + + const unsigned int cIndex = dataLayoutIndexed.GetChannelsIndex(); + const unsigned int wIndex = dataLayoutIndexed.GetWidthIndex(); + const unsigned int hIndex = dataLayoutIndexed.GetHeightIndex(); + + if (descriptor.m_BlockSize == 0) + { + throw InvalidArgumentException("Block size must be greater than zero"); + } + + outputShape[cIndex] = inputShape[cIndex] * descriptor.m_BlockSize * descriptor.m_BlockSize; + + outputShape[hIndex] = inputShape[hIndex] / descriptor.m_BlockSize; + outputShape[wIndex] = inputShape[wIndex] / descriptor.m_BlockSize; + + return outputShape; +} + TensorShape InferSubOutputShape(const TensorShape& input0Shape, const TensorShape& input1Shape) { return CalculateMaxShape(input0Shape, input1Shape); |