aboutsummaryrefslogtreecommitdiff
path: root/src/graph2/GraphBuilder.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/graph2/GraphBuilder.cpp')
-rw-r--r--src/graph2/GraphBuilder.cpp135
1 files changed, 78 insertions, 57 deletions
diff --git a/src/graph2/GraphBuilder.cpp b/src/graph2/GraphBuilder.cpp
index aaf70c4e61..e6fc2afe21 100644
--- a/src/graph2/GraphBuilder.cpp
+++ b/src/graph2/GraphBuilder.cpp
@@ -46,6 +46,7 @@ Status set_node_params(Graph &g, NodeID nid, NodeParams &params)
return Status{};
}
+
Status set_accessor_on_node(Graph &g, NodeID nid, bool is_output, size_t idx, ITensorAccessorUPtr accessor)
{
INode *node = g.node(nid);
@@ -66,6 +67,55 @@ NodeID add_const_node_with_name(Graph &g, NodeParams params, const std::string &
set_node_params(g, nid, params);
return nid;
}
+
+template <typename NT, typename... Args>
+NodeID create_simple_single_input_output_node(Graph &g, NodeParams &params, NodeIdxPair input, Args &&... args)
+{
+ CHECK_NODEIDX_PAIR(input, g);
+
+ NodeID nid = g.add_node<NT>(std::forward<Args>(args)...);
+ g.add_connection(input.node_id, input.index, nid, 0);
+ set_node_params(g, nid, params);
+
+ return nid;
+}
+
+NodeID create_grouped_convolution(Graph &g, NodeParams &params, NodeIdxPair input, NodeID weights, NodeID bias,
+ PadStrideInfo conv_info, ConvolutionMethod method, unsigned int num_groups)
+{
+ bool has_bias = (bias != EmptyNodeID);
+
+ // Split input
+ NodeID input_split = GraphBuilder::add_split_node(g, params, input, num_groups, 2);
+
+ // Split weights
+ NodeID weights_split = GraphBuilder::add_split_node(g, params, { weights, 0 }, num_groups, 3);
+
+ // Split bias
+ NodeID bias_split = EmptyNodeID;
+ if(has_bias)
+ {
+ // Split bias
+ bias_split = GraphBuilder::add_split_node(g, params, { bias, 0 }, num_groups, 0);
+ }
+
+ std::vector<NodeIdxPair> convolution_outputs;
+ for(unsigned int i = 0; i < num_groups; ++i)
+ {
+ NodeID conv_nid = g.add_node<ConvolutionLayerNode>(conv_info, method);
+ g.add_connection(input_split, i, conv_nid, 0);
+ g.add_connection(weights_split, i, conv_nid, 1);
+ if(has_bias)
+ {
+ g.add_connection(bias_split, i, conv_nid, 2);
+ }
+ set_node_params(g, conv_nid, params);
+ convolution_outputs.push_back({ conv_nid, 0 });
+ }
+
+ // Depth concatenate output
+ return GraphBuilder::add_depth_concatenate_node(g, params, convolution_outputs);
+}
} // namespace
NodeID GraphBuilder::add_const_node(Graph &g, NodeParams params, TensorDescriptor desc, ITensorAccessorUPtr accessor)
@@ -98,13 +148,7 @@ NodeID GraphBuilder::add_output_node(Graph &g, NodeParams params, NodeIdxPair in
NodeID GraphBuilder::add_activation_node(Graph &g, NodeParams params, NodeIdxPair input, ActivationLayerInfo act_info)
{
- CHECK_NODEIDX_PAIR(input, g);
-
- NodeID nid = g.add_node<ActivationLayerNode>(act_info);
- g.add_connection(input.node_id, input.index, nid, 0);
- set_node_params(g, nid, params);
-
- return nid;
+ return create_simple_single_input_output_node<ActivationLayerNode>(g, params, input, act_info);
}
NodeID GraphBuilder::add_batch_normalization_node(Graph &g, NodeParams params, NodeIdxPair input, float epsilon,
@@ -161,7 +205,7 @@ NodeID GraphBuilder::add_batch_normalization_node(Graph &g, NodeParams params, N
NodeID GraphBuilder::add_convolution_node(Graph &g, NodeParams params, NodeIdxPair input,
Size2D kernel_spatial_extend, unsigned int depth, PadStrideInfo conv_info,
- ConvolutionMethod method,
+ unsigned int num_groups, ConvolutionMethod method,
ITensorAccessorUPtr weights_accessor, ITensorAccessorUPtr bias_accessor)
{
CHECK_NODEIDX_PAIR(input, g);
@@ -175,7 +219,7 @@ NodeID GraphBuilder::add_convolution_node(Graph &g, NodeParams params, NodeIdxPa
// Create weights node
TensorDescriptor w_desc = input_tensor_desc;
- w_desc.shape = TensorShape(kernel_spatial_extend.width, kernel_spatial_extend.height, w_desc.shape.z(), depth);
+ w_desc.shape = TensorShape(kernel_spatial_extend.width, kernel_spatial_extend.height, w_desc.shape.z() / num_groups, depth);
NodeID w_nid = add_const_node_with_name(g, params, "Weights", w_desc, std::move(weights_accessor));
// Create bias nodes
@@ -187,17 +231,24 @@ NodeID GraphBuilder::add_convolution_node(Graph &g, NodeParams params, NodeIdxPa
b_nid = add_const_node_with_name(g, params, "Bias", b_desc, std::move(bias_accessor));
}
- // Create convolution node and connect
- NodeID conv_nid = g.add_node<ConvolutionLayerNode>(conv_info, method);
- g.add_connection(input.node_id, input.index, conv_nid, 0);
- g.add_connection(w_nid, 0, conv_nid, 1);
- if(has_bias)
+ if(num_groups == 1)
{
- g.add_connection(b_nid, 0, conv_nid, 2);
+ // Create convolution node and connect
+ NodeID conv_nid = g.add_node<ConvolutionLayerNode>(conv_info, method);
+ g.add_connection(input.node_id, input.index, conv_nid, 0);
+ g.add_connection(w_nid, 0, conv_nid, 1);
+ if(has_bias)
+ {
+ g.add_connection(b_nid, 0, conv_nid, 2);
+ }
+ set_node_params(g, conv_nid, params);
+
+ return conv_nid;
+ }
+ else
+ {
+ return create_grouped_convolution(g, params, input, w_nid, b_nid, conv_info, method, num_groups);
}
- set_node_params(g, conv_nid, params);
-
- return conv_nid;
}
NodeID GraphBuilder::add_depth_concatenate_node(Graph &g, NodeParams params, std::vector<NodeIdxPair> inputs)
@@ -273,14 +324,7 @@ NodeID GraphBuilder::add_elementwise_node(Graph &g, NodeParams params, NodeIdxPa
NodeID GraphBuilder::add_flatten_node(Graph &g, NodeParams params, NodeIdxPair input)
{
- CHECK_NODEIDX_PAIR(input, g);
-
- NodeID nid = g.add_node<FlattenLayerNode>();
- g.add_connection(input.node_id, input.index, nid, 0);
-
- set_node_params(g, nid, params);
-
- return nid;
+ return create_simple_single_input_output_node<FlattenLayerNode>(g, params, input);
}
NodeID GraphBuilder::add_fully_connected_layer(Graph &g, NodeParams params, NodeIdxPair input, unsigned int num_outputs,
@@ -324,50 +368,27 @@ NodeID GraphBuilder::add_fully_connected_layer(Graph &g, NodeParams params, Node
NodeID GraphBuilder::add_normalization_node(Graph &g, NodeParams params, NodeIdxPair input, NormalizationLayerInfo norm_info)
{
- CHECK_NODEIDX_PAIR(input, g);
-
- NodeID nid = g.add_node<NormalizationLayerNode>(norm_info);
- g.add_connection(input.node_id, input.index, nid, 0);
-
- set_node_params(g, nid, params);
-
- return nid;
+ return create_simple_single_input_output_node<NormalizationLayerNode>(g, params, input, norm_info);
}
NodeID GraphBuilder::add_pooling_node(Graph &g, NodeParams params, NodeIdxPair input, PoolingLayerInfo pool_info)
{
- CHECK_NODEIDX_PAIR(input, g);
-
- NodeID nid = g.add_node<PoolingLayerNode>(pool_info);
- g.add_connection(input.node_id, input.index, nid, 0);
-
- set_node_params(g, nid, params);
-
- return nid;
+ return create_simple_single_input_output_node<PoolingLayerNode>(g, params, input, pool_info);
}
NodeID GraphBuilder::add_reshape_node(Graph &g, NodeParams params, NodeIdxPair input, TensorShape shape)
{
- CHECK_NODEIDX_PAIR(input, g);
-
- NodeID nid = g.add_node<ReshapeLayerNode>(shape);
- g.add_connection(input.node_id, input.index, nid, 0);
-
- set_node_params(g, nid, params);
-
- return nid;
+ return create_simple_single_input_output_node<ReshapeLayerNode>(g, params, input, shape);
}
NodeID GraphBuilder::add_softmax_node(Graph &g, NodeParams params, NodeIdxPair input, float beta)
{
- CHECK_NODEIDX_PAIR(input, g);
-
- NodeID nid = g.add_node<SoftmaxLayerNode>(beta);
- g.add_connection(input.node_id, input.index, nid, 0);
-
- set_node_params(g, nid, params);
+ return create_simple_single_input_output_node<SoftmaxLayerNode>(g, params, input, beta);
+}
- return nid;
+NodeID GraphBuilder::add_split_node(Graph &g, NodeParams params, NodeIdxPair input, unsigned int num_splits, unsigned int axis)
+{
+ return create_simple_single_input_output_node<SplitLayerNode>(g, params, input, num_splits, axis);
}
} // namespace graph2
} // namespace arm_compute \ No newline at end of file