aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/optimizations/OptimizeInversePermutes.hpp
blob: 48bfa354405b4f4921ff971b8b75ee22eeb60cbf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once

#include "Optimization.hpp"

#include <boost/core/ignore_unused.hpp>

namespace armnn
{
namespace optimizations
{

class OptimizeInversePermutesImpl
{
public:
    /// Run for every connection between a base PermuteLayer and a child PermuteLayer.
    /// Bypasses both layers for that connection if one is the inverse of the other.
    void Run(Graph& graph, InputSlot& connection) const
    {
        boost::ignore_unused(graph);
        Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
        auto child = boost::polymorphic_downcast<PermuteLayer*>(&connection.GetOwningLayer());

        if (child->IsInverse(*boost::polymorphic_downcast<PermuteLayer*>(&base)))
        {
            // Bypass both layers. Child will be removed as it's left unconnected.
            // Base layer will be removed if left unconnected.
            child->GetOutputSlot().MoveAllConnections(*base.GetInputSlot(0).GetConnectedOutputSlot());
        }
    }

protected:
    OptimizeInversePermutesImpl() = default;
    ~OptimizeInversePermutesImpl() = default;
};

using OptimizeInversePermutes = OptimizeForConnection<PermuteLayer, PermuteLayer, OptimizeInversePermutesImpl>;

} // namespace optimizations
} // namespace armnn