diff options
Diffstat (limited to 'src/armnn/layers/StridedSliceLayer.cpp')
-rw-r--r-- | src/armnn/layers/StridedSliceLayer.cpp | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/src/armnn/layers/StridedSliceLayer.cpp b/src/armnn/layers/StridedSliceLayer.cpp index ae4fab0efd..fbe9815c06 100644 --- a/src/armnn/layers/StridedSliceLayer.cpp +++ b/src/armnn/layers/StridedSliceLayer.cpp @@ -49,6 +49,7 @@ std::vector<TensorShape> StridedSliceLayer::InferOutputShapes( TensorShape inputShape = inputShapes[0]; std::vector<unsigned int> outputShape; + unsigned int amountDimShrunk{0}; for (unsigned int i = 0; i < inputShape.GetNumDimensions(); i++) { @@ -58,6 +59,8 @@ std::vector<TensorShape> StridedSliceLayer::InferOutputShapes( if (m_Param.m_ShrinkAxisMask & (1 << i)) { + amountDimShrunk+=1; + // If the difference between the start point and the end point of the slice on an axis being shrunk // is greater than 1 then throw an error as the output will not be large enough to hold the slice if (((m_Param.m_Begin[i] - m_Param.m_End[i]) > 1) || ((m_Param.m_Begin[i] - m_Param.m_End[i]) < -1)) @@ -82,6 +85,11 @@ std::vector<TensorShape> StridedSliceLayer::InferOutputShapes( outputShape.push_back(boost::numeric_cast<unsigned int>(newSize)); } + if (outputShape.size() == 0 && (inputShape.GetNumDimensions() - amountDimShrunk) == 0) + { + outputShape.push_back(1); + } + return std::vector<TensorShape>({ TensorShape(boost::numeric_cast<unsigned int>(outputShape.size()), &outputShape[0]) }); } |