diff options
Diffstat (limited to 'src/armnnUtils')
-rw-r--r-- | src/armnnUtils/ParserHelper.cpp | 30 | ||||
-rw-r--r-- | src/armnnUtils/ParserHelper.hpp | 5 |
2 files changed, 35 insertions, 0 deletions
diff --git a/src/armnnUtils/ParserHelper.cpp b/src/armnnUtils/ParserHelper.cpp index ca6e42696e..9406553dff 100644 --- a/src/armnnUtils/ParserHelper.cpp +++ b/src/armnnUtils/ParserHelper.cpp @@ -101,4 +101,34 @@ void CalculateReducedOutputTensoInfo(const armnn::TensorInfo& inputTensorInfo, } } + +void CalculateStridedSliceOutputTensorInfo(const armnn::TensorInfo& inputTensorInfo, + const armnn::StridedSliceDescriptor& desc, + armnn::TensorInfo& outputTensorInfo) +{ + const armnn::TensorShape& inputShape = inputTensorInfo.GetShape(); + + std::vector<unsigned int> outputShapeVector; + for (unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); i++) + { + if (desc.m_ShrinkAxisMask & (1 << i)) + { + continue; + } + + int stride = desc.m_Stride[i]; + int start = desc.GetStartForAxis(inputShape, i); + int stop = desc.GetStopForAxis(inputShape, i, start); + + int newSize = stride > 0 ? ((stop - start) + stride - 1) / stride : + ((start - stop) - stride - 1) / -stride; + + newSize = std::max(0, newSize); + + outputShapeVector.push_back(static_cast<unsigned int>(newSize)); + } + + armnn::TensorShape outputTensorShape(inputTensorInfo.GetNumDimensions(), &outputShapeVector[0]); + outputTensorInfo = armnn::TensorInfo(armnn::TensorShape(outputTensorShape), inputTensorInfo.GetDataType()); +} } // namespace armnnUtils diff --git a/src/armnnUtils/ParserHelper.hpp b/src/armnnUtils/ParserHelper.hpp index d63408804d..28c7964ac1 100644 --- a/src/armnnUtils/ParserHelper.hpp +++ b/src/armnnUtils/ParserHelper.hpp @@ -25,4 +25,9 @@ void CalculateReducedOutputTensoInfo(const armnn::TensorInfo& inputTensorInfo, bool keepDims, armnn::TensorInfo& outputTensorInfo); +/// Create output tensor info for a StridedSlice operator +void CalculateStridedSliceOutputTensorInfo(const armnn::TensorInfo& inputTensorInfo, + const armnn::StridedSliceDescriptor& desc, + armnn::TensorInfo& outputTensorInfo); + } // namespace armnnUtils |