aboutsummaryrefslogtreecommitdiff
path: root/OutputShapeUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'OutputShapeUtils.cpp')
-rw-r--r--OutputShapeUtils.cpp22
1 files changed, 22 insertions, 0 deletions
diff --git a/OutputShapeUtils.cpp b/OutputShapeUtils.cpp
index 285e25f4..6c936ee7 100644
--- a/OutputShapeUtils.cpp
+++ b/OutputShapeUtils.cpp
@@ -144,6 +144,28 @@ TensorShape InferPreluOutputShape(const TensorShape& inputShape, const TensorSha
return CalculateMaxShape(inputShape, alphaShape);
}
+TensorShape InferResizeOutputShape(const TensorShape& inputShape, const ResizeDescriptor& descriptor)
+{
+ if (inputShape.GetNumDimensions() != 4)
+ {
+ throw InvalidArgumentException("Input shape for Resize must be 4D");
+ }
+
+ armnnUtils::DataLayoutIndexed dataLayoutIndexed(descriptor.m_DataLayout);
+
+ const unsigned int cIndex = dataLayoutIndexed.GetChannelsIndex();
+ const unsigned int wIndex = dataLayoutIndexed.GetWidthIndex();
+ const unsigned int hIndex = dataLayoutIndexed.GetHeightIndex();
+
+ TensorShape outputShape(4);
+ outputShape[0] = inputShape[0];
+ outputShape[cIndex] = inputShape[cIndex];
+ outputShape[wIndex] = descriptor.m_TargetWidth;
+ outputShape[hIndex] = descriptor.m_TargetHeight;
+
+ return outputShape;
+}
+
TensorShape InferSubOutputShape(const TensorShape& input0Shape, const TensorShape& input1Shape)
{
return CalculateMaxShape(input0Shape, input1Shape);