diff options
author | Ryan OShea <Ryan.OShea2@arm.com> | 2020-07-06 11:45:50 +0100 |
---|---|---|
committer | Ryan O'Shea <ryan.oshea2@arm.com> | 2020-07-06 16:21:33 +0000 |
commit | 06deacd58fbd4fbfd4884ab8024ef736f4f7105b (patch) | |
tree | 8c1c1d63bb43364ca05e6a408d41a02a14f79b73 /src/armnn | |
parent | b7c1831f95a5ecdde0fff068d4054a066309cec6 (diff) | |
download | armnn-06deacd58fbd4fbfd4884ab8024ef736f4f7105b.tar.gz |
IVGCVSW-4919 Strided Slice 0 Dimension Tensor Fix
* Add check Axis' shrunk to 0 dimensions
Signed-off-by: Ryan OShea <Ryan.OShea2@arm.com>
Change-Id: Ic2544f7538d2df4a561f88ce8909533424fa2a25
Diffstat (limited to 'src/armnn')
-rw-r--r-- | src/armnn/layers/StridedSliceLayer.cpp | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/src/armnn/layers/StridedSliceLayer.cpp b/src/armnn/layers/StridedSliceLayer.cpp index ae4fab0efd..fbe9815c06 100644 --- a/src/armnn/layers/StridedSliceLayer.cpp +++ b/src/armnn/layers/StridedSliceLayer.cpp @@ -49,6 +49,7 @@ std::vector<TensorShape> StridedSliceLayer::InferOutputShapes( TensorShape inputShape = inputShapes[0]; std::vector<unsigned int> outputShape; + unsigned int amountDimShrunk{0}; for (unsigned int i = 0; i < inputShape.GetNumDimensions(); i++) { @@ -58,6 +59,8 @@ std::vector<TensorShape> StridedSliceLayer::InferOutputShapes( if (m_Param.m_ShrinkAxisMask & (1 << i)) { + amountDimShrunk+=1; + // If the difference between the start point and the end point of the slice on an axis being shrunk // is greater than 1 then throw an error as the output will not be large enough to hold the slice if (((m_Param.m_Begin[i] - m_Param.m_End[i]) > 1) || ((m_Param.m_Begin[i] - m_Param.m_End[i]) < -1)) @@ -82,6 +85,11 @@ std::vector<TensorShape> StridedSliceLayer::InferOutputShapes( outputShape.push_back(boost::numeric_cast<unsigned int>(newSize)); } + if (outputShape.size() == 0 && (inputShape.GetNumDimensions() - amountDimShrunk) == 0) + { + outputShape.push_back(1); + } + return std::vector<TensorShape>({ TensorShape(boost::numeric_cast<unsigned int>(outputShape.size()), &outputShape[0]) }); } |