blob: 98e87c36c6213147e2175c38005666eeb74fe885 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
|
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
#include "Optimization.hpp"
#include <armnn/utility/IgnoreUnused.hpp>
namespace armnn
{
namespace optimizations
{
template <typename PermuteType>
class OptimizeInversePermutesImpl
{
public:
/// Run for every connection between a base PermuteLayer and a child PermuteLayer.
/// Bypasses both layers for that connection if one is the inverse of the other.
void Run(Graph& graph, InputSlot& connection) const
{
IgnoreUnused(graph);
Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
auto child = boost::polymorphic_downcast<PermuteType*>(&connection.GetOwningLayer());
if (child->IsInverse(*boost::polymorphic_downcast<PermuteType*>(&base)))
{
// Bypass both layers. Child will be removed as it's left unconnected.
// Base layer will be removed if left unconnected.
child->GetOutputSlot().MoveAllConnections(*base.GetInputSlot(0).GetConnectedOutputSlot());
}
}
protected:
OptimizeInversePermutesImpl() = default;
~OptimizeInversePermutesImpl() = default;
};
using OptimizeInversePermutes = OptimizeForConnection<PermuteLayer, PermuteLayer,
OptimizeInversePermutesImpl<PermuteLayer>>;
using OptimizeInverseTransposes = OptimizeForConnection<TransposeLayer, TransposeLayer,
OptimizeInversePermutesImpl<TransposeLayer>>;
} // namespace optimizations
} // namespace armnn
|