aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorgiuros01 <giuseppe.rossini@arm.com>2019-08-23 14:27:30 +0100
committerGiuseppe Rossini <giuseppe.rossini@arm.com>2019-08-30 13:37:28 +0000
commit351bd137e48c5276963274ac741b172483e98d21 (patch)
tree3ede92537c406d24f948acc51c1e6c0fac011036 /src
parentebe2e8ccc6f9504fdad95884a794be1e9f58803e (diff)
downloadComputeLibrary-351bd137e48c5276963274ac741b172483e98d21.tar.gz
compmid-2573: Investigate FP16 Winograd reference implementations
Change-Id: I5a3e692c046a5ad28a676c03e3e51950c64cf503 Signed-off-by: giuros01 <giuseppe.rossini@arm.com> Reviewed-on: https://review.mlplatform.org/c/1845 Reviewed-by: Pablo Marquez <pablo.tello@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src')
-rw-r--r--src/graph/mutators/NodeFusionMutator.cpp23
1 files changed, 16 insertions, 7 deletions
diff --git a/src/graph/mutators/NodeFusionMutator.cpp b/src/graph/mutators/NodeFusionMutator.cpp
index f7f3454fad..4c3a905598 100644
--- a/src/graph/mutators/NodeFusionMutator.cpp
+++ b/src/graph/mutators/NodeFusionMutator.cpp
@@ -71,11 +71,10 @@ void fuse_convolution_with_batch_normalization(Graph &g, const Edge *output_edge
FastMathHint fast_math_hint = conv_node->fast_math_hint();
// Extract bn inputs
- const auto bn_mean_id = bn_node->input_edge(1)->producer_id();
- const auto bn_var_id = bn_node->input_edge(2)->producer_id();
- const auto bn_beta_id = bn_node->input_edge(3)->producer_id();
- const auto bn_gamma_id = bn_node->input_edge(4)->producer_id();
- const auto epsilon = bn_node->epsilon();
+ const auto bn_mean_id = bn_node->input_edge(1)->producer_id();
+ const auto bn_var_id = bn_node->input_edge(2)->producer_id();
+
+ const auto epsilon = bn_node->epsilon();
// Create the fused node
const NodeID fused_id = g.add_node<FusedConvolutionBatchNormalizationNode>(epsilon, conv_info, num_groups, conv_method, fast_math_hint, act_info);
@@ -91,8 +90,18 @@ void fuse_convolution_with_batch_normalization(Graph &g, const Edge *output_edge
g.add_connection(conv_weights_id, 0, fused_id, 1);
g.add_connection(bn_mean_id, 0, fused_id, 3);
g.add_connection(bn_var_id, 0, fused_id, 4);
- g.add_connection(bn_beta_id, 0, fused_id, 5);
- g.add_connection(bn_gamma_id, 0, fused_id, 6);
+
+ if(bn_node->input_edge(3) != nullptr)
+ {
+ const auto bn_beta_id = bn_node->input_edge(3)->producer_id();
+ g.add_connection(bn_beta_id, 0, fused_id, 5);
+ }
+
+ if(bn_node->input_edge(4) != nullptr)
+ {
+ const auto bn_gamma_id = bn_node->input_edge(4)->producer_id();
+ g.add_connection(bn_gamma_id, 0, fused_id, 6);
+ }
auto fused_node = g.node(fused_id);
std::vector<NodeIdxPair> bn_driving_nodes = get_driving_nodes(*bn_node);