diff options
Diffstat (limited to 'src/armnn/layers/StridedSliceLayer.cpp')
-rw-r--r-- | src/armnn/layers/StridedSliceLayer.cpp | 35 |
1 files changed, 34 insertions, 1 deletions
diff --git a/src/armnn/layers/StridedSliceLayer.cpp b/src/armnn/layers/StridedSliceLayer.cpp index f5e001c03f..a3dca25656 100644 --- a/src/armnn/layers/StridedSliceLayer.cpp +++ b/src/armnn/layers/StridedSliceLayer.cpp @@ -9,6 +9,8 @@ #include <backendsCommon/WorkloadData.hpp> #include <backendsCommon/WorkloadFactory.hpp> +#include <boost/numeric/conversion/cast.hpp> + namespace armnn { @@ -41,6 +43,37 @@ StridedSliceLayer* StridedSliceLayer::Clone(Graph& graph) const return CloneBase<StridedSliceLayer>(graph, m_Param, GetName()); } +std::vector<TensorShape> StridedSliceLayer::InferOutputShapes( + const std::vector<TensorShape>& inputShapes) const +{ + BOOST_ASSERT(inputShapes.size() == 1); + + TensorShape inputShape = inputShapes[0]; + std::vector<unsigned int> outputShape; + + for (unsigned int i = 0; i < inputShape.GetNumDimensions(); i++) + { + if (m_Param.m_ShrinkAxisMask & (1 << i)) + { + continue; + } + + int stride = m_Param.m_Stride[i]; + int start = m_Param.GetStartForAxis(inputShape, i); + int stop = m_Param.GetStopForAxis(inputShape, i, start); + + int newSize = stride > 0 ? ((stop - start) + stride - 1) / stride : + ((start - stop) - stride - 1) / -stride; + + newSize = std::min(0, newSize); + + outputShape.push_back(boost::numeric_cast<unsigned int>(newSize)); + } + + return std::vector<TensorShape>({ + TensorShape(boost::numeric_cast<unsigned int>(outputShape.size()), &outputShape[0]) }); +} + void StridedSliceLayer::ValidateTensorShapesFromInputs() { VerifyLayerConnections(1, CHECK_LOCATION()); @@ -55,4 +88,4 @@ void StridedSliceLayer::ValidateTensorShapesFromInputs() inferredShapes[0]); } -} // namespace armnn
\ No newline at end of file +} // namespace armnn |