aboutsummaryrefslogtreecommitdiff
path: root/src/graph/mutators/NodeFusionMutator.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/graph/mutators/NodeFusionMutator.cpp')
-rw-r--r--src/graph/mutators/NodeFusionMutator.cpp38
1 files changed, 9 insertions, 29 deletions
diff --git a/src/graph/mutators/NodeFusionMutator.cpp b/src/graph/mutators/NodeFusionMutator.cpp
index b530fb0c00..e37164c60c 100644
--- a/src/graph/mutators/NodeFusionMutator.cpp
+++ b/src/graph/mutators/NodeFusionMutator.cpp
@@ -30,6 +30,8 @@
#include "arm_compute/graph/nodes/FusedConvolutionBatchNormalizationNode.h"
#include "arm_compute/graph/nodes/Nodes.h"
+#include "src/graph/mutators/MutatorUtils.h"
+
#include "support/Cast.h"
#include <set>
@@ -265,33 +267,6 @@ void fuse_node_with_activation(Graph &g, const Edge *output_edge, const std::set
}
}
-bool check_padding_info(const DataLayout &layout, const PaddingList &padding_list, PaddingInfo &pad_w, PaddingInfo &pad_h)
-{
- if(layout == DataLayout::NCHW || layout == DataLayout::NHWC)
- {
- const PaddingInfo zero_padding(0, 0);
-
- const unsigned int height_index = get_dimension_idx(layout, DataLayoutDimension::HEIGHT);
- const unsigned int width_index = get_dimension_idx(layout, DataLayoutDimension::WIDTH);
-
- pad_w = width_index < padding_list.size() ? padding_list[width_index] : zero_padding;
- pad_h = height_index < padding_list.size() ? padding_list[height_index] : zero_padding;
-
- for(unsigned int i = 0; i < padding_list.size(); i++)
- {
- if(i != height_index && i != width_index && padding_list[i] != zero_padding)
- {
- // if the index is not either height or width, don't fuse
- return false;
- }
- }
-
- return true;
- }
-
- return false;
-}
-
template <typename N>
void fuse_pad_with_convolution(Graph &g, const Edge *output_edge)
{
@@ -304,9 +279,14 @@ void fuse_pad_with_convolution(Graph &g, const Edge *output_edge)
{
const DataLayout layout = input_edge->tensor()->desc().layout;
const PaddingList padding_list = pad_node->padding();
- PaddingInfo pad_w, pad_h;
- if(check_padding_info(layout, padding_list, pad_w, pad_h))
+ const unsigned int height_index = get_dimension_idx(layout, DataLayoutDimension::HEIGHT);
+ const unsigned int width_index = get_dimension_idx(layout, DataLayoutDimension::WIDTH);
+
+ const PaddingInfo pad_w = width_index < padding_list.size() ? padding_list[width_index] : PaddingInfo(0, 0);
+ const PaddingInfo pad_h = height_index < padding_list.size() ? padding_list[height_index] : PaddingInfo(0, 0);
+
+ if(is_padding_in_height_or_width(layout, padding_list))
{
// Add paddings to the convolution node
const PadStrideInfo conv_info = conv_node->convolution_info();