From a7acb3cbabeb66ce647684466a04c96b2963c9c9 Mon Sep 17 00:00:00 2001 From: Isabella Gottardi Date: Tue, 8 Jan 2019 13:48:44 +0000 Subject: COMPMID-1849: Implement CPPDetectionPostProcessLayer * Add DetectionPostProcessLayer * Add DetectionPostProcessLayer at the graph Change-Id: I7e56f6cffc26f112d26dfe74853085bb8ec7d849 Signed-off-by: Isabella Gottardi Reviewed-on: https://review.mlplatform.org/c/1639 Reviewed-by: Giuseppe Rossini Tested-by: Arm Jenkins --- src/graph/GraphBuilder.cpp | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) (limited to 'src/graph/GraphBuilder.cpp') diff --git a/src/graph/GraphBuilder.cpp b/src/graph/GraphBuilder.cpp index 54bd066712..228f2d211a 100644 --- a/src/graph/GraphBuilder.cpp +++ b/src/graph/GraphBuilder.cpp @@ -393,6 +393,36 @@ NodeID GraphBuilder::add_detection_output_node(Graph &g, NodeParams params, Node return detect_nid; } +NodeID GraphBuilder::add_detection_post_process_node(Graph &g, NodeParams params, NodeIdxPair input_box_encoding, NodeIdxPair input_class_prediction, const DetectionPostProcessLayerInfo &detect_info, + ITensorAccessorUPtr anchors_accessor, const QuantizationInfo &anchor_quant_info) +{ + check_nodeidx_pair(input_box_encoding, g); + check_nodeidx_pair(input_class_prediction, g); + + // Get input tensor descriptor + const TensorDescriptor input_box_encoding_tensor_desc = get_tensor_descriptor(g, g.node(input_box_encoding.node_id)->outputs()[0]); + + // Calculate anchor descriptor + TensorDescriptor anchor_desc = input_box_encoding_tensor_desc; + if(!anchor_quant_info.empty()) + { + anchor_desc.quant_info = anchor_quant_info; + } + + // Create anchors nodes + auto anchors_nid = add_const_node_with_name(g, params, "Anchors", anchor_desc, std::move(anchors_accessor)); + + // Create detection_output node and connect + NodeID detect_nid = g.add_node(detect_info); + g.add_connection(input_box_encoding.node_id, input_box_encoding.index, detect_nid, 0); + g.add_connection(input_class_prediction.node_id, input_class_prediction.index, detect_nid, 1); + g.add_connection(anchors_nid, 0, detect_nid, 2); + + set_node_params(g, detect_nid, params); + + return detect_nid; +} + NodeID GraphBuilder::add_dummy_node(Graph &g, NodeParams params, NodeIdxPair input, TensorShape shape) { return create_simple_single_input_output_node(g, params, input, shape); -- cgit v1.2.1