diff options
Diffstat (limited to 'src/graph/mutators/NodeFusionMutator.cpp')
-rw-r--r-- | src/graph/mutators/NodeFusionMutator.cpp | 38 |
1 files changed, 9 insertions, 29 deletions
diff --git a/src/graph/mutators/NodeFusionMutator.cpp b/src/graph/mutators/NodeFusionMutator.cpp index b530fb0c00..e37164c60c 100644 --- a/src/graph/mutators/NodeFusionMutator.cpp +++ b/src/graph/mutators/NodeFusionMutator.cpp @@ -30,6 +30,8 @@ #include "arm_compute/graph/nodes/FusedConvolutionBatchNormalizationNode.h" #include "arm_compute/graph/nodes/Nodes.h" +#include "src/graph/mutators/MutatorUtils.h" + #include "support/Cast.h" #include <set> @@ -265,33 +267,6 @@ void fuse_node_with_activation(Graph &g, const Edge *output_edge, const std::set } } -bool check_padding_info(const DataLayout &layout, const PaddingList &padding_list, PaddingInfo &pad_w, PaddingInfo &pad_h) -{ - if(layout == DataLayout::NCHW || layout == DataLayout::NHWC) - { - const PaddingInfo zero_padding(0, 0); - - const unsigned int height_index = get_dimension_idx(layout, DataLayoutDimension::HEIGHT); - const unsigned int width_index = get_dimension_idx(layout, DataLayoutDimension::WIDTH); - - pad_w = width_index < padding_list.size() ? padding_list[width_index] : zero_padding; - pad_h = height_index < padding_list.size() ? padding_list[height_index] : zero_padding; - - for(unsigned int i = 0; i < padding_list.size(); i++) - { - if(i != height_index && i != width_index && padding_list[i] != zero_padding) - { - // if the index is not either height or width, don't fuse - return false; - } - } - - return true; - } - - return false; -} - template <typename N> void fuse_pad_with_convolution(Graph &g, const Edge *output_edge) { @@ -304,9 +279,14 @@ void fuse_pad_with_convolution(Graph &g, const Edge *output_edge) { const DataLayout layout = input_edge->tensor()->desc().layout; const PaddingList padding_list = pad_node->padding(); - PaddingInfo pad_w, pad_h; - if(check_padding_info(layout, padding_list, pad_w, pad_h)) + const unsigned int height_index = get_dimension_idx(layout, DataLayoutDimension::HEIGHT); + const unsigned int width_index = get_dimension_idx(layout, DataLayoutDimension::WIDTH); + + const PaddingInfo pad_w = width_index < padding_list.size() ? padding_list[width_index] : PaddingInfo(0, 0); + const PaddingInfo pad_h = height_index < padding_list.size() ? padding_list[height_index] : PaddingInfo(0, 0); + + if(is_padding_in_height_or_width(layout, padding_list)) { // Add paddings to the convolution node const PadStrideInfo conv_info = conv_node->convolution_info(); |