aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/graph/mutators/NodeFusionMutator.cpp23
1 files changed, 16 insertions, 7 deletions
diff --git a/src/graph/mutators/NodeFusionMutator.cpp b/src/graph/mutators/NodeFusionMutator.cpp
index f7f3454fad..4c3a905598 100644
--- a/src/graph/mutators/NodeFusionMutator.cpp
+++ b/src/graph/mutators/NodeFusionMutator.cpp
@@ -71,11 +71,10 @@ void fuse_convolution_with_batch_normalization(Graph &g, const Edge *output_edge
FastMathHint fast_math_hint = conv_node->fast_math_hint();
// Extract bn inputs
- const auto bn_mean_id = bn_node->input_edge(1)->producer_id();
- const auto bn_var_id = bn_node->input_edge(2)->producer_id();
- const auto bn_beta_id = bn_node->input_edge(3)->producer_id();
- const auto bn_gamma_id = bn_node->input_edge(4)->producer_id();
- const auto epsilon = bn_node->epsilon();
+ const auto bn_mean_id = bn_node->input_edge(1)->producer_id();
+ const auto bn_var_id = bn_node->input_edge(2)->producer_id();
+
+ const auto epsilon = bn_node->epsilon();
// Create the fused node
const NodeID fused_id = g.add_node<FusedConvolutionBatchNormalizationNode>(epsilon, conv_info, num_groups, conv_method, fast_math_hint, act_info);
@@ -91,8 +90,18 @@ void fuse_convolution_with_batch_normalization(Graph &g, const Edge *output_edge
g.add_connection(conv_weights_id, 0, fused_id, 1);
g.add_connection(bn_mean_id, 0, fused_id, 3);
g.add_connection(bn_var_id, 0, fused_id, 4);
- g.add_connection(bn_beta_id, 0, fused_id, 5);
- g.add_connection(bn_gamma_id, 0, fused_id, 6);
+
+ if(bn_node->input_edge(3) != nullptr)
+ {
+ const auto bn_beta_id = bn_node->input_edge(3)->producer_id();
+ g.add_connection(bn_beta_id, 0, fused_id, 5);
+ }
+
+ if(bn_node->input_edge(4) != nullptr)
+ {
+ const auto bn_gamma_id = bn_node->input_edge(4)->producer_id();
+ g.add_connection(bn_gamma_id, 0, fused_id, 6);
+ }
auto fused_node = g.node(fused_id);
std::vector<NodeIdxPair> bn_driving_nodes = get_driving_nodes(*bn_node);