From 351bd137e48c5276963274ac741b172483e98d21 Mon Sep 17 00:00:00 2001 From: giuros01 Date: Fri, 23 Aug 2019 14:27:30 +0100 Subject: compmid-2573: Investigate FP16 Winograd reference implementations Change-Id: I5a3e692c046a5ad28a676c03e3e51950c64cf503 Signed-off-by: giuros01 Reviewed-on: https://review.mlplatform.org/c/1845 Reviewed-by: Pablo Marquez Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins --- src/graph/mutators/NodeFusionMutator.cpp | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) (limited to 'src') diff --git a/src/graph/mutators/NodeFusionMutator.cpp b/src/graph/mutators/NodeFusionMutator.cpp index f7f3454fad..4c3a905598 100644 --- a/src/graph/mutators/NodeFusionMutator.cpp +++ b/src/graph/mutators/NodeFusionMutator.cpp @@ -71,11 +71,10 @@ void fuse_convolution_with_batch_normalization(Graph &g, const Edge *output_edge FastMathHint fast_math_hint = conv_node->fast_math_hint(); // Extract bn inputs - const auto bn_mean_id = bn_node->input_edge(1)->producer_id(); - const auto bn_var_id = bn_node->input_edge(2)->producer_id(); - const auto bn_beta_id = bn_node->input_edge(3)->producer_id(); - const auto bn_gamma_id = bn_node->input_edge(4)->producer_id(); - const auto epsilon = bn_node->epsilon(); + const auto bn_mean_id = bn_node->input_edge(1)->producer_id(); + const auto bn_var_id = bn_node->input_edge(2)->producer_id(); + + const auto epsilon = bn_node->epsilon(); // Create the fused node const NodeID fused_id = g.add_node(epsilon, conv_info, num_groups, conv_method, fast_math_hint, act_info); @@ -91,8 +90,18 @@ void fuse_convolution_with_batch_normalization(Graph &g, const Edge *output_edge g.add_connection(conv_weights_id, 0, fused_id, 1); g.add_connection(bn_mean_id, 0, fused_id, 3); g.add_connection(bn_var_id, 0, fused_id, 4); - g.add_connection(bn_beta_id, 0, fused_id, 5); - g.add_connection(bn_gamma_id, 0, fused_id, 6); + + if(bn_node->input_edge(3) != nullptr) + { + const auto bn_beta_id = bn_node->input_edge(3)->producer_id(); + g.add_connection(bn_beta_id, 0, fused_id, 5); + } + + if(bn_node->input_edge(4) != nullptr) + { + const auto bn_gamma_id = bn_node->input_edge(4)->producer_id(); + g.add_connection(bn_gamma_id, 0, fused_id, 6); + } auto fused_node = g.node(fused_id); std::vector bn_driving_nodes = get_driving_nodes(*bn_node); -- cgit v1.2.1