aboutsummaryrefslogtreecommitdiff
path: root/src/graph/GraphBuilder.cpp
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-04-27 19:07:19 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:51:17 +0000
commitcac13b1cfd593889271f8e2191be2039b8d88f36 (patch)
treed1c5196877d7fbd5dcfbb9f9003faf6035f82a33 /src/graph/GraphBuilder.cpp
parentad0c7388f6261989a268ffb2d042f2bd80736e3f (diff)
downloadComputeLibrary-cac13b1cfd593889271f8e2191be2039b8d88f36.tar.gz
COMPMID-1097: Port mobilenet to NHWC
Change-Id: I789065bfa0d4ef133388e1904c5caf31e450f80f Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/129495 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/graph/GraphBuilder.cpp')
-rw-r--r--src/graph/GraphBuilder.cpp22
1 files changed, 13 insertions, 9 deletions
diff --git a/src/graph/GraphBuilder.cpp b/src/graph/GraphBuilder.cpp
index 4ad34e789c..56b31c7844 100644
--- a/src/graph/GraphBuilder.cpp
+++ b/src/graph/GraphBuilder.cpp
@@ -63,7 +63,7 @@ Status set_accessor_on_node(Graph &g, NodeID nid, bool is_output, size_t idx, IT
NodeID add_const_node_with_name(Graph &g, NodeParams params, const std::string &name, TensorDescriptor desc, ITensorAccessorUPtr accessor)
{
params.name = params.name.empty() ? "" : params.name + name;
- auto nid = GraphBuilder::add_const_node(g, params, desc, std::move(accessor));
+ auto nid = GraphBuilder::add_const_node(g, params, std::move(desc), std::move(accessor));
set_node_params(g, nid, params);
return nid;
}
@@ -165,7 +165,7 @@ NodeID GraphBuilder::add_batch_normalization_node(Graph &g, NodeParams params, N
// Calculate Common Descriptor
TensorDescriptor common_desc = input_tensor_desc;
- common_desc.shape = TensorShape(common_desc.shape.z());
+ common_desc.shape = TensorShape(get_dimension_size(input_tensor_desc, DataLayoutDimension::CHANNEL));
// Create mean and nodes
auto mean_nid = add_const_node_with_name(g, params, "Mean", common_desc, std::move(mean_accessor));
@@ -221,8 +221,11 @@ NodeID GraphBuilder::add_convolution_node(Graph &g, NodeParams params, NodeIdxPa
// Create weights node
TensorDescriptor w_desc = input_tensor_desc;
- w_desc.shape = TensorShape(kernel_spatial_extend.width, kernel_spatial_extend.height, w_desc.shape.z() / num_groups, depth);
-
+ w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::WIDTH), kernel_spatial_extend.width);
+ w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::HEIGHT), kernel_spatial_extend.height);
+ w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::CHANNEL),
+ get_dimension_size(input_tensor_desc, DataLayoutDimension::CHANNEL) / num_groups);
+ w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::BATCHES), depth);
if(!weights_quant_info.empty())
{
w_desc.quant_info = weights_quant_info;
@@ -290,8 +293,10 @@ NodeID GraphBuilder::add_depthwise_convolution_node(Graph &g, NodeParams params,
// Create weights node
TensorDescriptor w_desc = input_tensor_desc;
- w_desc.shape = TensorShape(kernel_spatial_extend.width, kernel_spatial_extend.height, w_desc.shape.z());
-
+ w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::WIDTH), kernel_spatial_extend.width);
+ w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::HEIGHT), kernel_spatial_extend.height);
+ w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::CHANNEL),
+ get_dimension_size(input_tensor_desc, DataLayoutDimension::CHANNEL));
if(!quant_info.empty())
{
w_desc.quant_info = quant_info;
@@ -353,9 +358,8 @@ NodeID GraphBuilder::add_fully_connected_layer(Graph &g, NodeParams params, Node
const TensorDescriptor input_tensor_desc = get_tensor_descriptor(g, g.node(input.node_id)->outputs()[0]);
// Create weights node
- TensorDescriptor w_desc = input_tensor_desc;
- w_desc.shape = FullyConnectedLayerNode::compute_weights_shape(input_tensor_desc.shape, num_outputs);
- NodeID w_nid = add_const_node_with_name(g, params, "Weights", w_desc, std::move(weights_accessor));
+ TensorDescriptor w_desc = FullyConnectedLayerNode::compute_weights_descriptor(input_tensor_desc, num_outputs);
+ NodeID w_nid = add_const_node_with_name(g, params, "Weights", w_desc, std::move(weights_accessor));
// Create bias nodes
NodeID b_nid = EmptyNodeID;