From cac13b1cfd593889271f8e2191be2039b8d88f36 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 27 Apr 2018 19:07:19 +0100 Subject: COMPMID-1097: Port mobilenet to NHWC Change-Id: I789065bfa0d4ef133388e1904c5caf31e450f80f Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/129495 Tested-by: Jenkins Reviewed-by: Anthony Barbier --- src/graph/GraphBuilder.cpp | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) (limited to 'src/graph/GraphBuilder.cpp') diff --git a/src/graph/GraphBuilder.cpp b/src/graph/GraphBuilder.cpp index 4ad34e789c..56b31c7844 100644 --- a/src/graph/GraphBuilder.cpp +++ b/src/graph/GraphBuilder.cpp @@ -63,7 +63,7 @@ Status set_accessor_on_node(Graph &g, NodeID nid, bool is_output, size_t idx, IT NodeID add_const_node_with_name(Graph &g, NodeParams params, const std::string &name, TensorDescriptor desc, ITensorAccessorUPtr accessor) { params.name = params.name.empty() ? "" : params.name + name; - auto nid = GraphBuilder::add_const_node(g, params, desc, std::move(accessor)); + auto nid = GraphBuilder::add_const_node(g, params, std::move(desc), std::move(accessor)); set_node_params(g, nid, params); return nid; } @@ -165,7 +165,7 @@ NodeID GraphBuilder::add_batch_normalization_node(Graph &g, NodeParams params, N // Calculate Common Descriptor TensorDescriptor common_desc = input_tensor_desc; - common_desc.shape = TensorShape(common_desc.shape.z()); + common_desc.shape = TensorShape(get_dimension_size(input_tensor_desc, DataLayoutDimension::CHANNEL)); // Create mean and nodes auto mean_nid = add_const_node_with_name(g, params, "Mean", common_desc, std::move(mean_accessor)); @@ -221,8 +221,11 @@ NodeID GraphBuilder::add_convolution_node(Graph &g, NodeParams params, NodeIdxPa // Create weights node TensorDescriptor w_desc = input_tensor_desc; - w_desc.shape = TensorShape(kernel_spatial_extend.width, kernel_spatial_extend.height, w_desc.shape.z() / num_groups, depth); - + w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::WIDTH), kernel_spatial_extend.width); + w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::HEIGHT), kernel_spatial_extend.height); + w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::CHANNEL), + get_dimension_size(input_tensor_desc, DataLayoutDimension::CHANNEL) / num_groups); + w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::BATCHES), depth); if(!weights_quant_info.empty()) { w_desc.quant_info = weights_quant_info; @@ -290,8 +293,10 @@ NodeID GraphBuilder::add_depthwise_convolution_node(Graph &g, NodeParams params, // Create weights node TensorDescriptor w_desc = input_tensor_desc; - w_desc.shape = TensorShape(kernel_spatial_extend.width, kernel_spatial_extend.height, w_desc.shape.z()); - + w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::WIDTH), kernel_spatial_extend.width); + w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::HEIGHT), kernel_spatial_extend.height); + w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::CHANNEL), + get_dimension_size(input_tensor_desc, DataLayoutDimension::CHANNEL)); if(!quant_info.empty()) { w_desc.quant_info = quant_info; @@ -353,9 +358,8 @@ NodeID GraphBuilder::add_fully_connected_layer(Graph &g, NodeParams params, Node const TensorDescriptor input_tensor_desc = get_tensor_descriptor(g, g.node(input.node_id)->outputs()[0]); // Create weights node - TensorDescriptor w_desc = input_tensor_desc; - w_desc.shape = FullyConnectedLayerNode::compute_weights_shape(input_tensor_desc.shape, num_outputs); - NodeID w_nid = add_const_node_with_name(g, params, "Weights", w_desc, std::move(weights_accessor)); + TensorDescriptor w_desc = FullyConnectedLayerNode::compute_weights_descriptor(input_tensor_desc, num_outputs); + NodeID w_nid = add_const_node_with_name(g, params, "Weights", w_desc, std::move(weights_accessor)); // Create bias nodes NodeID b_nid = EmptyNodeID; -- cgit v1.2.1