aboutsummaryrefslogtreecommitdiff
path: root/src/graph/mutators/NodeFusionMutator.cpp
blob: 427d7b5095aa4d3735285485179c640afc5ae732 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
/*
 * Copyright (c) 2018-2019 ARM Limited.
 *
 * SPDX-License-Identifier: MIT
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to
 * deal in the Software without restriction, including without limitation the
 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
 * sell copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in all
 * copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */
#include "arm_compute/graph/mutators/NodeFusionMutator.h"

#include "arm_compute/graph/GraphBuilder.h"
#include "arm_compute/graph/Logger.h"
#include "arm_compute/graph/Utils.h"
#include "arm_compute/graph/backends/BackendRegistry.h"
#include "arm_compute/graph/nodes/FusedConvolutionBatchNormalizationNode.h"
#include "arm_compute/graph/nodes/Nodes.h"

#include "arm_compute/core/utils/misc/Cast.h"

#include <set>

namespace arm_compute
{
namespace graph
{
namespace detail
{
void fuse_convolution_with_batch_normalization(Graph &g, const Edge *output_edge)
{
    ARM_COMPUTE_ERROR_ON(output_edge == nullptr);

    auto *conv_node = arm_compute::utils::cast::polymorphic_downcast<ConvolutionLayerNode *>(output_edge->producer());
    auto *bn_node   = arm_compute::utils::cast::polymorphic_downcast<BatchNormalizationLayerNode *>(output_edge->consumer());

    // Not fusing if number of groups is greater than 1
    if(conv_node->num_groups() > 1)
    {
        return;
    }

    ARM_COMPUTE_LOG_GRAPH_VERBOSE("Fusing convolution node with ID : " << output_edge->producer_id()
                                  << " with BatchNormalization Layer node with ID : " << output_edge->consumer_id() << std::endl);

    // Prevent fusion if fused node has an output accessor
    if(conv_node->output(0)->accessor() == nullptr)
    {
        const Target assigned_target = conv_node->assigned_target();

        // Extract conv inputs
        const auto   conv_input_id   = conv_node->input_edge(0)->producer_id();
        const auto   conv_weights_id = conv_node->input_edge(1)->producer_id();
        const auto   out_quant_info  = conv_node->output(0)->desc().quant_info;
        const auto   conv_info       = conv_node->convolution_info();
        const auto   conv_method     = conv_node->convolution_method();
        const auto   num_groups      = conv_node->num_groups();
        const auto   act_info        = bn_node->fused_activation();
        FastMathHint fast_math_hint  = conv_node->fast_math_hint();

        // Extract bn inputs
        const auto bn_mean_id  = bn_node->input_edge(1)->producer_id();
        const auto bn_var_id   = bn_node->input_edge(2)->producer_id();
        const auto bn_beta_id  = bn_node->input_edge(3)->producer_id();
        const auto bn_gamma_id = bn_node->input_edge(4)->producer_id();
        const auto epsilon     = bn_node->epsilon();

        // Create the fused node
        const NodeID fused_id = g.add_node<FusedConvolutionBatchNormalizationNode>(epsilon, conv_info, num_groups, conv_method, fast_math_hint, out_quant_info, act_info);

        if(conv_node->input_edge(2) != nullptr)
        {
            auto conv_bias_id = conv_node->input_edge(2)->producer_id();
            g.add_connection(conv_bias_id, 0, fused_id, 2);
        }

        // Add connections from the conv/batch_norm inputs to the fused node
        g.add_connection(conv_input_id, 0, fused_id, 0);
        g.add_connection(conv_weights_id, 0, fused_id, 1);
        g.add_connection(bn_mean_id, 0, fused_id, 3);
        g.add_connection(bn_var_id, 0, fused_id, 4);
        g.add_connection(bn_beta_id, 0, fused_id, 5);
        g.add_connection(bn_gamma_id, 0, fused_id, 6);

        auto                     fused_node       = g.node(fused_id);
        std::vector<NodeIdxPair> bn_driving_nodes = get_driving_nodes(*bn_node);

        // Extract batch normalization node accessor if any
        auto bn_node_accessor = bn_node->output(0)->extract_accessor();
        auto bn_node_name     = bn_node->name();

        // Remove batch normalization node
        g.remove_node(bn_node->id());

        // Get driving nodes of batch normalization node
        for(auto &driving_node : bn_driving_nodes)
        {
            g.add_connection(fused_id, 0, driving_node.node_id, driving_node.index);
            configure_tensor(fused_node->output(0));
        }
        // Update fused node outputs
        fused_node->output(0)->set_accessor(std::move(bn_node_accessor));
        fused_node->set_assigned_target(assigned_target);
        fused_node->set_common_node_parameters(NodeParams{ conv_node->name() + "+" + bn_node_name, assigned_target });

        // Remove convolution node
        g.remove_node(conv_node->id());
    }
    else
    {
        ARM_COMPUTE_LOG_GRAPH_VERBOSE("Prevented fusion of convolution with batch normalization due to the presence of an output accessor\n");
    }
}

template <typename N>
void fuse_node_with_activation(Graph &g, const Edge *output_edge, const std::set<Activation> &supported_fused_activations)
{
    ARM_COMPUTE_ERROR_ON(output_edge == nullptr);

    auto *n_node   = arm_compute::utils::cast::polymorphic_downcast<N *>(output_edge->producer());
    auto *act_node = arm_compute::utils::cast::polymorphic_downcast<ActivationLayerNode *>(output_edge->consumer());

    ARM_COMPUTE_ERROR_ON(act_node->output(0) == nullptr || n_node->output(0) == nullptr);

    // Check if activation is supported for fusion
    if(supported_fused_activations.count(act_node->activation_info().activation()) == 0)
    {
        return;
    }

    ARM_COMPUTE_LOG_GRAPH_VERBOSE("Fusing node with ID : " << output_edge->producer_id()
                                  << " with Activation Layer node with ID : " << output_edge->consumer_id() << std::endl);

    // Prevent fusion if fused node has an output accessor
    if(n_node->output(0)->accessor() == nullptr)
    {
        // Get driving nodes of activation node
        std::vector<NodeIdxPair> act_driving_nodes = get_driving_nodes(*act_node);

        // Set activation info to fused node
        n_node->set_fused_activation(act_node->activation_info());

        // Extract activation node accessor if any
        auto act_node_accessor = act_node->output(0)->extract_accessor();

        // Remove activation node
        g.remove_node(act_node->id());

        // Update fused node outputs
        for(auto &driving_node : act_driving_nodes)
        {
            g.add_connection(n_node->id(), 0, driving_node.node_id, driving_node.index);
        }

        // Update accessor to fused node
        n_node->output(0)->set_accessor(std::move(act_node_accessor));
    }
    else
    {
        ARM_COMPUTE_LOG_GRAPH_VERBOSE("Prevented fusion of node with activation due to the presence of an output accessor\n");
    }
}

template <typename N1, typename N2, typename F, typename... Args>
void fuse_layer(Graph &g, std::function<bool(INode &)> const &prec, const F fuse_fcn, Args &&... optional_arguments)
{
    // Not interested in the order of nodes
    for(auto &node : g.nodes())
    {
        // Check if the node is of type N and not a branching node
        if(node && node->type() == N1::node_type && node->output_edges().size() == 1)
        {
            const auto output_edge_id = *node->output_edges().begin();
            const auto output_edge    = g.edge(output_edge_id);

            // Check if following node is an activation layer node
            if((output_edge != nullptr) && (output_edge->consumer() != nullptr) && (output_edge->consumer()->type() == N2::node_type) && prec(*output_edge->producer()))
            {
                fuse_fcn(g, output_edge, optional_arguments...);
            }
        }
    }
}
} // namespace detail

const char *NodeFusionMutator::name()
{
    return "NodeFusionMutator";
}

void NodeFusionMutator::mutate(Graph &g)
{
    // Supported activations when fusing
    const std::set<Activation> supported_fused_activations = { Activation::RELU, Activation::BOUNDED_RELU, Activation::LU_BOUNDED_RELU };

    // Preconditions
    auto empty_prec = [](INode &)
    {
        return true;
    };
    auto qs8_prec = [&g](INode & n)
    {
        ARM_COMPUTE_ERROR_ON(n.output(0) == nullptr);

        const auto output_edge_id = *n.output_edges().begin();
        const auto output_edge    = g.edge(output_edge_id);
        // To perform fusion the two nodes must have same output quantization information
        const bool same_qinfo     = n.output(0)->desc().quant_info == output_edge->producer()->output(0)->desc().quant_info;
        const bool output_qasymm8 = n.output(0)->desc().data_type == DataType::QASYMM8;

        return (output_qasymm8 && same_qinfo) || !output_qasymm8;
    };

    // Fusion mutations
    detail::fuse_layer<BatchNormalizationLayerNode, ActivationLayerNode>(g, empty_prec, detail::fuse_node_with_activation<BatchNormalizationLayerNode>, supported_fused_activations);
    detail::fuse_layer<ConvolutionLayerNode, ActivationLayerNode>(g, empty_prec, detail::fuse_node_with_activation<ConvolutionLayerNode>, supported_fused_activations);
    detail::fuse_layer<DepthwiseConvolutionLayerNode, ActivationLayerNode>(g, qs8_prec, detail::fuse_node_with_activation<DepthwiseConvolutionLayerNode>, supported_fused_activations);

    // TODO (COMPMID-2055): re-enable once we fuse bias and activations to convolution
    // detail::fuse_layer<ConvolutionLayerNode, BatchNormalizationLayerNode>(g, empty_prec, detail::fuse_convolution_with_batch_normalization);
}
} // namespace graph
} // namespace arm_compute