diff options
author | Michele Di Giorgio <michele.digiorgio@arm.com> | 2019-03-22 15:25:32 +0000 |
---|---|---|
committer | Michele Di Giorgio <michele.digiorgio@arm.com> | 2019-03-25 17:16:41 +0000 |
commit | ec6997563a7cccf58431267cca39435ecd57cd32 (patch) | |
tree | da355630d63858e31bf84acd5a0f588a1ea4f61f /src/graph/GraphBuilder.cpp | |
parent | 2761c2f0b60175469e959982a25ff0abdca6c9ce (diff) | |
download | ComputeLibrary-ec6997563a7cccf58431267cca39435ecd57cd32.tar.gz |
COMPMID-2076: Add StackLayer to the graph API
Change-Id: Ifae23659c2471d9c052bc8adf066c5228d6e8b23
Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-on: https://review.mlplatform.org/c/893
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/graph/GraphBuilder.cpp')
-rw-r--r-- | src/graph/GraphBuilder.cpp | 39 |
1 files changed, 25 insertions, 14 deletions
diff --git a/src/graph/GraphBuilder.cpp b/src/graph/GraphBuilder.cpp index 74f60d5354..3f40aeadcb 100644 --- a/src/graph/GraphBuilder.cpp +++ b/src/graph/GraphBuilder.cpp @@ -81,6 +81,24 @@ NodeID create_simple_single_input_output_node(Graph &g, NodeParams ¶ms, Node return nid; } + +template <typename NT, typename... Args> +NodeID create_simple_multiple_input_single_output_node(Graph &g, NodeParams ¶ms, std::vector<NodeIdxPair> inputs, Args &&... args) +{ + ARM_COMPUTE_ERROR_ON(inputs.size() == 0); + + NodeID nid = g.add_node<NT>(std::forward<Args>(args)...); + + unsigned int i = 0; + for(const auto &input : inputs) + { + CHECK_NODEIDX_PAIR(input, g); + g.add_connection(input.node_id, input.index, nid, i++); + } + set_node_params(g, nid, params); + + return nid; +} } // namespace NodeID GraphBuilder::add_const_node(Graph &g, NodeParams params, TensorDescriptor desc, ITensorAccessorUPtr accessor) @@ -294,21 +312,9 @@ NodeID GraphBuilder::add_deconvolution_node(Graph &g, NodeParams params, NodeIdx return deconv_nid; } -NodeID GraphBuilder::add_concatenate_node(Graph &g, NodeParams params, std::vector<NodeIdxPair> inputs, descriptors::ConcatLayerDescriptor concat_descriptor) +NodeID GraphBuilder::add_concatenate_node(Graph &g, NodeParams params, const std::vector<NodeIdxPair> &inputs, descriptors::ConcatLayerDescriptor concat_descriptor) { - ARM_COMPUTE_ERROR_ON(inputs.size() == 0); - - NodeID nid = g.add_node<ConcatenateLayerNode>(inputs.size(), concat_descriptor); - - unsigned int i = 0; - for(const auto &input : inputs) - { - CHECK_NODEIDX_PAIR(input, g); - g.add_connection(input.node_id, input.index, nid, i++); - } - set_node_params(g, nid, params); - - return nid; + return create_simple_multiple_input_single_output_node<ConcatenateLayerNode>(g, params, inputs, inputs.size(), concat_descriptor); } NodeID GraphBuilder::add_depthwise_convolution_node(Graph &g, NodeParams params, NodeIdxPair input, Size2D kernel_spatial_extend, @@ -627,6 +633,11 @@ NodeID GraphBuilder::add_split_node(Graph &g, NodeParams params, NodeIdxPair inp return create_simple_single_input_output_node<SplitLayerNode>(g, params, input, num_splits, axis); } +NodeID GraphBuilder::add_stack_node(Graph &g, NodeParams params, const std::vector<NodeIdxPair> &inputs, int axis) +{ + return create_simple_multiple_input_single_output_node<StackLayerNode>(g, params, inputs, inputs.size(), axis); +} + NodeID GraphBuilder::add_upsample_node(Graph &g, NodeParams params, NodeIdxPair input, Size2D info, InterpolationPolicy upsampling_policy) { return create_simple_single_input_output_node<UpsampleLayerNode>(g, params, input, info, upsampling_policy); |