aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils/ParserHelper.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnUtils/ParserHelper.cpp')
-rw-r--r--src/armnnUtils/ParserHelper.cpp30
1 files changed, 30 insertions, 0 deletions
diff --git a/src/armnnUtils/ParserHelper.cpp b/src/armnnUtils/ParserHelper.cpp
index ca6e42696e..9406553dff 100644
--- a/src/armnnUtils/ParserHelper.cpp
+++ b/src/armnnUtils/ParserHelper.cpp
@@ -101,4 +101,34 @@ void CalculateReducedOutputTensoInfo(const armnn::TensorInfo& inputTensorInfo,
}
}
+
+void CalculateStridedSliceOutputTensorInfo(const armnn::TensorInfo& inputTensorInfo,
+ const armnn::StridedSliceDescriptor& desc,
+ armnn::TensorInfo& outputTensorInfo)
+{
+ const armnn::TensorShape& inputShape = inputTensorInfo.GetShape();
+
+ std::vector<unsigned int> outputShapeVector;
+ for (unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); i++)
+ {
+ if (desc.m_ShrinkAxisMask & (1 << i))
+ {
+ continue;
+ }
+
+ int stride = desc.m_Stride[i];
+ int start = desc.GetStartForAxis(inputShape, i);
+ int stop = desc.GetStopForAxis(inputShape, i, start);
+
+ int newSize = stride > 0 ? ((stop - start) + stride - 1) / stride :
+ ((start - stop) - stride - 1) / -stride;
+
+ newSize = std::max(0, newSize);
+
+ outputShapeVector.push_back(static_cast<unsigned int>(newSize));
+ }
+
+ armnn::TensorShape outputTensorShape(inputTensorInfo.GetNumDimensions(), &outputShapeVector[0]);
+ outputTensorInfo = armnn::TensorInfo(armnn::TensorShape(outputTensorShape), inputTensorInfo.GetDataType());
+}
} // namespace armnnUtils