aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/graph/frontend/Layers.h
diff options
context:
space:
mode:
authorIsabella Gottardi <isabella.gottardi@arm.com>2019-01-08 13:48:44 +0000
committerIsabella Gottardi <isabella.gottardi@arm.com>2019-08-06 07:58:16 +0000
commita7acb3cbabeb66ce647684466a04c96b2963c9c9 (patch)
tree7988b75372c8ad1dfa3c8d028ab3a603a5e5a047 /arm_compute/graph/frontend/Layers.h
parent4746326ecb075dcfa123aaa8b38de5ec3e534b60 (diff)
downloadComputeLibrary-a7acb3cbabeb66ce647684466a04c96b2963c9c9.tar.gz
COMPMID-1849: Implement CPPDetectionPostProcessLayer
* Add DetectionPostProcessLayer * Add DetectionPostProcessLayer at the graph Change-Id: I7e56f6cffc26f112d26dfe74853085bb8ec7d849 Signed-off-by: Isabella Gottardi <isabella.gottardi@arm.com> Reviewed-on: https://review.mlplatform.org/c/1639 Reviewed-by: Giuseppe Rossini <giuseppe.rossini@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/graph/frontend/Layers.h')
-rw-r--r--arm_compute/graph/frontend/Layers.h33
1 files changed, 33 insertions, 0 deletions
diff --git a/arm_compute/graph/frontend/Layers.h b/arm_compute/graph/frontend/Layers.h
index 3fc4af46d5..27a0cd3026 100644
--- a/arm_compute/graph/frontend/Layers.h
+++ b/arm_compute/graph/frontend/Layers.h
@@ -493,6 +493,39 @@ private:
SubStream _ss_prior;
DetectionOutputLayerInfo _detect_info;
};
+/** DetectionOutputPostProcess Layer */
+class DetectionPostProcessLayer final : public ILayer
+{
+public:
+ /** Construct a detection output layer.
+ *
+ * @param[in] sub_stream_class_prediction Class prediction graph sub-stream.
+ * @param[in] detect_info DetectionOutput parameters.
+ * @param[in] anchors Accessor to get anchors tensor data from.
+ * @param[in] out_quant_info (Optional) Output quantization info
+ */
+ DetectionPostProcessLayer(SubStream &&sub_stream_class_prediction, DetectionPostProcessLayerInfo detect_info, ITensorAccessorUPtr anchors,
+ const QuantizationInfo out_quant_info = QuantizationInfo())
+ : _sub_stream_class_prediction(std::move(sub_stream_class_prediction)), _detect_info(detect_info), _anchors(std::move(anchors)), _out_quant_info(std::move(out_quant_info))
+ {
+ }
+
+ NodeID create_layer(IStream &s) override
+ {
+ ARM_COMPUTE_ERROR_ON(_anchors == nullptr);
+
+ NodeParams common_params = { name(), s.hints().target_hint };
+ NodeIdxPair input_box_encoding = { s.tail_node(), 0 };
+ NodeIdxPair input_class_prediction = { _sub_stream_class_prediction.tail_node(), 0 };
+ return GraphBuilder::add_detection_post_process_node(s.graph(), common_params, input_box_encoding, input_class_prediction, _detect_info, std::move(_anchors), std::move(_out_quant_info));
+ }
+
+private:
+ SubStream _sub_stream_class_prediction;
+ DetectionPostProcessLayerInfo _detect_info;
+ ITensorAccessorUPtr _anchors;
+ const QuantizationInfo _out_quant_info;
+};
/** Dummy Layer */
class DummyLayer final : public ILayer
{