aboutsummaryrefslogtreecommitdiff
path: root/src/graph/GraphBuilder.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/graph/GraphBuilder.cpp')
-rw-r--r--src/graph/GraphBuilder.cpp22
1 files changed, 13 insertions, 9 deletions
diff --git a/src/graph/GraphBuilder.cpp b/src/graph/GraphBuilder.cpp
index 4ad34e789c..56b31c7844 100644
--- a/src/graph/GraphBuilder.cpp
+++ b/src/graph/GraphBuilder.cpp
@@ -63,7 +63,7 @@ Status set_accessor_on_node(Graph &g, NodeID nid, bool is_output, size_t idx, IT
NodeID add_const_node_with_name(Graph &g, NodeParams params, const std::string &name, TensorDescriptor desc, ITensorAccessorUPtr accessor)
{
params.name = params.name.empty() ? "" : params.name + name;
- auto nid = GraphBuilder::add_const_node(g, params, desc, std::move(accessor));
+ auto nid = GraphBuilder::add_const_node(g, params, std::move(desc), std::move(accessor));
set_node_params(g, nid, params);
return nid;
}
@@ -165,7 +165,7 @@ NodeID GraphBuilder::add_batch_normalization_node(Graph &g, NodeParams params, N
// Calculate Common Descriptor
TensorDescriptor common_desc = input_tensor_desc;
- common_desc.shape = TensorShape(common_desc.shape.z());
+ common_desc.shape = TensorShape(get_dimension_size(input_tensor_desc, DataLayoutDimension::CHANNEL));
// Create mean and nodes
auto mean_nid = add_const_node_with_name(g, params, "Mean", common_desc, std::move(mean_accessor));
@@ -221,8 +221,11 @@ NodeID GraphBuilder::add_convolution_node(Graph &g, NodeParams params, NodeIdxPa
// Create weights node
TensorDescriptor w_desc = input_tensor_desc;
- w_desc.shape = TensorShape(kernel_spatial_extend.width, kernel_spatial_extend.height, w_desc.shape.z() / num_groups, depth);
-
+ w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::WIDTH), kernel_spatial_extend.width);
+ w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::HEIGHT), kernel_spatial_extend.height);
+ w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::CHANNEL),
+ get_dimension_size(input_tensor_desc, DataLayoutDimension::CHANNEL) / num_groups);
+ w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::BATCHES), depth);
if(!weights_quant_info.empty())
{
w_desc.quant_info = weights_quant_info;
@@ -290,8 +293,10 @@ NodeID GraphBuilder::add_depthwise_convolution_node(Graph &g, NodeParams params,
// Create weights node
TensorDescriptor w_desc = input_tensor_desc;
- w_desc.shape = TensorShape(kernel_spatial_extend.width, kernel_spatial_extend.height, w_desc.shape.z());
-
+ w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::WIDTH), kernel_spatial_extend.width);
+ w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::HEIGHT), kernel_spatial_extend.height);
+ w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::CHANNEL),
+ get_dimension_size(input_tensor_desc, DataLayoutDimension::CHANNEL));
if(!quant_info.empty())
{
w_desc.quant_info = quant_info;
@@ -353,9 +358,8 @@ NodeID GraphBuilder::add_fully_connected_layer(Graph &g, NodeParams params, Node
const TensorDescriptor input_tensor_desc = get_tensor_descriptor(g, g.node(input.node_id)->outputs()[0]);
// Create weights node
- TensorDescriptor w_desc = input_tensor_desc;
- w_desc.shape = FullyConnectedLayerNode::compute_weights_shape(input_tensor_desc.shape, num_outputs);
- NodeID w_nid = add_const_node_with_name(g, params, "Weights", w_desc, std::move(weights_accessor));
+ TensorDescriptor w_desc = FullyConnectedLayerNode::compute_weights_descriptor(input_tensor_desc, num_outputs);
+ NodeID w_nid = add_const_node_with_name(g, params, "Weights", w_desc, std::move(weights_accessor));
// Create bias nodes
NodeID b_nid = EmptyNodeID;