From a0d2844d05dd9ae15733d426b04ab651457604ce Mon Sep 17 00:00:00 2001 From: Nattapat Chaimanowong Date: Wed, 21 Nov 2018 16:48:17 +0000 Subject: IVGCVSW-2086 Update StridedSliceLayer and StridedSliceDescriptor Change-Id: Ifa88a879dd239f60ab27330d6b73859393828ef0 --- src/armnn/Descriptors.cpp | 62 ++++++++++++++++++++++++++++++++++ src/armnn/layers/StridedSliceLayer.cpp | 35 ++++++++++++++++++- src/armnn/layers/StridedSliceLayer.hpp | 4 ++- 3 files changed, 99 insertions(+), 2 deletions(-) (limited to 'src/armnn') diff --git a/src/armnn/Descriptors.cpp b/src/armnn/Descriptors.cpp index a200c6462c..43f41a7397 100644 --- a/src/armnn/Descriptors.cpp +++ b/src/armnn/Descriptors.cpp @@ -306,4 +306,66 @@ void swap(ViewsDescriptor& first, ViewsDescriptor& second) swap(first.m_ViewSizes, second.m_ViewSizes); } +int StridedSliceDescriptor::GetStartForAxis(const TensorShape& inputShape, + unsigned int axis) const +{ + int start = m_Begin[axis]; + + if (m_BeginMask & (1 << axis)) + { + if (m_Stride[axis] > 0) + { + start = std::numeric_limits::min(); + } + else + { + start = std::numeric_limits::max(); + } + } + + const int axisSize = boost::numeric_cast(inputShape[axis]); + if (start < 0) + { + start += (axisSize); + } + + return std::max(0, std::min(start, axisSize - 1)); + +} + +int StridedSliceDescriptor::GetStopForAxis(const TensorShape& inputShape, + unsigned int axis, + int startForAxis) const +{ + + if (m_ShrinkAxisMask & (1 << axis)) + { + return startForAxis + 1; + } + + int stop = m_End[axis]; + + if (m_EndMask & (1 << axis)) + { + if (m_Stride[axis] > 0) + { + stop = std::numeric_limits::max(); + } + else + { + stop = std::numeric_limits::min(); + } + } + + const int axisSize = boost::numeric_cast(inputShape[axis]); + if (stop < 0) + { + stop += axisSize; + } + + return m_Stride[axis] > 0 ? std::max(0, std::min(stop, axisSize)) : + std::max(-1, std::min(stop, axisSize - 1)); + +} + } diff --git a/src/armnn/layers/StridedSliceLayer.cpp b/src/armnn/layers/StridedSliceLayer.cpp index f5e001c03f..a3dca25656 100644 --- a/src/armnn/layers/StridedSliceLayer.cpp +++ b/src/armnn/layers/StridedSliceLayer.cpp @@ -9,6 +9,8 @@ #include #include +#include + namespace armnn { @@ -41,6 +43,37 @@ StridedSliceLayer* StridedSliceLayer::Clone(Graph& graph) const return CloneBase(graph, m_Param, GetName()); } +std::vector StridedSliceLayer::InferOutputShapes( + const std::vector& inputShapes) const +{ + BOOST_ASSERT(inputShapes.size() == 1); + + TensorShape inputShape = inputShapes[0]; + std::vector outputShape; + + for (unsigned int i = 0; i < inputShape.GetNumDimensions(); i++) + { + if (m_Param.m_ShrinkAxisMask & (1 << i)) + { + continue; + } + + int stride = m_Param.m_Stride[i]; + int start = m_Param.GetStartForAxis(inputShape, i); + int stop = m_Param.GetStopForAxis(inputShape, i, start); + + int newSize = stride > 0 ? ((stop - start) + stride - 1) / stride : + ((start - stop) - stride - 1) / -stride; + + newSize = std::min(0, newSize); + + outputShape.push_back(boost::numeric_cast(newSize)); + } + + return std::vector({ + TensorShape(boost::numeric_cast(outputShape.size()), &outputShape[0]) }); +} + void StridedSliceLayer::ValidateTensorShapesFromInputs() { VerifyLayerConnections(1, CHECK_LOCATION()); @@ -55,4 +88,4 @@ void StridedSliceLayer::ValidateTensorShapesFromInputs() inferredShapes[0]); } -} // namespace armnn \ No newline at end of file +} // namespace armnn diff --git a/src/armnn/layers/StridedSliceLayer.hpp b/src/armnn/layers/StridedSliceLayer.hpp index 33a44243a5..c3aad53e19 100644 --- a/src/armnn/layers/StridedSliceLayer.hpp +++ b/src/armnn/layers/StridedSliceLayer.hpp @@ -17,6 +17,8 @@ public: StridedSliceLayer* Clone(Graph& graph) const override; + std::vector InferOutputShapes(const std::vector& inputShapes) const override; + void ValidateTensorShapesFromInputs() override; protected: @@ -24,4 +26,4 @@ protected: ~StridedSliceLayer() = default; }; -} // namespace \ No newline at end of file +} // namespace -- cgit v1.2.1