diff options
Diffstat (limited to 'src/armnn/Descriptors.cpp')
-rw-r--r-- | src/armnn/Descriptors.cpp | 62 |
1 files changed, 62 insertions, 0 deletions
diff --git a/src/armnn/Descriptors.cpp b/src/armnn/Descriptors.cpp index a200c6462c..43f41a7397 100644 --- a/src/armnn/Descriptors.cpp +++ b/src/armnn/Descriptors.cpp @@ -306,4 +306,66 @@ void swap(ViewsDescriptor& first, ViewsDescriptor& second) swap(first.m_ViewSizes, second.m_ViewSizes); } +int StridedSliceDescriptor::GetStartForAxis(const TensorShape& inputShape, + unsigned int axis) const +{ + int start = m_Begin[axis]; + + if (m_BeginMask & (1 << axis)) + { + if (m_Stride[axis] > 0) + { + start = std::numeric_limits<int>::min(); + } + else + { + start = std::numeric_limits<int>::max(); + } + } + + const int axisSize = boost::numeric_cast<int>(inputShape[axis]); + if (start < 0) + { + start += (axisSize); + } + + return std::max(0, std::min(start, axisSize - 1)); + +} + +int StridedSliceDescriptor::GetStopForAxis(const TensorShape& inputShape, + unsigned int axis, + int startForAxis) const +{ + + if (m_ShrinkAxisMask & (1 << axis)) + { + return startForAxis + 1; + } + + int stop = m_End[axis]; + + if (m_EndMask & (1 << axis)) + { + if (m_Stride[axis] > 0) + { + stop = std::numeric_limits<int>::max(); + } + else + { + stop = std::numeric_limits<int>::min(); + } + } + + const int axisSize = boost::numeric_cast<int>(inputShape[axis]); + if (stop < 0) + { + stop += axisSize; + } + + return m_Stride[axis] > 0 ? std::max(0, std::min(stop, axisSize)) : + std::max(-1, std::min(stop, axisSize - 1)); + +} + } |