aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon')
-rw-r--r--src/backends/backendsCommon/ILayerSupport.cpp8
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp37
-rw-r--r--src/backends/backendsCommon/WorkloadData.hpp6
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp11
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.hpp3
-rw-r--r--src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp2
6 files changed, 67 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/ILayerSupport.cpp b/src/backends/backendsCommon/ILayerSupport.cpp
index 2cd57b7ad7..dc106e344e 100644
--- a/src/backends/backendsCommon/ILayerSupport.cpp
+++ b/src/backends/backendsCommon/ILayerSupport.cpp
@@ -279,6 +279,14 @@ bool ILayerSupport::IsSplitterSupported(const TensorInfo& input,
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
+bool ILayerSupport::IsStridedSliceSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const StridedSliceDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
+}
+
bool ILayerSupport::IsSubtractionSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 9fbdfe94c2..e1146543ff 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -924,4 +924,41 @@ void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
ValidateSingleOutput(workloadInfo, "BatchToSpaceNdQueueDescriptor");
}
+void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
+{
+ ValidateSingleInput(workloadInfo, "StridedSliceQueueDescriptor");
+ ValidateSingleOutput(workloadInfo, "StridedSliceQueueDescriptor");
+
+ const TensorInfo& input = workloadInfo.m_InputTensorInfos[0];
+ const uint32_t rank = input.GetNumDimensions();
+
+ // Begin, End & Stride length must be of rank(input0)
+ if (m_Parameters.m_Begin.size() != rank)
+ {
+ throw InvalidArgumentException("StridedSliceLayer: Begin length must be of rank input0("
+ + to_string(rank) + ")");
+ }
+
+ if (m_Parameters.m_End.size() != rank)
+ {
+ throw InvalidArgumentException("StridedSliceLayer: End length must be of rank input0("
+ + to_string(rank) + ")");
+ }
+
+ if (m_Parameters.m_Stride.size() != rank)
+ {
+ throw InvalidArgumentException("StridedSliceLayer: Stride length must be of rank input0("
+ + to_string(rank) + ")");
+ }
+
+ // Stride entries must be non-zero
+ for (auto& stride : m_Parameters.m_Stride)
+ {
+ if (stride == 0)
+ {
+ throw InvalidArgumentException("StridedSliceLayer: Stride entries must be non-zero");
+ }
+ }
+}
+
} //namespace armnn \ No newline at end of file
diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp
index d54a71aa8c..8cc60d0a96 100644
--- a/src/backends/backendsCommon/WorkloadData.hpp
+++ b/src/backends/backendsCommon/WorkloadData.hpp
@@ -339,4 +339,10 @@ struct BatchToSpaceNdQueueDescriptor : QueueDescriptorWithParameters<BatchToSpac
{
void Validate(const WorkloadInfo& workloadInfo) const;
};
+
+struct StridedSliceQueueDescriptor : QueueDescriptorWithParameters<StridedSliceDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
} //namespace armnn
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index bb63b336e9..dc38f1a721 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -592,6 +592,17 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
reason);
break;
}
+ case LayerType::StridedSlice:
+ {
+ auto cLayer = boost::polymorphic_downcast<const StridedSliceLayer*>(&layer);
+ const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
+ const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
+ result = layerSupportObject->IsStridedSliceSupported(OverrideDataType(input, dataType),
+ OverrideDataType(output, dataType),
+ cLayer->GetParameters(),
+ reason);
+ break;
+ }
case LayerType::Subtraction:
{
const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
diff --git a/src/backends/backendsCommon/WorkloadFactory.hpp b/src/backends/backendsCommon/WorkloadFactory.hpp
index cd1ca25bb2..a1d0400f11 100644
--- a/src/backends/backendsCommon/WorkloadFactory.hpp
+++ b/src/backends/backendsCommon/WorkloadFactory.hpp
@@ -138,6 +138,9 @@ public:
virtual std::unique_ptr<IWorkload> CreatePad(const PadQueueDescriptor& descriptor,
const WorkloadInfo& Info) const = 0;
+
+ virtual std::unique_ptr<IWorkload> CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor,
+ const WorkloadInfo& Info) const = 0;
};
} //namespace armnn
diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
index 25079058f6..7817e42321 100644
--- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
+++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
@@ -374,6 +374,8 @@ DECLARE_LAYER_POLICY_2_PARAM(SpaceToBatchNd)
DECLARE_LAYER_POLICY_2_PARAM(Splitter)
+DECLARE_LAYER_POLICY_2_PARAM(StridedSlice)
+
DECLARE_LAYER_POLICY_1_PARAM(Subtraction)