diff options
Diffstat (limited to 'src/graph/GraphBuilder.cpp')
-rw-r--r-- | src/graph/GraphBuilder.cpp | 18 |
1 files changed, 11 insertions, 7 deletions
diff --git a/src/graph/GraphBuilder.cpp b/src/graph/GraphBuilder.cpp index d26039ec35..b3721719d9 100644 --- a/src/graph/GraphBuilder.cpp +++ b/src/graph/GraphBuilder.cpp @@ -88,10 +88,14 @@ NodeID create_grouped_convolution(Graph &g, const NodeParams ¶ms, NodeIdxPai bool has_bias = (bias != EmptyNodeID); // Split input - NodeID input_split = GraphBuilder::add_split_node(g, params, input, num_groups, 2); + const TensorDescriptor input_tensor_desc = get_tensor_descriptor(g, g.node(input.node_id)->outputs()[0]); + const unsigned int input_idx = get_dimension_idx(input_tensor_desc, DataLayoutDimension::CHANNEL); + NodeID input_split = GraphBuilder::add_split_node(g, params, input, num_groups, input_idx); // Split weights - NodeID weights_split = GraphBuilder::add_split_node(g, params, { weights, 0 }, num_groups, 3); + const TensorDescriptor weights_tensor_desc = get_tensor_descriptor(g, g.node(weights)->outputs()[0]); + const unsigned int batch_idx = get_dimension_idx(weights_tensor_desc, DataLayoutDimension::BATCHES); + NodeID weights_split = GraphBuilder::add_split_node(g, params, { weights, 0 }, num_groups, batch_idx); // Split bias NodeID bias_split = EmptyNodeID; @@ -122,7 +126,7 @@ NodeID create_grouped_convolution(Graph &g, const NodeParams ¶ms, NodeIdxPai } // Depth concatenate output - return GraphBuilder::add_depth_concatenate_node(g, params, convolution_outputs); + return GraphBuilder::add_concatenate_node(g, params, convolution_outputs, DataLayoutDimension::CHANNEL); } } // namespace @@ -329,11 +333,11 @@ NodeID GraphBuilder::add_deconvolution_node(Graph &g, NodeParams params, NodeIdx return deconv_nid; } -NodeID GraphBuilder::add_depth_concatenate_node(Graph &g, NodeParams params, std::vector<NodeIdxPair> inputs) +NodeID GraphBuilder::add_concatenate_node(Graph &g, NodeParams params, std::vector<NodeIdxPair> inputs, DataLayoutDimension axis) { ARM_COMPUTE_ERROR_ON(inputs.size() == 0); - NodeID nid = g.add_node<DepthConcatenateLayerNode>(inputs.size()); + NodeID nid = g.add_node<ConcatenateLayerNode>(inputs.size(), axis); unsigned int i = 0; for(const auto &input : inputs) @@ -508,9 +512,9 @@ NodeID GraphBuilder::add_scale_layer(Graph &g, const NodeParams ¶ms, NodeIdx NodeIdxPair add_const_nidxp = { add_const_nid, 0 }; // Create node and connect - NodeID mul_node = GraphBuilder::add_elementwise_node(g, params, input, mul_const_nidxp, EltwiseOperation::MUL); + NodeID mul_node = GraphBuilder::add_elementwise_node(g, params, input, mul_const_nidxp, EltwiseOperation::Mul); NodeIdxPair mulnode_nidxp = { mul_node, 0 }; - NodeID add_node = GraphBuilder::add_elementwise_node(g, params, mulnode_nidxp, add_const_nidxp, EltwiseOperation::ADD); + NodeID add_node = GraphBuilder::add_elementwise_node(g, params, mulnode_nidxp, add_const_nidxp, EltwiseOperation::Add); return add_node; } |