diff options
Diffstat (limited to 'src/graph/GraphBuilder.cpp')
-rw-r--r-- | src/graph/GraphBuilder.cpp | 45 |
1 files changed, 43 insertions, 2 deletions
diff --git a/src/graph/GraphBuilder.cpp b/src/graph/GraphBuilder.cpp index 2f74f065d5..2afc1e2533 100644 --- a/src/graph/GraphBuilder.cpp +++ b/src/graph/GraphBuilder.cpp @@ -737,9 +737,50 @@ NodeID GraphBuilder::add_upsample_node(Graph &g, NodeParams params, NodeIdxPair return create_simple_single_input_output_node<UpsampleLayerNode>(g, params, input, info, upsampling_policy); } -NodeID GraphBuilder::add_yolo_node(Graph &g, NodeParams params, NodeIdxPair input, ActivationLayerInfo act_info, int32_t num_classes) +NodeID GraphBuilder::add_yolo_node(Graph &g, NodeParams params, NodeIdxPair input, ActivationLayerInfo act_info) { - return create_simple_single_input_output_node<YOLOLayerNode>(g, params, input, act_info, num_classes); + check_nodeidx_pair(input, g); + + // Get input tensor descriptor + const TensorDescriptor input_tensor_desc = get_tensor_descriptor(g, g.node(input.node_id)->outputs()[0]); + const bool is_nhwc = input_tensor_desc.layout == DataLayout::NHWC; + + // Box format: [Objectness:1][Box:4][Classes:N] + + // Activate objectness and front part of the box + const Coordinates box_start(0, 0, 0); + const Coordinates box_end = is_nhwc ? Coordinates(3, -1, -1) : Coordinates(-1, -1, 3); + NodeID box = g.add_node<SliceLayerNode>(box_start, box_end); + NodeID act_box = g.add_node<ActivationLayerNode>(act_info); + set_node_params(g, box, params); + set_node_params(g, act_box, params); + g.add_connection(input.node_id, input.index, box, 0); + g.add_connection(box, 0, act_box, 0); + + // Immutable part + const Coordinates imm_start = is_nhwc ? Coordinates(3, 0, 0) : Coordinates(0, 0, 3); + const Coordinates imm_end = is_nhwc ? Coordinates(5, -1, -1) : Coordinates(-1, -1, 5); + NodeID imm = g.add_node<SliceLayerNode>(imm_start, imm_end); + set_node_params(g, imm, params); + g.add_connection(input.node_id, input.index, imm, 0); + + // Activation classes and end part of box + const Coordinates cls_start = is_nhwc ? Coordinates(5, 0, 0) : Coordinates(0, 0, 5); + const Coordinates cls_end = is_nhwc ? Coordinates(-1, -1, -1) : Coordinates(-1, -1, -1); + NodeID cls = g.add_node<SliceLayerNode>(cls_start, cls_end); + NodeID cls_act = g.add_node<ActivationLayerNode>(act_info); + set_node_params(g, cls, params); + set_node_params(g, cls_act, params); + g.add_connection(input.node_id, input.index, cls, 0); + g.add_connection(cls, 0, cls_act, 0); + + NodeID concat = g.add_node<ConcatenateLayerNode>(3, descriptors::ConcatLayerDescriptor(DataLayoutDimension::CHANNEL)); + set_node_params(g, concat, params); + g.add_connection(act_box, 0, concat, 0); + g.add_connection(imm, 0, concat, 1); + g.add_connection(cls_act, 0, concat, 2); + + return concat; } } // namespace graph } // namespace arm_compute |