From 0b1c2db5c29ed80b7f4dd0c4fd6d4ed91b3d1538 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 4 Dec 2020 15:51:34 +0000 Subject: Remove (NE/CL)YoloLayer support YOLO layer is too specialized and specific to a single model type. Can be decomposed using split, activation and concatenate layers Partially Resolves: COMPMID-3996 Signed-off-by: Georgios Pinitas Change-Id: I3cde88f8d4cc7d8c70ce1bb3b32b00f8d09bdca2 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4678 Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio --- src/graph/GraphBuilder.cpp | 45 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 2 deletions(-) (limited to 'src/graph/GraphBuilder.cpp') diff --git a/src/graph/GraphBuilder.cpp b/src/graph/GraphBuilder.cpp index 2f74f065d5..2afc1e2533 100644 --- a/src/graph/GraphBuilder.cpp +++ b/src/graph/GraphBuilder.cpp @@ -737,9 +737,50 @@ NodeID GraphBuilder::add_upsample_node(Graph &g, NodeParams params, NodeIdxPair return create_simple_single_input_output_node(g, params, input, info, upsampling_policy); } -NodeID GraphBuilder::add_yolo_node(Graph &g, NodeParams params, NodeIdxPair input, ActivationLayerInfo act_info, int32_t num_classes) +NodeID GraphBuilder::add_yolo_node(Graph &g, NodeParams params, NodeIdxPair input, ActivationLayerInfo act_info) { - return create_simple_single_input_output_node(g, params, input, act_info, num_classes); + check_nodeidx_pair(input, g); + + // Get input tensor descriptor + const TensorDescriptor input_tensor_desc = get_tensor_descriptor(g, g.node(input.node_id)->outputs()[0]); + const bool is_nhwc = input_tensor_desc.layout == DataLayout::NHWC; + + // Box format: [Objectness:1][Box:4][Classes:N] + + // Activate objectness and front part of the box + const Coordinates box_start(0, 0, 0); + const Coordinates box_end = is_nhwc ? Coordinates(3, -1, -1) : Coordinates(-1, -1, 3); + NodeID box = g.add_node(box_start, box_end); + NodeID act_box = g.add_node(act_info); + set_node_params(g, box, params); + set_node_params(g, act_box, params); + g.add_connection(input.node_id, input.index, box, 0); + g.add_connection(box, 0, act_box, 0); + + // Immutable part + const Coordinates imm_start = is_nhwc ? Coordinates(3, 0, 0) : Coordinates(0, 0, 3); + const Coordinates imm_end = is_nhwc ? Coordinates(5, -1, -1) : Coordinates(-1, -1, 5); + NodeID imm = g.add_node(imm_start, imm_end); + set_node_params(g, imm, params); + g.add_connection(input.node_id, input.index, imm, 0); + + // Activation classes and end part of box + const Coordinates cls_start = is_nhwc ? Coordinates(5, 0, 0) : Coordinates(0, 0, 5); + const Coordinates cls_end = is_nhwc ? Coordinates(-1, -1, -1) : Coordinates(-1, -1, -1); + NodeID cls = g.add_node(cls_start, cls_end); + NodeID cls_act = g.add_node(act_info); + set_node_params(g, cls, params); + set_node_params(g, cls_act, params); + g.add_connection(input.node_id, input.index, cls, 0); + g.add_connection(cls, 0, cls_act, 0); + + NodeID concat = g.add_node(3, descriptors::ConcatLayerDescriptor(DataLayoutDimension::CHANNEL)); + set_node_params(g, concat, params); + g.add_connection(act_box, 0, concat, 0); + g.add_connection(imm, 0, concat, 1); + g.add_connection(cls_act, 0, concat, 2); + + return concat; } } // namespace graph } // namespace arm_compute -- cgit v1.2.1