aboutsummaryrefslogtreecommitdiff
path: root/src/graph/GraphBuilder.cpp
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2019-03-22 15:25:32 +0000
committerMichele Di Giorgio <michele.digiorgio@arm.com>2019-03-25 17:16:41 +0000
commitec6997563a7cccf58431267cca39435ecd57cd32 (patch)
treeda355630d63858e31bf84acd5a0f588a1ea4f61f /src/graph/GraphBuilder.cpp
parent2761c2f0b60175469e959982a25ff0abdca6c9ce (diff)
downloadComputeLibrary-ec6997563a7cccf58431267cca39435ecd57cd32.tar.gz
COMPMID-2076: Add StackLayer to the graph API
Change-Id: Ifae23659c2471d9c052bc8adf066c5228d6e8b23 Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-on: https://review.mlplatform.org/c/893 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/graph/GraphBuilder.cpp')
-rw-r--r--src/graph/GraphBuilder.cpp39
1 files changed, 25 insertions, 14 deletions
diff --git a/src/graph/GraphBuilder.cpp b/src/graph/GraphBuilder.cpp
index 74f60d5354..3f40aeadcb 100644
--- a/src/graph/GraphBuilder.cpp
+++ b/src/graph/GraphBuilder.cpp
@@ -81,6 +81,24 @@ NodeID create_simple_single_input_output_node(Graph &g, NodeParams &params, Node
return nid;
}
+
+template <typename NT, typename... Args>
+NodeID create_simple_multiple_input_single_output_node(Graph &g, NodeParams &params, std::vector<NodeIdxPair> inputs, Args &&... args)
+{
+ ARM_COMPUTE_ERROR_ON(inputs.size() == 0);
+
+ NodeID nid = g.add_node<NT>(std::forward<Args>(args)...);
+
+ unsigned int i = 0;
+ for(const auto &input : inputs)
+ {
+ CHECK_NODEIDX_PAIR(input, g);
+ g.add_connection(input.node_id, input.index, nid, i++);
+ }
+ set_node_params(g, nid, params);
+
+ return nid;
+}
} // namespace
NodeID GraphBuilder::add_const_node(Graph &g, NodeParams params, TensorDescriptor desc, ITensorAccessorUPtr accessor)
@@ -294,21 +312,9 @@ NodeID GraphBuilder::add_deconvolution_node(Graph &g, NodeParams params, NodeIdx
return deconv_nid;
}
-NodeID GraphBuilder::add_concatenate_node(Graph &g, NodeParams params, std::vector<NodeIdxPair> inputs, descriptors::ConcatLayerDescriptor concat_descriptor)
+NodeID GraphBuilder::add_concatenate_node(Graph &g, NodeParams params, const std::vector<NodeIdxPair> &inputs, descriptors::ConcatLayerDescriptor concat_descriptor)
{
- ARM_COMPUTE_ERROR_ON(inputs.size() == 0);
-
- NodeID nid = g.add_node<ConcatenateLayerNode>(inputs.size(), concat_descriptor);
-
- unsigned int i = 0;
- for(const auto &input : inputs)
- {
- CHECK_NODEIDX_PAIR(input, g);
- g.add_connection(input.node_id, input.index, nid, i++);
- }
- set_node_params(g, nid, params);
-
- return nid;
+ return create_simple_multiple_input_single_output_node<ConcatenateLayerNode>(g, params, inputs, inputs.size(), concat_descriptor);
}
NodeID GraphBuilder::add_depthwise_convolution_node(Graph &g, NodeParams params, NodeIdxPair input, Size2D kernel_spatial_extend,
@@ -627,6 +633,11 @@ NodeID GraphBuilder::add_split_node(Graph &g, NodeParams params, NodeIdxPair inp
return create_simple_single_input_output_node<SplitLayerNode>(g, params, input, num_splits, axis);
}
+NodeID GraphBuilder::add_stack_node(Graph &g, NodeParams params, const std::vector<NodeIdxPair> &inputs, int axis)
+{
+ return create_simple_multiple_input_single_output_node<StackLayerNode>(g, params, inputs, inputs.size(), axis);
+}
+
NodeID GraphBuilder::add_upsample_node(Graph &g, NodeParams params, NodeIdxPair input, Size2D info, InterpolationPolicy upsampling_policy)
{
return create_simple_single_input_output_node<UpsampleLayerNode>(g, params, input, info, upsampling_policy);