aboutsummaryrefslogtreecommitdiff
path: root/src/graph/GraphBuilder.cpp
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-07-20 13:23:44 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commite2220551b7a64b929650ba9a60529c31e70c13c5 (patch)
tree5d609887f15b4392cdade7bb388710ceafc62260 /src/graph/GraphBuilder.cpp
parenteff8d95991205e874091576e2d225f63246dd0bb (diff)
downloadComputeLibrary-e2220551b7a64b929650ba9a60529c31e70c13c5.tar.gz
COMPMID-1367: Enable NHWC in graph examples
Change-Id: Iabc54a3a1bdcd46a9a921cda39c7c85fef672b72 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/141449 Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/graph/GraphBuilder.cpp')
-rw-r--r--src/graph/GraphBuilder.cpp18
1 files changed, 11 insertions, 7 deletions
diff --git a/src/graph/GraphBuilder.cpp b/src/graph/GraphBuilder.cpp
index d26039ec35..b3721719d9 100644
--- a/src/graph/GraphBuilder.cpp
+++ b/src/graph/GraphBuilder.cpp
@@ -88,10 +88,14 @@ NodeID create_grouped_convolution(Graph &g, const NodeParams &params, NodeIdxPai
bool has_bias = (bias != EmptyNodeID);
// Split input
- NodeID input_split = GraphBuilder::add_split_node(g, params, input, num_groups, 2);
+ const TensorDescriptor input_tensor_desc = get_tensor_descriptor(g, g.node(input.node_id)->outputs()[0]);
+ const unsigned int input_idx = get_dimension_idx(input_tensor_desc, DataLayoutDimension::CHANNEL);
+ NodeID input_split = GraphBuilder::add_split_node(g, params, input, num_groups, input_idx);
// Split weights
- NodeID weights_split = GraphBuilder::add_split_node(g, params, { weights, 0 }, num_groups, 3);
+ const TensorDescriptor weights_tensor_desc = get_tensor_descriptor(g, g.node(weights)->outputs()[0]);
+ const unsigned int batch_idx = get_dimension_idx(weights_tensor_desc, DataLayoutDimension::BATCHES);
+ NodeID weights_split = GraphBuilder::add_split_node(g, params, { weights, 0 }, num_groups, batch_idx);
// Split bias
NodeID bias_split = EmptyNodeID;
@@ -122,7 +126,7 @@ NodeID create_grouped_convolution(Graph &g, const NodeParams &params, NodeIdxPai
}
// Depth concatenate output
- return GraphBuilder::add_depth_concatenate_node(g, params, convolution_outputs);
+ return GraphBuilder::add_concatenate_node(g, params, convolution_outputs, DataLayoutDimension::CHANNEL);
}
} // namespace
@@ -329,11 +333,11 @@ NodeID GraphBuilder::add_deconvolution_node(Graph &g, NodeParams params, NodeIdx
return deconv_nid;
}
-NodeID GraphBuilder::add_depth_concatenate_node(Graph &g, NodeParams params, std::vector<NodeIdxPair> inputs)
+NodeID GraphBuilder::add_concatenate_node(Graph &g, NodeParams params, std::vector<NodeIdxPair> inputs, DataLayoutDimension axis)
{
ARM_COMPUTE_ERROR_ON(inputs.size() == 0);
- NodeID nid = g.add_node<DepthConcatenateLayerNode>(inputs.size());
+ NodeID nid = g.add_node<ConcatenateLayerNode>(inputs.size(), axis);
unsigned int i = 0;
for(const auto &input : inputs)
@@ -508,9 +512,9 @@ NodeID GraphBuilder::add_scale_layer(Graph &g, const NodeParams &params, NodeIdx
NodeIdxPair add_const_nidxp = { add_const_nid, 0 };
// Create node and connect
- NodeID mul_node = GraphBuilder::add_elementwise_node(g, params, input, mul_const_nidxp, EltwiseOperation::MUL);
+ NodeID mul_node = GraphBuilder::add_elementwise_node(g, params, input, mul_const_nidxp, EltwiseOperation::Mul);
NodeIdxPair mulnode_nidxp = { mul_node, 0 };
- NodeID add_node = GraphBuilder::add_elementwise_node(g, params, mulnode_nidxp, add_const_nidxp, EltwiseOperation::ADD);
+ NodeID add_node = GraphBuilder::add_elementwise_node(g, params, mulnode_nidxp, add_const_nidxp, EltwiseOperation::Add);
return add_node;
}