aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2020-02-14 14:46:51 +0000
committerDerek Lamberti <derek.lamberti@arm.com>2020-02-18 10:10:07 +0000
commit5e90aab1cc25681c3e02b4d4436c24ee43400e91 (patch)
tree1121284177ec93b8e8e8ac07e98209639a562a8f /src/armnnUtils
parent0c2eeac6347533a1d3d456aebea492f5123388f3 (diff)
downloadarmnn-5e90aab1cc25681c3e02b4d4436c24ee43400e91.tar.gz
COMPMID-3059: Add TF parser support for StridedSlice
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Change-Id: I31f25f26a50c9054b5650b1be127c84194b56be7
Diffstat (limited to 'src/armnnUtils')
-rw-r--r--src/armnnUtils/ParserHelper.cpp30
-rw-r--r--src/armnnUtils/ParserHelper.hpp5
2 files changed, 35 insertions, 0 deletions
diff --git a/src/armnnUtils/ParserHelper.cpp b/src/armnnUtils/ParserHelper.cpp
index ca6e42696e..9406553dff 100644
--- a/src/armnnUtils/ParserHelper.cpp
+++ b/src/armnnUtils/ParserHelper.cpp
@@ -101,4 +101,34 @@ void CalculateReducedOutputTensoInfo(const armnn::TensorInfo& inputTensorInfo,
}
}
+
+void CalculateStridedSliceOutputTensorInfo(const armnn::TensorInfo& inputTensorInfo,
+ const armnn::StridedSliceDescriptor& desc,
+ armnn::TensorInfo& outputTensorInfo)
+{
+ const armnn::TensorShape& inputShape = inputTensorInfo.GetShape();
+
+ std::vector<unsigned int> outputShapeVector;
+ for (unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); i++)
+ {
+ if (desc.m_ShrinkAxisMask & (1 << i))
+ {
+ continue;
+ }
+
+ int stride = desc.m_Stride[i];
+ int start = desc.GetStartForAxis(inputShape, i);
+ int stop = desc.GetStopForAxis(inputShape, i, start);
+
+ int newSize = stride > 0 ? ((stop - start) + stride - 1) / stride :
+ ((start - stop) - stride - 1) / -stride;
+
+ newSize = std::max(0, newSize);
+
+ outputShapeVector.push_back(static_cast<unsigned int>(newSize));
+ }
+
+ armnn::TensorShape outputTensorShape(inputTensorInfo.GetNumDimensions(), &outputShapeVector[0]);
+ outputTensorInfo = armnn::TensorInfo(armnn::TensorShape(outputTensorShape), inputTensorInfo.GetDataType());
+}
} // namespace armnnUtils
diff --git a/src/armnnUtils/ParserHelper.hpp b/src/armnnUtils/ParserHelper.hpp
index d63408804d..28c7964ac1 100644
--- a/src/armnnUtils/ParserHelper.hpp
+++ b/src/armnnUtils/ParserHelper.hpp
@@ -25,4 +25,9 @@ void CalculateReducedOutputTensoInfo(const armnn::TensorInfo& inputTensorInfo,
bool keepDims,
armnn::TensorInfo& outputTensorInfo);
+/// Create output tensor info for a StridedSlice operator
+void CalculateStridedSliceOutputTensorInfo(const armnn::TensorInfo& inputTensorInfo,
+ const armnn::StridedSliceDescriptor& desc,
+ armnn::TensorInfo& outputTensorInfo);
+
} // namespace armnnUtils