From ec6997563a7cccf58431267cca39435ecd57cd32 Mon Sep 17 00:00:00 2001 From: Michele Di Giorgio Date: Fri, 22 Mar 2019 15:25:32 +0000 Subject: COMPMID-2076: Add StackLayer to the graph API Change-Id: Ifae23659c2471d9c052bc8adf066c5228d6e8b23 Signed-off-by: Michele Di Giorgio Reviewed-on: https://review.mlplatform.org/c/893 Tested-by: Arm Jenkins Reviewed-by: Georgios Pinitas Comments-Addressed: Arm Jenkins --- src/graph/GraphBuilder.cpp | 39 +++++++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 14 deletions(-) (limited to 'src/graph/GraphBuilder.cpp') diff --git a/src/graph/GraphBuilder.cpp b/src/graph/GraphBuilder.cpp index 74f60d5354..3f40aeadcb 100644 --- a/src/graph/GraphBuilder.cpp +++ b/src/graph/GraphBuilder.cpp @@ -81,6 +81,24 @@ NodeID create_simple_single_input_output_node(Graph &g, NodeParams ¶ms, Node return nid; } + +template +NodeID create_simple_multiple_input_single_output_node(Graph &g, NodeParams ¶ms, std::vector inputs, Args &&... args) +{ + ARM_COMPUTE_ERROR_ON(inputs.size() == 0); + + NodeID nid = g.add_node(std::forward(args)...); + + unsigned int i = 0; + for(const auto &input : inputs) + { + CHECK_NODEIDX_PAIR(input, g); + g.add_connection(input.node_id, input.index, nid, i++); + } + set_node_params(g, nid, params); + + return nid; +} } // namespace NodeID GraphBuilder::add_const_node(Graph &g, NodeParams params, TensorDescriptor desc, ITensorAccessorUPtr accessor) @@ -294,21 +312,9 @@ NodeID GraphBuilder::add_deconvolution_node(Graph &g, NodeParams params, NodeIdx return deconv_nid; } -NodeID GraphBuilder::add_concatenate_node(Graph &g, NodeParams params, std::vector inputs, descriptors::ConcatLayerDescriptor concat_descriptor) +NodeID GraphBuilder::add_concatenate_node(Graph &g, NodeParams params, const std::vector &inputs, descriptors::ConcatLayerDescriptor concat_descriptor) { - ARM_COMPUTE_ERROR_ON(inputs.size() == 0); - - NodeID nid = g.add_node(inputs.size(), concat_descriptor); - - unsigned int i = 0; - for(const auto &input : inputs) - { - CHECK_NODEIDX_PAIR(input, g); - g.add_connection(input.node_id, input.index, nid, i++); - } - set_node_params(g, nid, params); - - return nid; + return create_simple_multiple_input_single_output_node(g, params, inputs, inputs.size(), concat_descriptor); } NodeID GraphBuilder::add_depthwise_convolution_node(Graph &g, NodeParams params, NodeIdxPair input, Size2D kernel_spatial_extend, @@ -627,6 +633,11 @@ NodeID GraphBuilder::add_split_node(Graph &g, NodeParams params, NodeIdxPair inp return create_simple_single_input_output_node(g, params, input, num_splits, axis); } +NodeID GraphBuilder::add_stack_node(Graph &g, NodeParams params, const std::vector &inputs, int axis) +{ + return create_simple_multiple_input_single_output_node(g, params, inputs, inputs.size(), axis); +} + NodeID GraphBuilder::add_upsample_node(Graph &g, NodeParams params, NodeIdxPair input, Size2D info, InterpolationPolicy upsampling_policy) { return create_simple_single_input_output_node(g, params, input, info, upsampling_policy); -- cgit v1.2.1