From 1c32bf396eb690a54fd94487e3f258b2c7d31753 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Mon, 12 Nov 2018 18:36:19 +0000 Subject: COMPMID-1451: Perform fusion before GroupConvolution unrolling Change-Id: Id94fb9c88a498d7b938f4f707e2e7b9b6df94880 --- src/graph/PassManager.cpp | 4 ++-- src/graph/Utils.cpp | 16 +++++-------- src/graph/mutators/GroupedConvolutionMutator.cpp | 30 ++++++++++++++---------- src/graph/mutators/NodeFusionMutator.cpp | 8 +++---- 4 files changed, 30 insertions(+), 28 deletions(-) (limited to 'src/graph') diff --git a/src/graph/PassManager.cpp b/src/graph/PassManager.cpp index 8ed68bd99b..92860e2987 100644 --- a/src/graph/PassManager.cpp +++ b/src/graph/PassManager.cpp @@ -44,9 +44,9 @@ IGraphMutator *PassManager::pass(size_t index) return (index >= _passes.size()) ? nullptr : _passes.at(index).get(); } -void PassManager::append(std::unique_ptr pass) +void PassManager::append(std::unique_ptr pass, bool conditional) { - if(pass) + if(pass && conditional) { ARM_COMPUTE_LOG_GRAPH_VERBOSE("Appending mutating pass : " << pass->name() << std::endl); _passes.push_back(std::move(pass)); diff --git a/src/graph/Utils.cpp b/src/graph/Utils.cpp index 0a85a7f119..71ec548f8a 100644 --- a/src/graph/Utils.cpp +++ b/src/graph/Utils.cpp @@ -78,20 +78,16 @@ PassManager create_default_pass_manager(Target target) { PassManager pm; + const bool is_target_gc = target == Target::GC; + // Passes that mutate graph IR + pm.append(support::cpp14::make_unique(), !is_target_gc); pm.append(support::cpp14::make_unique()); - if(target != Target::GC) - { - pm.append(support::cpp14::make_unique()); - pm.append(support::cpp14::make_unique()); - } + pm.append(support::cpp14::make_unique(), !is_target_gc); // Passes that mutate backend information - if(target != Target::GC) - { - pm.append(support::cpp14::make_unique()); - pm.append(support::cpp14::make_unique()); - } + pm.append(support::cpp14::make_unique(), !is_target_gc); + pm.append(support::cpp14::make_unique(), !is_target_gc); pm.append(support::cpp14::make_unique()); return pm; diff --git a/src/graph/mutators/GroupedConvolutionMutator.cpp b/src/graph/mutators/GroupedConvolutionMutator.cpp index 1bcc11bcb9..d69d2cd7d0 100644 --- a/src/graph/mutators/GroupedConvolutionMutator.cpp +++ b/src/graph/mutators/GroupedConvolutionMutator.cpp @@ -41,7 +41,7 @@ namespace graph namespace { NodeID create_grouped_convolution(Graph &g, const NodeParams ¶ms, NodeIdxPair input, NodeID weights, NodeID bias, - PadStrideInfo conv_info, ConvolutionMethod method, FastMathHint fast_math_hint, unsigned int num_groups) + PadStrideInfo conv_info, ConvolutionMethod method, ActivationLayerInfo fused_act, FastMathHint fast_math_hint, unsigned int num_groups) { bool has_bias = (bias != EmptyNodeID); @@ -86,6 +86,10 @@ NodeID create_grouped_convolution(Graph &g, const NodeParams ¶ms, NodeIdxPai ARM_COMPUTE_ERROR_ON(node == nullptr); node->set_common_node_parameters(group_params); + // Down-cast node + auto *conv_node = arm_compute::utils::cast::polymorphic_downcast(node); + conv_node->set_fused_activation(fused_act); + convolution_outputs.push_back({ conv_nid, 0 }); } @@ -127,18 +131,20 @@ void GroupedConvolutionMutator::mutate(Graph &g) auto *conv_node = arm_compute::utils::cast::polymorphic_downcast(node); // Get internal convolution info - // TODO (geopin01) : Create a descriptor - const PadStrideInfo conv_info = conv_node->convolution_info(); - const ConvolutionMethod conv_method = conv_node->convolution_method(); - const FastMathHint fast_math_hint = conv_node->fast_math_hint(); - const unsigned int num_groups = conv_node->num_groups(); - const NodeParams params = conv_node->common_node_params(); - const Target assigned_target = conv_node->assigned_target(); + // TODO (geopin01) : Create a descriptor or a clone interface + const PadStrideInfo conv_info = conv_node->convolution_info(); + const ConvolutionMethod conv_method = conv_node->convolution_method(); + const ActivationLayerInfo fused_act_info = conv_node->fused_activation(); + const FastMathHint fast_math_hint = conv_node->fast_math_hint(); + const unsigned int num_groups = conv_node->num_groups(); + const NodeParams params = conv_node->common_node_params(); + const Target assigned_target = conv_node->assigned_target(); // Extract node ids - const NodeID input_id = conv_node->input_id(0); - const NodeID weights_id = conv_node->input_id(1); - const NodeID bias_id = conv_node->input_id(2); + ARM_COMPUTE_ERROR_ON(conv_node->input_edge(0) == nullptr || conv_node->input_edge(1) == nullptr); + const NodeID input_id = conv_node->input_edge(0)->producer()->id(); + const NodeID weights_id = conv_node->input_edge(1)->producer()->id(); + const NodeID bias_id = (conv_node->input_edge(2) != nullptr) ? conv_node->input_edge(2)->producer()->id() : EmptyNodeID; // Get driving nodes std::vector driving_nodes = get_driving_nodes(*node); @@ -152,7 +158,7 @@ void GroupedConvolutionMutator::mutate(Graph &g) // Create grouped convolution node NodeID grouped_conv_id = create_grouped_convolution(g, params, { input_id, 0 }, weights_id, bias_id, - conv_info, conv_method, fast_math_hint, num_groups); + conv_info, conv_method, fused_act_info, fast_math_hint, num_groups); // Remove convolution node g.remove_node(node->id()); diff --git a/src/graph/mutators/NodeFusionMutator.cpp b/src/graph/mutators/NodeFusionMutator.cpp index 98c3a56018..9dc02d1ad1 100644 --- a/src/graph/mutators/NodeFusionMutator.cpp +++ b/src/graph/mutators/NodeFusionMutator.cpp @@ -73,13 +73,13 @@ void fuse_node_with_activation(Graph &g, ARM_COMPUTE_LOG_GRAPH_VERBOSE("Fusing node with ID : " << output_edge->producer_id() << " with Activation Layer node with ID : " << output_edge->consumer_id() << std::endl); - // Prevent fusion if batch normalization node has an output accessor + // Prevent fusion if fused node has an output accessor if(n_node->output(0)->accessor() == nullptr) { // Get driving nodes of activation node std::vector act_driving_nodes = get_driving_nodes(*act_node); - // Set activation info to batch normalization + // Set activation info to fused node n_node->set_fused_activation(act_node->activation_info()); // Extract activation node accessor if any @@ -88,13 +88,13 @@ void fuse_node_with_activation(Graph &g, // Remove activation node g.remove_node(act_node->id()); - // Update batch normalization node outputs + // Update fused node outputs for(auto &driving_node : act_driving_nodes) { g.add_connection(n_node->id(), 0, driving_node.node_id, driving_node.index); } - // Update accessor to batch normalization node + // Update accessor to fused node n_node->output(0)->set_accessor(std::move(act_node_accessor)); } else -- cgit v1.2.1