aboutsummaryrefslogtreecommitdiff
path: root/OutputShapeUtils.cpp
diff options
context:
space:
mode:
authorAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-07-16 11:32:29 +0100
committerÁron Virginás-Tar <aron.virginas-tar@arm.com>2019-07-16 13:31:47 +0000
commitbe5d356c1663009c69c934fd860db1f91863e3d1 (patch)
tree603c9b744147c6875998beadb931afec00e8d3a9 /OutputShapeUtils.cpp
parent9fd373954d64fbae72d1726bbdfc57a18a3a2f6d (diff)
downloadandroid-nn-driver-be5d356c1663009c69c934fd860db1f91863e3d1.tar.gz
IVGCVSW-3522 Support dynamic output shape in hal_1_2::HalPolicy::ConvertResize
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> Change-Id: I962f9759679f539566f7bc3aa75ed3e0bffe7c9f
Diffstat (limited to 'OutputShapeUtils.cpp')
-rw-r--r--OutputShapeUtils.cpp22
1 files changed, 22 insertions, 0 deletions
diff --git a/OutputShapeUtils.cpp b/OutputShapeUtils.cpp
index 285e25f4..6c936ee7 100644
--- a/OutputShapeUtils.cpp
+++ b/OutputShapeUtils.cpp
@@ -144,6 +144,28 @@ TensorShape InferPreluOutputShape(const TensorShape& inputShape, const TensorSha
return CalculateMaxShape(inputShape, alphaShape);
}
+TensorShape InferResizeOutputShape(const TensorShape& inputShape, const ResizeDescriptor& descriptor)
+{
+ if (inputShape.GetNumDimensions() != 4)
+ {
+ throw InvalidArgumentException("Input shape for Resize must be 4D");
+ }
+
+ armnnUtils::DataLayoutIndexed dataLayoutIndexed(descriptor.m_DataLayout);
+
+ const unsigned int cIndex = dataLayoutIndexed.GetChannelsIndex();
+ const unsigned int wIndex = dataLayoutIndexed.GetWidthIndex();
+ const unsigned int hIndex = dataLayoutIndexed.GetHeightIndex();
+
+ TensorShape outputShape(4);
+ outputShape[0] = inputShape[0];
+ outputShape[cIndex] = inputShape[cIndex];
+ outputShape[wIndex] = descriptor.m_TargetWidth;
+ outputShape[hIndex] = descriptor.m_TargetHeight;
+
+ return outputShape;
+}
+
TensorShape InferSubOutputShape(const TensorShape& input0Shape, const TensorShape& input1Shape)
{
return CalculateMaxShape(input0Shape, input1Shape);