aboutsummaryrefslogtreecommitdiff
path: root/src/graph
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-11-12 18:36:19 +0000
committerGeorgios Pinitas <georgios.pinitas@arm.com>2018-11-14 16:00:45 +0000
commit1c32bf396eb690a54fd94487e3f258b2c7d31753 (patch)
tree83dafe008fe428133c1c531cd179e4cad256ef5c /src/graph
parent6c7c38e70c795077ba727aadeefc670888bec089 (diff)
downloadComputeLibrary-1c32bf396eb690a54fd94487e3f258b2c7d31753.tar.gz
COMPMID-1451: Perform fusion before GroupConvolution unrolling
Change-Id: Id94fb9c88a498d7b938f4f707e2e7b9b6df94880
Diffstat (limited to 'src/graph')
-rw-r--r--src/graph/PassManager.cpp4
-rw-r--r--src/graph/Utils.cpp16
-rw-r--r--src/graph/mutators/GroupedConvolutionMutator.cpp30
-rw-r--r--src/graph/mutators/NodeFusionMutator.cpp8
4 files changed, 30 insertions, 28 deletions
diff --git a/src/graph/PassManager.cpp b/src/graph/PassManager.cpp
index 8ed68bd99b..92860e2987 100644
--- a/src/graph/PassManager.cpp
+++ b/src/graph/PassManager.cpp
@@ -44,9 +44,9 @@ IGraphMutator *PassManager::pass(size_t index)
return (index >= _passes.size()) ? nullptr : _passes.at(index).get();
}
-void PassManager::append(std::unique_ptr<IGraphMutator> pass)
+void PassManager::append(std::unique_ptr<IGraphMutator> pass, bool conditional)
{
- if(pass)
+ if(pass && conditional)
{
ARM_COMPUTE_LOG_GRAPH_VERBOSE("Appending mutating pass : " << pass->name() << std::endl);
_passes.push_back(std::move(pass));
diff --git a/src/graph/Utils.cpp b/src/graph/Utils.cpp
index 0a85a7f119..71ec548f8a 100644
--- a/src/graph/Utils.cpp
+++ b/src/graph/Utils.cpp
@@ -78,20 +78,16 @@ PassManager create_default_pass_manager(Target target)
{
PassManager pm;
+ const bool is_target_gc = target == Target::GC;
+
// Passes that mutate graph IR
+ pm.append(support::cpp14::make_unique<NodeFusionMutator>(), !is_target_gc);
pm.append(support::cpp14::make_unique<GroupedConvolutionMutator>());
- if(target != Target::GC)
- {
- pm.append(support::cpp14::make_unique<NodeFusionMutator>());
- pm.append(support::cpp14::make_unique<InPlaceOperationMutator>());
- }
+ pm.append(support::cpp14::make_unique<InPlaceOperationMutator>(), !is_target_gc);
// Passes that mutate backend information
- if(target != Target::GC)
- {
- pm.append(support::cpp14::make_unique<DepthConcatSubTensorMutator>());
- pm.append(support::cpp14::make_unique<SplitLayerSubTensorMutator>());
- }
+ pm.append(support::cpp14::make_unique<DepthConcatSubTensorMutator>(), !is_target_gc);
+ pm.append(support::cpp14::make_unique<SplitLayerSubTensorMutator>(), !is_target_gc);
pm.append(support::cpp14::make_unique<NodeExecutionMethodMutator>());
return pm;
diff --git a/src/graph/mutators/GroupedConvolutionMutator.cpp b/src/graph/mutators/GroupedConvolutionMutator.cpp
index 1bcc11bcb9..d69d2cd7d0 100644
--- a/src/graph/mutators/GroupedConvolutionMutator.cpp
+++ b/src/graph/mutators/GroupedConvolutionMutator.cpp
@@ -41,7 +41,7 @@ namespace graph
namespace
{
NodeID create_grouped_convolution(Graph &g, const NodeParams &params, NodeIdxPair input, NodeID weights, NodeID bias,
- PadStrideInfo conv_info, ConvolutionMethod method, FastMathHint fast_math_hint, unsigned int num_groups)
+ PadStrideInfo conv_info, ConvolutionMethod method, ActivationLayerInfo fused_act, FastMathHint fast_math_hint, unsigned int num_groups)
{
bool has_bias = (bias != EmptyNodeID);
@@ -86,6 +86,10 @@ NodeID create_grouped_convolution(Graph &g, const NodeParams &params, NodeIdxPai
ARM_COMPUTE_ERROR_ON(node == nullptr);
node->set_common_node_parameters(group_params);
+ // Down-cast node
+ auto *conv_node = arm_compute::utils::cast::polymorphic_downcast<ConvolutionLayerNode *>(node);
+ conv_node->set_fused_activation(fused_act);
+
convolution_outputs.push_back({ conv_nid, 0 });
}
@@ -127,18 +131,20 @@ void GroupedConvolutionMutator::mutate(Graph &g)
auto *conv_node = arm_compute::utils::cast::polymorphic_downcast<ConvolutionLayerNode *>(node);
// Get internal convolution info
- // TODO (geopin01) : Create a descriptor
- const PadStrideInfo conv_info = conv_node->convolution_info();
- const ConvolutionMethod conv_method = conv_node->convolution_method();
- const FastMathHint fast_math_hint = conv_node->fast_math_hint();
- const unsigned int num_groups = conv_node->num_groups();
- const NodeParams params = conv_node->common_node_params();
- const Target assigned_target = conv_node->assigned_target();
+ // TODO (geopin01) : Create a descriptor or a clone interface
+ const PadStrideInfo conv_info = conv_node->convolution_info();
+ const ConvolutionMethod conv_method = conv_node->convolution_method();
+ const ActivationLayerInfo fused_act_info = conv_node->fused_activation();
+ const FastMathHint fast_math_hint = conv_node->fast_math_hint();
+ const unsigned int num_groups = conv_node->num_groups();
+ const NodeParams params = conv_node->common_node_params();
+ const Target assigned_target = conv_node->assigned_target();
// Extract node ids
- const NodeID input_id = conv_node->input_id(0);
- const NodeID weights_id = conv_node->input_id(1);
- const NodeID bias_id = conv_node->input_id(2);
+ ARM_COMPUTE_ERROR_ON(conv_node->input_edge(0) == nullptr || conv_node->input_edge(1) == nullptr);
+ const NodeID input_id = conv_node->input_edge(0)->producer()->id();
+ const NodeID weights_id = conv_node->input_edge(1)->producer()->id();
+ const NodeID bias_id = (conv_node->input_edge(2) != nullptr) ? conv_node->input_edge(2)->producer()->id() : EmptyNodeID;
// Get driving nodes
std::vector<NodeIdxPair> driving_nodes = get_driving_nodes(*node);
@@ -152,7 +158,7 @@ void GroupedConvolutionMutator::mutate(Graph &g)
// Create grouped convolution node
NodeID grouped_conv_id = create_grouped_convolution(g, params, { input_id, 0 }, weights_id, bias_id,
- conv_info, conv_method, fast_math_hint, num_groups);
+ conv_info, conv_method, fused_act_info, fast_math_hint, num_groups);
// Remove convolution node
g.remove_node(node->id());
diff --git a/src/graph/mutators/NodeFusionMutator.cpp b/src/graph/mutators/NodeFusionMutator.cpp
index 98c3a56018..9dc02d1ad1 100644
--- a/src/graph/mutators/NodeFusionMutator.cpp
+++ b/src/graph/mutators/NodeFusionMutator.cpp
@@ -73,13 +73,13 @@ void fuse_node_with_activation(Graph &g,
ARM_COMPUTE_LOG_GRAPH_VERBOSE("Fusing node with ID : " << output_edge->producer_id()
<< " with Activation Layer node with ID : " << output_edge->consumer_id() << std::endl);
- // Prevent fusion if batch normalization node has an output accessor
+ // Prevent fusion if fused node has an output accessor
if(n_node->output(0)->accessor() == nullptr)
{
// Get driving nodes of activation node
std::vector<NodeIdxPair> act_driving_nodes = get_driving_nodes(*act_node);
- // Set activation info to batch normalization
+ // Set activation info to fused node
n_node->set_fused_activation(act_node->activation_info());
// Extract activation node accessor if any
@@ -88,13 +88,13 @@ void fuse_node_with_activation(Graph &g,
// Remove activation node
g.remove_node(act_node->id());
- // Update batch normalization node outputs
+ // Update fused node outputs
for(auto &driving_node : act_driving_nodes)
{
g.add_connection(n_node->id(), 0, driving_node.node_id, driving_node.index);
}
- // Update accessor to batch normalization node
+ // Update accessor to fused node
n_node->output(0)->set_accessor(std::move(act_node_accessor));
}
else