diff options
Diffstat (limited to 'src/graph/mutators/NodeFusionMutator.cpp')
-rw-r--r-- | src/graph/mutators/NodeFusionMutator.cpp | 13 |
1 files changed, 2 insertions, 11 deletions
diff --git a/src/graph/mutators/NodeFusionMutator.cpp b/src/graph/mutators/NodeFusionMutator.cpp index 6677330cec..82bfe25a3e 100644 --- a/src/graph/mutators/NodeFusionMutator.cpp +++ b/src/graph/mutators/NodeFusionMutator.cpp @@ -25,6 +25,7 @@ #include "arm_compute/graph/Graph.h" #include "arm_compute/graph/Logger.h" +#include "arm_compute/graph/Utils.h" #include "arm_compute/graph/nodes/Nodes.h" #include "arm_compute/core/utils/misc/Cast.h" @@ -71,17 +72,7 @@ void fuse_batch_norm_with_activation(Graph &g) if(bn_node->output(0)->accessor() == nullptr) { // Get driving nodes of activation node - std::vector<NodeIdxPair> act_driving_nodes; - for(auto &act_output_edge_id : act_node->output_edges()) - { - auto act_output_edge = g.edge(act_output_edge_id); - if(act_output_edge != nullptr) - { - ARM_COMPUTE_ERROR_ON(act_output_edge->consumer() == nullptr); - act_driving_nodes.push_back( - { act_output_edge->consumer_id(), act_output_edge->consumer_idx() }); - } - } + std::vector<NodeIdxPair> act_driving_nodes = get_driving_nodes(*act_node); // Set activation info to batch normalization bn_node->set_fused_activation(act_node->activation_info()); |