aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/armnn/Descriptors.hpp13
-rw-r--r--src/armnn/Descriptors.cpp62
-rw-r--r--src/armnn/layers/StridedSliceLayer.cpp35
-rw-r--r--src/armnn/layers/StridedSliceLayer.hpp4
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp8
5 files changed, 115 insertions, 7 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index b705abe729..22dd0d2991 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -427,10 +427,6 @@ struct PadDescriptor
struct StridedSliceDescriptor
{
- StridedSliceDescriptor()
- : m_DataLayout(DataLayout::NCHW)
- {}
-
StridedSliceDescriptor(const std::vector<int>& begin,
const std::vector<int>& end,
const std::vector<int>& stride)
@@ -445,6 +441,15 @@ struct StridedSliceDescriptor
, m_DataLayout(DataLayout::NCHW)
{}
+ StridedSliceDescriptor()
+ : StridedSliceDescriptor({}, {}, {})
+ {}
+
+ int GetStartForAxis(const TensorShape& inputShape, unsigned int axis) const;
+ int GetStopForAxis(const TensorShape& inputShape,
+ unsigned int axis,
+ int startForAxis) const;
+
std::vector<int> m_Begin;
std::vector<int> m_End;
std::vector<int> m_Stride;
diff --git a/src/armnn/Descriptors.cpp b/src/armnn/Descriptors.cpp
index a200c6462c..43f41a7397 100644
--- a/src/armnn/Descriptors.cpp
+++ b/src/armnn/Descriptors.cpp
@@ -306,4 +306,66 @@ void swap(ViewsDescriptor& first, ViewsDescriptor& second)
swap(first.m_ViewSizes, second.m_ViewSizes);
}
+int StridedSliceDescriptor::GetStartForAxis(const TensorShape& inputShape,
+ unsigned int axis) const
+{
+ int start = m_Begin[axis];
+
+ if (m_BeginMask & (1 << axis))
+ {
+ if (m_Stride[axis] > 0)
+ {
+ start = std::numeric_limits<int>::min();
+ }
+ else
+ {
+ start = std::numeric_limits<int>::max();
+ }
+ }
+
+ const int axisSize = boost::numeric_cast<int>(inputShape[axis]);
+ if (start < 0)
+ {
+ start += (axisSize);
+ }
+
+ return std::max(0, std::min(start, axisSize - 1));
+
+}
+
+int StridedSliceDescriptor::GetStopForAxis(const TensorShape& inputShape,
+ unsigned int axis,
+ int startForAxis) const
+{
+
+ if (m_ShrinkAxisMask & (1 << axis))
+ {
+ return startForAxis + 1;
+ }
+
+ int stop = m_End[axis];
+
+ if (m_EndMask & (1 << axis))
+ {
+ if (m_Stride[axis] > 0)
+ {
+ stop = std::numeric_limits<int>::max();
+ }
+ else
+ {
+ stop = std::numeric_limits<int>::min();
+ }
+ }
+
+ const int axisSize = boost::numeric_cast<int>(inputShape[axis]);
+ if (stop < 0)
+ {
+ stop += axisSize;
+ }
+
+ return m_Stride[axis] > 0 ? std::max(0, std::min(stop, axisSize)) :
+ std::max(-1, std::min(stop, axisSize - 1));
+
+}
+
}
diff --git a/src/armnn/layers/StridedSliceLayer.cpp b/src/armnn/layers/StridedSliceLayer.cpp
index f5e001c03f..a3dca25656 100644
--- a/src/armnn/layers/StridedSliceLayer.cpp
+++ b/src/armnn/layers/StridedSliceLayer.cpp
@@ -9,6 +9,8 @@
#include <backendsCommon/WorkloadData.hpp>
#include <backendsCommon/WorkloadFactory.hpp>
+#include <boost/numeric/conversion/cast.hpp>
+
namespace armnn
{
@@ -41,6 +43,37 @@ StridedSliceLayer* StridedSliceLayer::Clone(Graph& graph) const
return CloneBase<StridedSliceLayer>(graph, m_Param, GetName());
}
+std::vector<TensorShape> StridedSliceLayer::InferOutputShapes(
+ const std::vector<TensorShape>& inputShapes) const
+{
+ BOOST_ASSERT(inputShapes.size() == 1);
+
+ TensorShape inputShape = inputShapes[0];
+ std::vector<unsigned int> outputShape;
+
+ for (unsigned int i = 0; i < inputShape.GetNumDimensions(); i++)
+ {
+ if (m_Param.m_ShrinkAxisMask & (1 << i))
+ {
+ continue;
+ }
+
+ int stride = m_Param.m_Stride[i];
+ int start = m_Param.GetStartForAxis(inputShape, i);
+ int stop = m_Param.GetStopForAxis(inputShape, i, start);
+
+ int newSize = stride > 0 ? ((stop - start) + stride - 1) / stride :
+ ((start - stop) - stride - 1) / -stride;
+
+ newSize = std::min(0, newSize);
+
+ outputShape.push_back(boost::numeric_cast<unsigned int>(newSize));
+ }
+
+ return std::vector<TensorShape>({
+ TensorShape(boost::numeric_cast<unsigned int>(outputShape.size()), &outputShape[0]) });
+}
+
void StridedSliceLayer::ValidateTensorShapesFromInputs()
{
VerifyLayerConnections(1, CHECK_LOCATION());
@@ -55,4 +88,4 @@ void StridedSliceLayer::ValidateTensorShapesFromInputs()
inferredShapes[0]);
}
-} // namespace armnn \ No newline at end of file
+} // namespace armnn
diff --git a/src/armnn/layers/StridedSliceLayer.hpp b/src/armnn/layers/StridedSliceLayer.hpp
index 33a44243a5..c3aad53e19 100644
--- a/src/armnn/layers/StridedSliceLayer.hpp
+++ b/src/armnn/layers/StridedSliceLayer.hpp
@@ -17,6 +17,8 @@ public:
StridedSliceLayer* Clone(Graph& graph) const override;
+ std::vector<TensorShape> InferOutputShapes(const std::vector<TensorShape>& inputShapes) const override;
+
void ValidateTensorShapesFromInputs() override;
protected:
@@ -24,4 +26,4 @@ protected:
~StridedSliceLayer() = default;
};
-} // namespace \ No newline at end of file
+} // namespace
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index d5e3638a06..af57fee935 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -946,6 +946,12 @@ void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) con
const TensorInfo& input = workloadInfo.m_InputTensorInfos[0];
const uint32_t rank = input.GetNumDimensions();
+ if (rank > 4)
+ {
+ throw InvalidArgumentException(
+ "StridedSliceLayer: Input tensors with rank greater than 4 are not supported");
+ }
+
// Begin, End & Stride length must be of rank(input0)
if (m_Parameters.m_Begin.size() != rank)
{
@@ -975,4 +981,4 @@ void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) con
}
}
-} //namespace armnn \ No newline at end of file
+} //namespace armnn