diff options
Diffstat (limited to 'src/armnn')
-rw-r--r-- | src/armnn/LayerVisitorBase.hpp | 1 | ||||
-rw-r--r-- | src/armnn/Network.cpp | 8 | ||||
-rw-r--r-- | src/armnn/Network.hpp | 1 | ||||
-rw-r--r-- | src/armnn/layers/DetectionPostProcessLayer.cpp | 8 | ||||
-rw-r--r-- | src/armnn/test/TestLayerVisitor.hpp | 1 |
5 files changed, 15 insertions, 4 deletions
diff --git a/src/armnn/LayerVisitorBase.hpp b/src/armnn/LayerVisitorBase.hpp index 3b6a2ff578..641ca31e2d 100644 --- a/src/armnn/LayerVisitorBase.hpp +++ b/src/armnn/LayerVisitorBase.hpp @@ -57,6 +57,7 @@ public: virtual void VisitDetectionPostProcessLayer(const IConnectableLayer*, const DetectionPostProcessDescriptor&, + const ConstTensor&, const char*) { DefaultPolicy::Apply(); } virtual void VisitFullyConnectedLayer(const IConnectableLayer*, diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index 7897a81d1e..5c70003785 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -648,9 +648,13 @@ IConnectableLayer* Network::AddDepthwiseConvolution2dLayer( } IConnectableLayer* Network::AddDetectionPostProcessLayer(const armnn::DetectionPostProcessDescriptor& descriptor, - const char* name) + const ConstTensor& anchors, const char* name) { - return m_Graph->AddLayer<DetectionPostProcessLayer>(descriptor, name); + const auto layer = m_Graph->AddLayer<DetectionPostProcessLayer>(descriptor, name); + + layer->m_Anchors = std::make_unique<ScopedCpuTensorHandle>(anchors); + + return layer; } IConnectableLayer* Network::AddPermuteLayer(const PermuteDescriptor& permuteDescriptor, diff --git a/src/armnn/Network.hpp b/src/armnn/Network.hpp index 4239ac5ba4..66fb240979 100644 --- a/src/armnn/Network.hpp +++ b/src/armnn/Network.hpp @@ -59,6 +59,7 @@ public: IConnectableLayer* AddDetectionPostProcessLayer( const DetectionPostProcessDescriptor& descriptor, + const ConstTensor& anchors, const char* name = nullptr) override; IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor, diff --git a/src/armnn/layers/DetectionPostProcessLayer.cpp b/src/armnn/layers/DetectionPostProcessLayer.cpp index 3eea198f90..289cee0bd7 100644 --- a/src/armnn/layers/DetectionPostProcessLayer.cpp +++ b/src/armnn/layers/DetectionPostProcessLayer.cpp @@ -24,12 +24,15 @@ std::unique_ptr<IWorkload> DetectionPostProcessLayer::CreateWorkload(const armnn const armnn::IWorkloadFactory& factory) const { DetectionPostProcessQueueDescriptor descriptor; + descriptor.m_Anchors = m_Anchors.get(); return factory.CreateDetectionPostProcess(descriptor, PrepInfoAndDesc(descriptor, graph)); } DetectionPostProcessLayer* DetectionPostProcessLayer::Clone(Graph& graph) const { - return CloneBase<DetectionPostProcessLayer>(graph, m_Param, GetName()); + auto layer = CloneBase<DetectionPostProcessLayer>(graph, m_Param, GetName()); + layer->m_Anchors = m_Anchors ? std::make_unique<ScopedCpuTensorHandle>(*m_Anchors) : nullptr; + return std::move(layer); } void DetectionPostProcessLayer::ValidateTensorShapesFromInputs() @@ -72,7 +75,8 @@ Layer::ConstantTensors DetectionPostProcessLayer::GetConstantTensorsByRef() void DetectionPostProcessLayer::Accept(ILayerVisitor& visitor) const { - visitor.VisitDetectionPostProcessLayer(this, GetParameters(), GetName()); + ConstTensor anchorTensor(m_Anchors->GetTensorInfo(), m_Anchors->GetConstTensor<void>()); + visitor.VisitDetectionPostProcessLayer(this, GetParameters(), anchorTensor, GetName()); } } // namespace armnn diff --git a/src/armnn/test/TestLayerVisitor.hpp b/src/armnn/test/TestLayerVisitor.hpp index 5775df0e61..6b9503291a 100644 --- a/src/armnn/test/TestLayerVisitor.hpp +++ b/src/armnn/test/TestLayerVisitor.hpp @@ -61,6 +61,7 @@ public: virtual void VisitDetectionPostProcessLayer(const IConnectableLayer* layer, const DetectionPostProcessDescriptor& descriptor, + const ConstTensor& anchors, const char* name = nullptr) {}; virtual void VisitFullyConnectedLayer(const IConnectableLayer* layer, |