diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2020-02-14 14:46:51 +0000 |
---|---|---|
committer | Derek Lamberti <derek.lamberti@arm.com> | 2020-02-18 10:10:07 +0000 |
commit | 5e90aab1cc25681c3e02b4d4436c24ee43400e91 (patch) | |
tree | 1121284177ec93b8e8e8ac07e98209639a562a8f /src/armnnUtils | |
parent | 0c2eeac6347533a1d3d456aebea492f5123388f3 (diff) | |
download | armnn-5e90aab1cc25681c3e02b4d4436c24ee43400e91.tar.gz |
COMPMID-3059: Add TF parser support for StridedSlice
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Change-Id: I31f25f26a50c9054b5650b1be127c84194b56be7
Diffstat (limited to 'src/armnnUtils')
-rw-r--r-- | src/armnnUtils/ParserHelper.cpp | 30 | ||||
-rw-r--r-- | src/armnnUtils/ParserHelper.hpp | 5 |
2 files changed, 35 insertions, 0 deletions
diff --git a/src/armnnUtils/ParserHelper.cpp b/src/armnnUtils/ParserHelper.cpp index ca6e42696e..9406553dff 100644 --- a/src/armnnUtils/ParserHelper.cpp +++ b/src/armnnUtils/ParserHelper.cpp @@ -101,4 +101,34 @@ void CalculateReducedOutputTensoInfo(const armnn::TensorInfo& inputTensorInfo, } } + +void CalculateStridedSliceOutputTensorInfo(const armnn::TensorInfo& inputTensorInfo, + const armnn::StridedSliceDescriptor& desc, + armnn::TensorInfo& outputTensorInfo) +{ + const armnn::TensorShape& inputShape = inputTensorInfo.GetShape(); + + std::vector<unsigned int> outputShapeVector; + for (unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); i++) + { + if (desc.m_ShrinkAxisMask & (1 << i)) + { + continue; + } + + int stride = desc.m_Stride[i]; + int start = desc.GetStartForAxis(inputShape, i); + int stop = desc.GetStopForAxis(inputShape, i, start); + + int newSize = stride > 0 ? ((stop - start) + stride - 1) / stride : + ((start - stop) - stride - 1) / -stride; + + newSize = std::max(0, newSize); + + outputShapeVector.push_back(static_cast<unsigned int>(newSize)); + } + + armnn::TensorShape outputTensorShape(inputTensorInfo.GetNumDimensions(), &outputShapeVector[0]); + outputTensorInfo = armnn::TensorInfo(armnn::TensorShape(outputTensorShape), inputTensorInfo.GetDataType()); +} } // namespace armnnUtils diff --git a/src/armnnUtils/ParserHelper.hpp b/src/armnnUtils/ParserHelper.hpp index d63408804d..28c7964ac1 100644 --- a/src/armnnUtils/ParserHelper.hpp +++ b/src/armnnUtils/ParserHelper.hpp @@ -25,4 +25,9 @@ void CalculateReducedOutputTensoInfo(const armnn::TensorInfo& inputTensorInfo, bool keepDims, armnn::TensorInfo& outputTensorInfo); +/// Create output tensor info for a StridedSlice operator +void CalculateStridedSliceOutputTensorInfo(const armnn::TensorInfo& inputTensorInfo, + const armnn::StridedSliceDescriptor& desc, + armnn::TensorInfo& outputTensorInfo); + } // namespace armnnUtils |