diff options
Diffstat (limited to 'src/graph/GraphBuilder.cpp')
-rw-r--r-- | src/graph/GraphBuilder.cpp | 30 |
1 files changed, 30 insertions, 0 deletions
diff --git a/src/graph/GraphBuilder.cpp b/src/graph/GraphBuilder.cpp index 54bd066712..228f2d211a 100644 --- a/src/graph/GraphBuilder.cpp +++ b/src/graph/GraphBuilder.cpp @@ -393,6 +393,36 @@ NodeID GraphBuilder::add_detection_output_node(Graph &g, NodeParams params, Node return detect_nid; } +NodeID GraphBuilder::add_detection_post_process_node(Graph &g, NodeParams params, NodeIdxPair input_box_encoding, NodeIdxPair input_class_prediction, const DetectionPostProcessLayerInfo &detect_info, + ITensorAccessorUPtr anchors_accessor, const QuantizationInfo &anchor_quant_info) +{ + check_nodeidx_pair(input_box_encoding, g); + check_nodeidx_pair(input_class_prediction, g); + + // Get input tensor descriptor + const TensorDescriptor input_box_encoding_tensor_desc = get_tensor_descriptor(g, g.node(input_box_encoding.node_id)->outputs()[0]); + + // Calculate anchor descriptor + TensorDescriptor anchor_desc = input_box_encoding_tensor_desc; + if(!anchor_quant_info.empty()) + { + anchor_desc.quant_info = anchor_quant_info; + } + + // Create anchors nodes + auto anchors_nid = add_const_node_with_name(g, params, "Anchors", anchor_desc, std::move(anchors_accessor)); + + // Create detection_output node and connect + NodeID detect_nid = g.add_node<DetectionPostProcessLayerNode>(detect_info); + g.add_connection(input_box_encoding.node_id, input_box_encoding.index, detect_nid, 0); + g.add_connection(input_class_prediction.node_id, input_class_prediction.index, detect_nid, 1); + g.add_connection(anchors_nid, 0, detect_nid, 2); + + set_node_params(g, detect_nid, params); + + return detect_nid; +} + NodeID GraphBuilder::add_dummy_node(Graph &g, NodeParams params, NodeIdxPair input, TensorShape shape) { return create_simple_single_input_output_node<DummyNode>(g, params, input, shape); |