diff options
Diffstat (limited to 'arm_compute/graph/frontend/Layers.h')
-rw-r--r-- | arm_compute/graph/frontend/Layers.h | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/arm_compute/graph/frontend/Layers.h b/arm_compute/graph/frontend/Layers.h index 3fc4af46d5..27a0cd3026 100644 --- a/arm_compute/graph/frontend/Layers.h +++ b/arm_compute/graph/frontend/Layers.h @@ -493,6 +493,39 @@ private: SubStream _ss_prior; DetectionOutputLayerInfo _detect_info; }; +/** DetectionOutputPostProcess Layer */ +class DetectionPostProcessLayer final : public ILayer +{ +public: + /** Construct a detection output layer. + * + * @param[in] sub_stream_class_prediction Class prediction graph sub-stream. + * @param[in] detect_info DetectionOutput parameters. + * @param[in] anchors Accessor to get anchors tensor data from. + * @param[in] out_quant_info (Optional) Output quantization info + */ + DetectionPostProcessLayer(SubStream &&sub_stream_class_prediction, DetectionPostProcessLayerInfo detect_info, ITensorAccessorUPtr anchors, + const QuantizationInfo out_quant_info = QuantizationInfo()) + : _sub_stream_class_prediction(std::move(sub_stream_class_prediction)), _detect_info(detect_info), _anchors(std::move(anchors)), _out_quant_info(std::move(out_quant_info)) + { + } + + NodeID create_layer(IStream &s) override + { + ARM_COMPUTE_ERROR_ON(_anchors == nullptr); + + NodeParams common_params = { name(), s.hints().target_hint }; + NodeIdxPair input_box_encoding = { s.tail_node(), 0 }; + NodeIdxPair input_class_prediction = { _sub_stream_class_prediction.tail_node(), 0 }; + return GraphBuilder::add_detection_post_process_node(s.graph(), common_params, input_box_encoding, input_class_prediction, _detect_info, std::move(_anchors), std::move(_out_quant_info)); + } + +private: + SubStream _sub_stream_class_prediction; + DetectionPostProcessLayerInfo _detect_info; + ITensorAccessorUPtr _anchors; + const QuantizationInfo _out_quant_info; +}; /** Dummy Layer */ class DummyLayer final : public ILayer { |