aboutsummaryrefslogtreecommitdiff
path: root/src/graph/GraphBuilder.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/graph/GraphBuilder.cpp')
-rw-r--r--src/graph/GraphBuilder.cpp30
1 files changed, 30 insertions, 0 deletions
diff --git a/src/graph/GraphBuilder.cpp b/src/graph/GraphBuilder.cpp
index 54bd066712..228f2d211a 100644
--- a/src/graph/GraphBuilder.cpp
+++ b/src/graph/GraphBuilder.cpp
@@ -393,6 +393,36 @@ NodeID GraphBuilder::add_detection_output_node(Graph &g, NodeParams params, Node
return detect_nid;
}
+NodeID GraphBuilder::add_detection_post_process_node(Graph &g, NodeParams params, NodeIdxPair input_box_encoding, NodeIdxPair input_class_prediction, const DetectionPostProcessLayerInfo &detect_info,
+ ITensorAccessorUPtr anchors_accessor, const QuantizationInfo &anchor_quant_info)
+{
+ check_nodeidx_pair(input_box_encoding, g);
+ check_nodeidx_pair(input_class_prediction, g);
+
+ // Get input tensor descriptor
+ const TensorDescriptor input_box_encoding_tensor_desc = get_tensor_descriptor(g, g.node(input_box_encoding.node_id)->outputs()[0]);
+
+ // Calculate anchor descriptor
+ TensorDescriptor anchor_desc = input_box_encoding_tensor_desc;
+ if(!anchor_quant_info.empty())
+ {
+ anchor_desc.quant_info = anchor_quant_info;
+ }
+
+ // Create anchors nodes
+ auto anchors_nid = add_const_node_with_name(g, params, "Anchors", anchor_desc, std::move(anchors_accessor));
+
+ // Create detection_output node and connect
+ NodeID detect_nid = g.add_node<DetectionPostProcessLayerNode>(detect_info);
+ g.add_connection(input_box_encoding.node_id, input_box_encoding.index, detect_nid, 0);
+ g.add_connection(input_class_prediction.node_id, input_class_prediction.index, detect_nid, 1);
+ g.add_connection(anchors_nid, 0, detect_nid, 2);
+
+ set_node_params(g, detect_nid, params);
+
+ return detect_nid;
+}
+
NodeID GraphBuilder::add_dummy_node(Graph &g, NodeParams params, NodeIdxPair input, TensorShape shape)
{
return create_simple_single_input_output_node<DummyNode>(g, params, input, shape);