diff options
Diffstat (limited to 'src/backends/backendsCommon')
4 files changed, 50 insertions, 10 deletions
diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp index 55261b83cf..00f1d0223d 100644 --- a/src/backends/backendsCommon/LayerSupportBase.cpp +++ b/src/backends/backendsCommon/LayerSupportBase.cpp @@ -163,10 +163,15 @@ bool LayerSupportBase::IsDequantizeSupported(const TensorInfo& input, return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0, - const armnn::TensorInfo& input1, - const armnn::DetectionPostProcessDescriptor& descriptor, - armnn::Optional<std::string&> reasonIfUnsupported) const +bool LayerSupportBase::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings, + const TensorInfo& scores, + const TensorInfo& anchors, + const TensorInfo& detectionBoxes, + const TensorInfo& detectionClasses, + const TensorInfo& detectionScores, + const TensorInfo& numDetections, + const DetectionPostProcessDescriptor& descriptor, + Optional<std::string&> reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp index e99cb67614..60f94d0c4d 100644 --- a/src/backends/backendsCommon/LayerSupportBase.hpp +++ b/src/backends/backendsCommon/LayerSupportBase.hpp @@ -96,8 +96,13 @@ public: const TensorInfo& output, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; - bool IsDetectionPostProcessSupported(const TensorInfo& input0, - const TensorInfo& input1, + bool IsDetectionPostProcessSupported(const TensorInfo& boxEncodings, + const TensorInfo& scores, + const TensorInfo& anchors, + const TensorInfo& detectionBoxes, + const TensorInfo& detectionClasses, + const TensorInfo& detectionScores, + const TensorInfo& numDetections, const DetectionPostProcessDescriptor& descriptor, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index 1d4ed7e159..805ec7ba5f 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -272,12 +272,24 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, } case LayerType::DetectionPostProcess: { - const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); - const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo(); auto cLayer = boost::polymorphic_downcast<const DetectionPostProcessLayer*>(&layer); + const TensorInfo& boxEncodings = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); + const TensorInfo& scores = layer.GetInputSlot(1).GetConnection()->GetTensorInfo(); + const TensorInfo& anchors = cLayer->m_Anchors->GetTensorInfo(); + + const TensorInfo& detectionBoxes = layer.GetOutputSlot(0).GetTensorInfo(); + const TensorInfo& detectionClasses = layer.GetOutputSlot(1).GetTensorInfo(); + const TensorInfo& detectionScores = layer.GetOutputSlot(2).GetTensorInfo(); + const TensorInfo& numDetections = layer.GetOutputSlot(3).GetTensorInfo(); + const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters(); - result = layerSupportObject->IsDetectionPostProcessSupported(input0, - input1, + result = layerSupportObject->IsDetectionPostProcessSupported(boxEncodings, + scores, + anchors, + detectionBoxes, + detectionClasses, + detectionScores, + numDetections, descriptor, reason); break; diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp index 12d7143122..7ab5ee4ec4 100644 --- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp +++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp @@ -238,6 +238,24 @@ struct DummyLayer<armnn::TransposeConvolution2dLayer> { }; +template<> +struct DummyLayer<armnn::DetectionPostProcessLayer> +{ + DummyLayer() + { + m_Layer = dummyGraph.AddLayer<armnn::DetectionPostProcessLayer>(armnn::DetectionPostProcessDescriptor(), ""); + m_Layer->m_Anchors = std::make_unique<armnn::ScopedCpuTensorHandle>( + armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32)); + } + + ~DummyLayer() + { + dummyGraph.EraseLayer(m_Layer); + } + + armnn::DetectionPostProcessLayer* m_Layer; +}; + template <typename LstmLayerType> struct DummyLstmLayer { |