From e2220551b7a64b929650ba9a60529c31e70c13c5 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 20 Jul 2018 13:23:44 +0100 Subject: COMPMID-1367: Enable NHWC in graph examples Change-Id: Iabc54a3a1bdcd46a9a921cda39c7c85fef672b72 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/141449 Reviewed-by: Giorgio Arena Reviewed-by: Anthony Barbier Tested-by: Jenkins --- src/graph/GraphBuilder.cpp | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) (limited to 'src/graph/GraphBuilder.cpp') diff --git a/src/graph/GraphBuilder.cpp b/src/graph/GraphBuilder.cpp index d26039ec35..b3721719d9 100644 --- a/src/graph/GraphBuilder.cpp +++ b/src/graph/GraphBuilder.cpp @@ -88,10 +88,14 @@ NodeID create_grouped_convolution(Graph &g, const NodeParams ¶ms, NodeIdxPai bool has_bias = (bias != EmptyNodeID); // Split input - NodeID input_split = GraphBuilder::add_split_node(g, params, input, num_groups, 2); + const TensorDescriptor input_tensor_desc = get_tensor_descriptor(g, g.node(input.node_id)->outputs()[0]); + const unsigned int input_idx = get_dimension_idx(input_tensor_desc, DataLayoutDimension::CHANNEL); + NodeID input_split = GraphBuilder::add_split_node(g, params, input, num_groups, input_idx); // Split weights - NodeID weights_split = GraphBuilder::add_split_node(g, params, { weights, 0 }, num_groups, 3); + const TensorDescriptor weights_tensor_desc = get_tensor_descriptor(g, g.node(weights)->outputs()[0]); + const unsigned int batch_idx = get_dimension_idx(weights_tensor_desc, DataLayoutDimension::BATCHES); + NodeID weights_split = GraphBuilder::add_split_node(g, params, { weights, 0 }, num_groups, batch_idx); // Split bias NodeID bias_split = EmptyNodeID; @@ -122,7 +126,7 @@ NodeID create_grouped_convolution(Graph &g, const NodeParams ¶ms, NodeIdxPai } // Depth concatenate output - return GraphBuilder::add_depth_concatenate_node(g, params, convolution_outputs); + return GraphBuilder::add_concatenate_node(g, params, convolution_outputs, DataLayoutDimension::CHANNEL); } } // namespace @@ -329,11 +333,11 @@ NodeID GraphBuilder::add_deconvolution_node(Graph &g, NodeParams params, NodeIdx return deconv_nid; } -NodeID GraphBuilder::add_depth_concatenate_node(Graph &g, NodeParams params, std::vector inputs) +NodeID GraphBuilder::add_concatenate_node(Graph &g, NodeParams params, std::vector inputs, DataLayoutDimension axis) { ARM_COMPUTE_ERROR_ON(inputs.size() == 0); - NodeID nid = g.add_node(inputs.size()); + NodeID nid = g.add_node(inputs.size(), axis); unsigned int i = 0; for(const auto &input : inputs) @@ -508,9 +512,9 @@ NodeID GraphBuilder::add_scale_layer(Graph &g, const NodeParams ¶ms, NodeIdx NodeIdxPair add_const_nidxp = { add_const_nid, 0 }; // Create node and connect - NodeID mul_node = GraphBuilder::add_elementwise_node(g, params, input, mul_const_nidxp, EltwiseOperation::MUL); + NodeID mul_node = GraphBuilder::add_elementwise_node(g, params, input, mul_const_nidxp, EltwiseOperation::Mul); NodeIdxPair mulnode_nidxp = { mul_node, 0 }; - NodeID add_node = GraphBuilder::add_elementwise_node(g, params, mulnode_nidxp, add_const_nidxp, EltwiseOperation::ADD); + NodeID add_node = GraphBuilder::add_elementwise_node(g, params, mulnode_nidxp, add_const_nidxp, EltwiseOperation::Add); return add_node; } -- cgit v1.2.1