From 5e90aab1cc25681c3e02b4d4436c24ee43400e91 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 14 Feb 2020 14:46:51 +0000 Subject: COMPMID-3059: Add TF parser support for StridedSlice Signed-off-by: Georgios Pinitas Change-Id: I31f25f26a50c9054b5650b1be127c84194b56be7 --- src/armnnUtils/ParserHelper.cpp | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) (limited to 'src/armnnUtils/ParserHelper.cpp') diff --git a/src/armnnUtils/ParserHelper.cpp b/src/armnnUtils/ParserHelper.cpp index ca6e42696e..9406553dff 100644 --- a/src/armnnUtils/ParserHelper.cpp +++ b/src/armnnUtils/ParserHelper.cpp @@ -101,4 +101,34 @@ void CalculateReducedOutputTensoInfo(const armnn::TensorInfo& inputTensorInfo, } } + +void CalculateStridedSliceOutputTensorInfo(const armnn::TensorInfo& inputTensorInfo, + const armnn::StridedSliceDescriptor& desc, + armnn::TensorInfo& outputTensorInfo) +{ + const armnn::TensorShape& inputShape = inputTensorInfo.GetShape(); + + std::vector outputShapeVector; + for (unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); i++) + { + if (desc.m_ShrinkAxisMask & (1 << i)) + { + continue; + } + + int stride = desc.m_Stride[i]; + int start = desc.GetStartForAxis(inputShape, i); + int stop = desc.GetStopForAxis(inputShape, i, start); + + int newSize = stride > 0 ? ((stop - start) + stride - 1) / stride : + ((start - stop) - stride - 1) / -stride; + + newSize = std::max(0, newSize); + + outputShapeVector.push_back(static_cast(newSize)); + } + + armnn::TensorShape outputTensorShape(inputTensorInfo.GetNumDimensions(), &outputShapeVector[0]); + outputTensorInfo = armnn::TensorInfo(armnn::TensorShape(outputTensorShape), inputTensorInfo.GetDataType()); +} } // namespace armnnUtils -- cgit v1.2.1