aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRyan OShea <Ryan.OShea2@arm.com>2020-07-06 11:45:50 +0100
committerRyan O'Shea <ryan.oshea2@arm.com>2020-07-06 16:21:33 +0000
commit06deacd58fbd4fbfd4884ab8024ef736f4f7105b (patch)
tree8c1c1d63bb43364ca05e6a408d41a02a14f79b73
parentb7c1831f95a5ecdde0fff068d4054a066309cec6 (diff)
downloadarmnn-06deacd58fbd4fbfd4884ab8024ef736f4f7105b.tar.gz
IVGCVSW-4919 Strided Slice 0 Dimension Tensor Fix
* Add check Axis' shrunk to 0 dimensions Signed-off-by: Ryan OShea <Ryan.OShea2@arm.com> Change-Id: Ic2544f7538d2df4a561f88ce8909533424fa2a25
-rw-r--r--src/armnn/layers/StridedSliceLayer.cpp8
1 files changed, 8 insertions, 0 deletions
diff --git a/src/armnn/layers/StridedSliceLayer.cpp b/src/armnn/layers/StridedSliceLayer.cpp
index ae4fab0efd..fbe9815c06 100644
--- a/src/armnn/layers/StridedSliceLayer.cpp
+++ b/src/armnn/layers/StridedSliceLayer.cpp
@@ -49,6 +49,7 @@ std::vector<TensorShape> StridedSliceLayer::InferOutputShapes(
TensorShape inputShape = inputShapes[0];
std::vector<unsigned int> outputShape;
+ unsigned int amountDimShrunk{0};
for (unsigned int i = 0; i < inputShape.GetNumDimensions(); i++)
{
@@ -58,6 +59,8 @@ std::vector<TensorShape> StridedSliceLayer::InferOutputShapes(
if (m_Param.m_ShrinkAxisMask & (1 << i))
{
+ amountDimShrunk+=1;
+
// If the difference between the start point and the end point of the slice on an axis being shrunk
// is greater than 1 then throw an error as the output will not be large enough to hold the slice
if (((m_Param.m_Begin[i] - m_Param.m_End[i]) > 1) || ((m_Param.m_Begin[i] - m_Param.m_End[i]) < -1))
@@ -82,6 +85,11 @@ std::vector<TensorShape> StridedSliceLayer::InferOutputShapes(
outputShape.push_back(boost::numeric_cast<unsigned int>(newSize));
}
+ if (outputShape.size() == 0 && (inputShape.GetNumDimensions() - amountDimShrunk) == 0)
+ {
+ outputShape.push_back(1);
+ }
+
return std::vector<TensorShape>({
TensorShape(boost::numeric_cast<unsigned int>(outputShape.size()), &outputShape[0]) });
}