diff options
Diffstat (limited to 'src/graph')
-rw-r--r-- | src/graph/mutators/NodeFusionMutator.cpp | 23 |
1 files changed, 16 insertions, 7 deletions
diff --git a/src/graph/mutators/NodeFusionMutator.cpp b/src/graph/mutators/NodeFusionMutator.cpp index f7f3454fad..4c3a905598 100644 --- a/src/graph/mutators/NodeFusionMutator.cpp +++ b/src/graph/mutators/NodeFusionMutator.cpp @@ -71,11 +71,10 @@ void fuse_convolution_with_batch_normalization(Graph &g, const Edge *output_edge FastMathHint fast_math_hint = conv_node->fast_math_hint(); // Extract bn inputs - const auto bn_mean_id = bn_node->input_edge(1)->producer_id(); - const auto bn_var_id = bn_node->input_edge(2)->producer_id(); - const auto bn_beta_id = bn_node->input_edge(3)->producer_id(); - const auto bn_gamma_id = bn_node->input_edge(4)->producer_id(); - const auto epsilon = bn_node->epsilon(); + const auto bn_mean_id = bn_node->input_edge(1)->producer_id(); + const auto bn_var_id = bn_node->input_edge(2)->producer_id(); + + const auto epsilon = bn_node->epsilon(); // Create the fused node const NodeID fused_id = g.add_node<FusedConvolutionBatchNormalizationNode>(epsilon, conv_info, num_groups, conv_method, fast_math_hint, act_info); @@ -91,8 +90,18 @@ void fuse_convolution_with_batch_normalization(Graph &g, const Edge *output_edge g.add_connection(conv_weights_id, 0, fused_id, 1); g.add_connection(bn_mean_id, 0, fused_id, 3); g.add_connection(bn_var_id, 0, fused_id, 4); - g.add_connection(bn_beta_id, 0, fused_id, 5); - g.add_connection(bn_gamma_id, 0, fused_id, 6); + + if(bn_node->input_edge(3) != nullptr) + { + const auto bn_beta_id = bn_node->input_edge(3)->producer_id(); + g.add_connection(bn_beta_id, 0, fused_id, 5); + } + + if(bn_node->input_edge(4) != nullptr) + { + const auto bn_gamma_id = bn_node->input_edge(4)->producer_id(); + g.add_connection(bn_gamma_id, 0, fused_id, 6); + } auto fused_node = g.node(fused_id); std::vector<NodeIdxPair> bn_driving_nodes = get_driving_nodes(*bn_node); |