aboutsummaryrefslogtreecommitdiff
path: root/src/graph/mutators/NodeFusionMutator.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/graph/mutators/NodeFusionMutator.cpp')
-rw-r--r--src/graph/mutators/NodeFusionMutator.cpp13
1 files changed, 2 insertions, 11 deletions
diff --git a/src/graph/mutators/NodeFusionMutator.cpp b/src/graph/mutators/NodeFusionMutator.cpp
index 6677330cec..82bfe25a3e 100644
--- a/src/graph/mutators/NodeFusionMutator.cpp
+++ b/src/graph/mutators/NodeFusionMutator.cpp
@@ -25,6 +25,7 @@
#include "arm_compute/graph/Graph.h"
#include "arm_compute/graph/Logger.h"
+#include "arm_compute/graph/Utils.h"
#include "arm_compute/graph/nodes/Nodes.h"
#include "arm_compute/core/utils/misc/Cast.h"
@@ -71,17 +72,7 @@ void fuse_batch_norm_with_activation(Graph &g)
if(bn_node->output(0)->accessor() == nullptr)
{
// Get driving nodes of activation node
- std::vector<NodeIdxPair> act_driving_nodes;
- for(auto &act_output_edge_id : act_node->output_edges())
- {
- auto act_output_edge = g.edge(act_output_edge_id);
- if(act_output_edge != nullptr)
- {
- ARM_COMPUTE_ERROR_ON(act_output_edge->consumer() == nullptr);
- act_driving_nodes.push_back(
- { act_output_edge->consumer_id(), act_output_edge->consumer_idx() });
- }
- }
+ std::vector<NodeIdxPair> act_driving_nodes = get_driving_nodes(*act_node);
// Set activation info to batch normalization
bn_node->set_fused_activation(act_node->activation_info());