diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2018-04-27 19:07:19 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:51:17 +0000 |
commit | cac13b1cfd593889271f8e2191be2039b8d88f36 (patch) | |
tree | d1c5196877d7fbd5dcfbb9f9003faf6035f82a33 /src/graph/GraphBuilder.cpp | |
parent | ad0c7388f6261989a268ffb2d042f2bd80736e3f (diff) | |
download | ComputeLibrary-cac13b1cfd593889271f8e2191be2039b8d88f36.tar.gz |
COMPMID-1097: Port mobilenet to NHWC
Change-Id: I789065bfa0d4ef133388e1904c5caf31e450f80f
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/129495
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/graph/GraphBuilder.cpp')
-rw-r--r-- | src/graph/GraphBuilder.cpp | 22 |
1 files changed, 13 insertions, 9 deletions
diff --git a/src/graph/GraphBuilder.cpp b/src/graph/GraphBuilder.cpp index 4ad34e789c..56b31c7844 100644 --- a/src/graph/GraphBuilder.cpp +++ b/src/graph/GraphBuilder.cpp @@ -63,7 +63,7 @@ Status set_accessor_on_node(Graph &g, NodeID nid, bool is_output, size_t idx, IT NodeID add_const_node_with_name(Graph &g, NodeParams params, const std::string &name, TensorDescriptor desc, ITensorAccessorUPtr accessor) { params.name = params.name.empty() ? "" : params.name + name; - auto nid = GraphBuilder::add_const_node(g, params, desc, std::move(accessor)); + auto nid = GraphBuilder::add_const_node(g, params, std::move(desc), std::move(accessor)); set_node_params(g, nid, params); return nid; } @@ -165,7 +165,7 @@ NodeID GraphBuilder::add_batch_normalization_node(Graph &g, NodeParams params, N // Calculate Common Descriptor TensorDescriptor common_desc = input_tensor_desc; - common_desc.shape = TensorShape(common_desc.shape.z()); + common_desc.shape = TensorShape(get_dimension_size(input_tensor_desc, DataLayoutDimension::CHANNEL)); // Create mean and nodes auto mean_nid = add_const_node_with_name(g, params, "Mean", common_desc, std::move(mean_accessor)); @@ -221,8 +221,11 @@ NodeID GraphBuilder::add_convolution_node(Graph &g, NodeParams params, NodeIdxPa // Create weights node TensorDescriptor w_desc = input_tensor_desc; - w_desc.shape = TensorShape(kernel_spatial_extend.width, kernel_spatial_extend.height, w_desc.shape.z() / num_groups, depth); - + w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::WIDTH), kernel_spatial_extend.width); + w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::HEIGHT), kernel_spatial_extend.height); + w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::CHANNEL), + get_dimension_size(input_tensor_desc, DataLayoutDimension::CHANNEL) / num_groups); + w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::BATCHES), depth); if(!weights_quant_info.empty()) { w_desc.quant_info = weights_quant_info; @@ -290,8 +293,10 @@ NodeID GraphBuilder::add_depthwise_convolution_node(Graph &g, NodeParams params, // Create weights node TensorDescriptor w_desc = input_tensor_desc; - w_desc.shape = TensorShape(kernel_spatial_extend.width, kernel_spatial_extend.height, w_desc.shape.z()); - + w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::WIDTH), kernel_spatial_extend.width); + w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::HEIGHT), kernel_spatial_extend.height); + w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::CHANNEL), + get_dimension_size(input_tensor_desc, DataLayoutDimension::CHANNEL)); if(!quant_info.empty()) { w_desc.quant_info = quant_info; @@ -353,9 +358,8 @@ NodeID GraphBuilder::add_fully_connected_layer(Graph &g, NodeParams params, Node const TensorDescriptor input_tensor_desc = get_tensor_descriptor(g, g.node(input.node_id)->outputs()[0]); // Create weights node - TensorDescriptor w_desc = input_tensor_desc; - w_desc.shape = FullyConnectedLayerNode::compute_weights_shape(input_tensor_desc.shape, num_outputs); - NodeID w_nid = add_const_node_with_name(g, params, "Weights", w_desc, std::move(weights_accessor)); + TensorDescriptor w_desc = FullyConnectedLayerNode::compute_weights_descriptor(input_tensor_desc, num_outputs); + NodeID w_nid = add_const_node_with_name(g, params, "Weights", w_desc, std::move(weights_accessor)); // Create bias nodes NodeID b_nid = EmptyNodeID; |