aboutsummaryrefslogtreecommitdiff
path: root/src/graph/GraphBuilder.cpp
diff options
context:
space:
mode:
authorIsabella Gottardi <isabella.gottardi@arm.com>2019-01-08 13:48:44 +0000
committerIsabella Gottardi <isabella.gottardi@arm.com>2019-08-06 07:58:16 +0000
commita7acb3cbabeb66ce647684466a04c96b2963c9c9 (patch)
tree7988b75372c8ad1dfa3c8d028ab3a603a5e5a047 /src/graph/GraphBuilder.cpp
parent4746326ecb075dcfa123aaa8b38de5ec3e534b60 (diff)
downloadComputeLibrary-a7acb3cbabeb66ce647684466a04c96b2963c9c9.tar.gz
COMPMID-1849: Implement CPPDetectionPostProcessLayer
* Add DetectionPostProcessLayer * Add DetectionPostProcessLayer at the graph Change-Id: I7e56f6cffc26f112d26dfe74853085bb8ec7d849 Signed-off-by: Isabella Gottardi <isabella.gottardi@arm.com> Reviewed-on: https://review.mlplatform.org/c/1639 Reviewed-by: Giuseppe Rossini <giuseppe.rossini@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/graph/GraphBuilder.cpp')
-rw-r--r--src/graph/GraphBuilder.cpp30
1 files changed, 30 insertions, 0 deletions
diff --git a/src/graph/GraphBuilder.cpp b/src/graph/GraphBuilder.cpp
index 54bd066712..228f2d211a 100644
--- a/src/graph/GraphBuilder.cpp
+++ b/src/graph/GraphBuilder.cpp
@@ -393,6 +393,36 @@ NodeID GraphBuilder::add_detection_output_node(Graph &g, NodeParams params, Node
return detect_nid;
}
+NodeID GraphBuilder::add_detection_post_process_node(Graph &g, NodeParams params, NodeIdxPair input_box_encoding, NodeIdxPair input_class_prediction, const DetectionPostProcessLayerInfo &detect_info,
+ ITensorAccessorUPtr anchors_accessor, const QuantizationInfo &anchor_quant_info)
+{
+ check_nodeidx_pair(input_box_encoding, g);
+ check_nodeidx_pair(input_class_prediction, g);
+
+ // Get input tensor descriptor
+ const TensorDescriptor input_box_encoding_tensor_desc = get_tensor_descriptor(g, g.node(input_box_encoding.node_id)->outputs()[0]);
+
+ // Calculate anchor descriptor
+ TensorDescriptor anchor_desc = input_box_encoding_tensor_desc;
+ if(!anchor_quant_info.empty())
+ {
+ anchor_desc.quant_info = anchor_quant_info;
+ }
+
+ // Create anchors nodes
+ auto anchors_nid = add_const_node_with_name(g, params, "Anchors", anchor_desc, std::move(anchors_accessor));
+
+ // Create detection_output node and connect
+ NodeID detect_nid = g.add_node<DetectionPostProcessLayerNode>(detect_info);
+ g.add_connection(input_box_encoding.node_id, input_box_encoding.index, detect_nid, 0);
+ g.add_connection(input_class_prediction.node_id, input_class_prediction.index, detect_nid, 1);
+ g.add_connection(anchors_nid, 0, detect_nid, 2);
+
+ set_node_params(g, detect_nid, params);
+
+ return detect_nid;
+}
+
NodeID GraphBuilder::add_dummy_node(Graph &g, NodeParams params, NodeIdxPair input, TensorShape shape)
{
return create_simple_single_input_output_node<DummyNode>(g, params, input, shape);