diff options
-rw-r--r-- | ConversionUtils.hpp | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp index 15381338..8313d045 100644 --- a/ConversionUtils.hpp +++ b/ConversionUtils.hpp @@ -3599,6 +3599,26 @@ bool ConvertStridedSlice(const HalOperation& operation, const HalModel& model, C return false; } + // Check if slice can fit in a inferred output + armnn::TensorShape inputShape = inputInfo.GetShape(); + for (unsigned int i = 0; i < inputShape.GetNumDimensions(); i++) + { + int stride = descriptor.m_Stride[i]; + int start = descriptor.GetStartForAxis(inputShape, i); + int stop = descriptor.GetStopForAxis(inputShape, i, start); + + if (descriptor.m_ShrinkAxisMask & (1 << i)) + { + // If the difference between the start point and the end point of the slice on an axis being shrunk + // is greater than 1 then throw an error as the output will not be large enough to hold the slice + if (((descriptor.m_Begin[i] - descriptor.m_End[i]) > 1) + || ((descriptor.m_Begin[i] - descriptor.m_End[i]) < -1)) + { + return Fail("%s: StridedSlice: Output will not be large enough to hold the slice", __func__); + } + } + } + armnn::IConnectableLayer* const layer = data.m_Network->AddStridedSliceLayer(descriptor); assert(layer != nullptr); input.Connect(layer->GetInputSlot(0)); |