aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ConversionUtils.hpp20
1 files changed, 20 insertions, 0 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index 15381338..8313d045 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -3599,6 +3599,26 @@ bool ConvertStridedSlice(const HalOperation& operation, const HalModel& model, C
return false;
}
+ // Check if slice can fit in a inferred output
+ armnn::TensorShape inputShape = inputInfo.GetShape();
+ for (unsigned int i = 0; i < inputShape.GetNumDimensions(); i++)
+ {
+ int stride = descriptor.m_Stride[i];
+ int start = descriptor.GetStartForAxis(inputShape, i);
+ int stop = descriptor.GetStopForAxis(inputShape, i, start);
+
+ if (descriptor.m_ShrinkAxisMask & (1 << i))
+ {
+ // If the difference between the start point and the end point of the slice on an axis being shrunk
+ // is greater than 1 then throw an error as the output will not be large enough to hold the slice
+ if (((descriptor.m_Begin[i] - descriptor.m_End[i]) > 1)
+ || ((descriptor.m_Begin[i] - descriptor.m_End[i]) < -1))
+ {
+ return Fail("%s: StridedSlice: Output will not be large enough to hold the slice", __func__);
+ }
+ }
+ }
+
armnn::IConnectableLayer* const layer = data.m_Network->AddStridedSliceLayer(descriptor);
assert(layer != nullptr);
input.Connect(layer->GetInputSlot(0));