aboutsummaryrefslogtreecommitdiff
path: root/src/graph/GraphBuilder.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/graph/GraphBuilder.cpp')
-rw-r--r--src/graph/GraphBuilder.cpp39
1 files changed, 25 insertions, 14 deletions
diff --git a/src/graph/GraphBuilder.cpp b/src/graph/GraphBuilder.cpp
index 74f60d5354..3f40aeadcb 100644
--- a/src/graph/GraphBuilder.cpp
+++ b/src/graph/GraphBuilder.cpp
@@ -81,6 +81,24 @@ NodeID create_simple_single_input_output_node(Graph &g, NodeParams &params, Node
return nid;
}
+
+template <typename NT, typename... Args>
+NodeID create_simple_multiple_input_single_output_node(Graph &g, NodeParams &params, std::vector<NodeIdxPair> inputs, Args &&... args)
+{
+ ARM_COMPUTE_ERROR_ON(inputs.size() == 0);
+
+ NodeID nid = g.add_node<NT>(std::forward<Args>(args)...);
+
+ unsigned int i = 0;
+ for(const auto &input : inputs)
+ {
+ CHECK_NODEIDX_PAIR(input, g);
+ g.add_connection(input.node_id, input.index, nid, i++);
+ }
+ set_node_params(g, nid, params);
+
+ return nid;
+}
} // namespace
NodeID GraphBuilder::add_const_node(Graph &g, NodeParams params, TensorDescriptor desc, ITensorAccessorUPtr accessor)
@@ -294,21 +312,9 @@ NodeID GraphBuilder::add_deconvolution_node(Graph &g, NodeParams params, NodeIdx
return deconv_nid;
}
-NodeID GraphBuilder::add_concatenate_node(Graph &g, NodeParams params, std::vector<NodeIdxPair> inputs, descriptors::ConcatLayerDescriptor concat_descriptor)
+NodeID GraphBuilder::add_concatenate_node(Graph &g, NodeParams params, const std::vector<NodeIdxPair> &inputs, descriptors::ConcatLayerDescriptor concat_descriptor)
{
- ARM_COMPUTE_ERROR_ON(inputs.size() == 0);
-
- NodeID nid = g.add_node<ConcatenateLayerNode>(inputs.size(), concat_descriptor);
-
- unsigned int i = 0;
- for(const auto &input : inputs)
- {
- CHECK_NODEIDX_PAIR(input, g);
- g.add_connection(input.node_id, input.index, nid, i++);
- }
- set_node_params(g, nid, params);
-
- return nid;
+ return create_simple_multiple_input_single_output_node<ConcatenateLayerNode>(g, params, inputs, inputs.size(), concat_descriptor);
}
NodeID GraphBuilder::add_depthwise_convolution_node(Graph &g, NodeParams params, NodeIdxPair input, Size2D kernel_spatial_extend,
@@ -627,6 +633,11 @@ NodeID GraphBuilder::add_split_node(Graph &g, NodeParams params, NodeIdxPair inp
return create_simple_single_input_output_node<SplitLayerNode>(g, params, input, num_splits, axis);
}
+NodeID GraphBuilder::add_stack_node(Graph &g, NodeParams params, const std::vector<NodeIdxPair> &inputs, int axis)
+{
+ return create_simple_multiple_input_single_output_node<StackLayerNode>(g, params, inputs, inputs.size(), axis);
+}
+
NodeID GraphBuilder::add_upsample_node(Graph &g, NodeParams params, NodeIdxPair input, Size2D info, InterpolationPolicy upsampling_policy)
{
return create_simple_single_input_output_node<UpsampleLayerNode>(g, params, input, info, upsampling_policy);