aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/armnn/layers/StridedSliceLayer.cpp8
1 files changed, 8 insertions, 0 deletions
diff --git a/src/armnn/layers/StridedSliceLayer.cpp b/src/armnn/layers/StridedSliceLayer.cpp
index ae4fab0efd..fbe9815c06 100644
--- a/src/armnn/layers/StridedSliceLayer.cpp
+++ b/src/armnn/layers/StridedSliceLayer.cpp
@@ -49,6 +49,7 @@ std::vector<TensorShape> StridedSliceLayer::InferOutputShapes(
TensorShape inputShape = inputShapes[0];
std::vector<unsigned int> outputShape;
+ unsigned int amountDimShrunk{0};
for (unsigned int i = 0; i < inputShape.GetNumDimensions(); i++)
{
@@ -58,6 +59,8 @@ std::vector<TensorShape> StridedSliceLayer::InferOutputShapes(
if (m_Param.m_ShrinkAxisMask & (1 << i))
{
+ amountDimShrunk+=1;
+
// If the difference between the start point and the end point of the slice on an axis being shrunk
// is greater than 1 then throw an error as the output will not be large enough to hold the slice
if (((m_Param.m_Begin[i] - m_Param.m_End[i]) > 1) || ((m_Param.m_Begin[i] - m_Param.m_End[i]) < -1))
@@ -82,6 +85,11 @@ std::vector<TensorShape> StridedSliceLayer::InferOutputShapes(
outputShape.push_back(boost::numeric_cast<unsigned int>(newSize));
}
+ if (outputShape.size() == 0 && (inputShape.GetNumDimensions() - amountDimShrunk) == 0)
+ {
+ outputShape.push_back(1);
+ }
+
return std::vector<TensorShape>({
TensorShape(boost::numeric_cast<unsigned int>(outputShape.size()), &outputShape[0]) });
}