aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/StridedSliceLayer.cpp
diff options
context:
space:
mode:
authorNattapat Chaimanowong <nattapat.chaimanowong@arm.com>2018-11-21 16:48:17 +0000
committernattapat.chaimanowong <nattapat.chaimanowong@arm.com>2018-11-21 17:13:30 +0000
commita0d2844d05dd9ae15733d426b04ab651457604ce (patch)
tree7ae0787d7be49ec2b1bb754ff35e13e8b71c73d2 /src/armnn/layers/StridedSliceLayer.cpp
parent3d93bc47f42b339d82cfcf56a90c9264dd46d70a (diff)
downloadarmnn-a0d2844d05dd9ae15733d426b04ab651457604ce.tar.gz
IVGCVSW-2086 Update StridedSliceLayer and StridedSliceDescriptor
Change-Id: Ifa88a879dd239f60ab27330d6b73859393828ef0
Diffstat (limited to 'src/armnn/layers/StridedSliceLayer.cpp')
-rw-r--r--src/armnn/layers/StridedSliceLayer.cpp35
1 files changed, 34 insertions, 1 deletions
diff --git a/src/armnn/layers/StridedSliceLayer.cpp b/src/armnn/layers/StridedSliceLayer.cpp
index f5e001c03f..a3dca25656 100644
--- a/src/armnn/layers/StridedSliceLayer.cpp
+++ b/src/armnn/layers/StridedSliceLayer.cpp
@@ -9,6 +9,8 @@
#include <backendsCommon/WorkloadData.hpp>
#include <backendsCommon/WorkloadFactory.hpp>
+#include <boost/numeric/conversion/cast.hpp>
+
namespace armnn
{
@@ -41,6 +43,37 @@ StridedSliceLayer* StridedSliceLayer::Clone(Graph& graph) const
return CloneBase<StridedSliceLayer>(graph, m_Param, GetName());
}
+std::vector<TensorShape> StridedSliceLayer::InferOutputShapes(
+ const std::vector<TensorShape>& inputShapes) const
+{
+ BOOST_ASSERT(inputShapes.size() == 1);
+
+ TensorShape inputShape = inputShapes[0];
+ std::vector<unsigned int> outputShape;
+
+ for (unsigned int i = 0; i < inputShape.GetNumDimensions(); i++)
+ {
+ if (m_Param.m_ShrinkAxisMask & (1 << i))
+ {
+ continue;
+ }
+
+ int stride = m_Param.m_Stride[i];
+ int start = m_Param.GetStartForAxis(inputShape, i);
+ int stop = m_Param.GetStopForAxis(inputShape, i, start);
+
+ int newSize = stride > 0 ? ((stop - start) + stride - 1) / stride :
+ ((start - stop) - stride - 1) / -stride;
+
+ newSize = std::min(0, newSize);
+
+ outputShape.push_back(boost::numeric_cast<unsigned int>(newSize));
+ }
+
+ return std::vector<TensorShape>({
+ TensorShape(boost::numeric_cast<unsigned int>(outputShape.size()), &outputShape[0]) });
+}
+
void StridedSliceLayer::ValidateTensorShapesFromInputs()
{
VerifyLayerConnections(1, CHECK_LOCATION());
@@ -55,4 +88,4 @@ void StridedSliceLayer::ValidateTensorShapesFromInputs()
inferredShapes[0]);
}
-} // namespace armnn \ No newline at end of file
+} // namespace armnn