aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Descriptors.cpp
diff options
context:
space:
mode:
authorNattapat Chaimanowong <nattapat.chaimanowong@arm.com>2018-11-21 16:48:17 +0000
committernattapat.chaimanowong <nattapat.chaimanowong@arm.com>2018-11-21 17:13:30 +0000
commita0d2844d05dd9ae15733d426b04ab651457604ce (patch)
tree7ae0787d7be49ec2b1bb754ff35e13e8b71c73d2 /src/armnn/Descriptors.cpp
parent3d93bc47f42b339d82cfcf56a90c9264dd46d70a (diff)
downloadarmnn-a0d2844d05dd9ae15733d426b04ab651457604ce.tar.gz
IVGCVSW-2086 Update StridedSliceLayer and StridedSliceDescriptor
Change-Id: Ifa88a879dd239f60ab27330d6b73859393828ef0
Diffstat (limited to 'src/armnn/Descriptors.cpp')
-rw-r--r--src/armnn/Descriptors.cpp62
1 files changed, 62 insertions, 0 deletions
diff --git a/src/armnn/Descriptors.cpp b/src/armnn/Descriptors.cpp
index a200c6462c..43f41a7397 100644
--- a/src/armnn/Descriptors.cpp
+++ b/src/armnn/Descriptors.cpp
@@ -306,4 +306,66 @@ void swap(ViewsDescriptor& first, ViewsDescriptor& second)
swap(first.m_ViewSizes, second.m_ViewSizes);
}
+int StridedSliceDescriptor::GetStartForAxis(const TensorShape& inputShape,
+ unsigned int axis) const
+{
+ int start = m_Begin[axis];
+
+ if (m_BeginMask & (1 << axis))
+ {
+ if (m_Stride[axis] > 0)
+ {
+ start = std::numeric_limits<int>::min();
+ }
+ else
+ {
+ start = std::numeric_limits<int>::max();
+ }
+ }
+
+ const int axisSize = boost::numeric_cast<int>(inputShape[axis]);
+ if (start < 0)
+ {
+ start += (axisSize);
+ }
+
+ return std::max(0, std::min(start, axisSize - 1));
+
+}
+
+int StridedSliceDescriptor::GetStopForAxis(const TensorShape& inputShape,
+ unsigned int axis,
+ int startForAxis) const
+{
+
+ if (m_ShrinkAxisMask & (1 << axis))
+ {
+ return startForAxis + 1;
+ }
+
+ int stop = m_End[axis];
+
+ if (m_EndMask & (1 << axis))
+ {
+ if (m_Stride[axis] > 0)
+ {
+ stop = std::numeric_limits<int>::max();
+ }
+ else
+ {
+ stop = std::numeric_limits<int>::min();
+ }
+ }
+
+ const int axisSize = boost::numeric_cast<int>(inputShape[axis]);
+ if (stop < 0)
+ {
+ stop += axisSize;
+ }
+
+ return m_Stride[axis] > 0 ? std::max(0, std::min(stop, axisSize)) :
+ std::max(-1, std::min(stop, axisSize - 1));
+
+}
+
}