aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/optimizations/ConvertConstPermuteLayersToConstLayers.hpp
blob: 2cc3e8eaef6947f0165f37e0fc0724e094174a0b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
//
// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

#include "Optimization.hpp"
#include <armnnUtils/Permute.hpp>
#include <ResolveType.hpp>

namespace armnn
{
namespace optimizations
{

class ConvertConstPermuteLayersToConstLayers
{
public:
    void Run(Graph& graph, InputSlot& connection) const
    {
        Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
        Layer& child = connection.GetOwningLayer();

        ARMNN_ASSERT(base.GetType() == LayerType::Constant);
        ARMNN_ASSERT(child.GetType() == LayerType::Permute);

        if (base.GetDataType() == child.GetDataType())
        {
            switch (base.GetDataType())
            {
                case DataType::Float16:
                    ReplaceConstPermuteLayer<DataType::Float16>(graph,
                                                                 PolymorphicDowncast<ConstantLayer*>(&base),
                                                                 PolymorphicDowncast<PermuteLayer*>(&child));
                    break;
                case DataType::Float32:
                    ReplaceConstPermuteLayer<DataType::Float32>(graph,
                                                                 PolymorphicDowncast<ConstantLayer*>(&base),
                                                                 PolymorphicDowncast<PermuteLayer*>(&child));
                    break;
                case DataType::QAsymmU8:
                    ReplaceConstPermuteLayer<DataType::QAsymmU8>(graph,
                                                                 PolymorphicDowncast<ConstantLayer*>(&base),
                                                                 PolymorphicDowncast<PermuteLayer*>(&child));
                    break;
                case DataType::Signed32:
                    ReplaceConstPermuteLayer<DataType::Signed32>(graph,
                                                                 PolymorphicDowncast<ConstantLayer*>(&base),
                                                                 PolymorphicDowncast<PermuteLayer*>(&child));
                    break;
                case DataType::QSymmS16:
                    ReplaceConstPermuteLayer<DataType::QSymmS16>(graph,
                                                                 PolymorphicDowncast<ConstantLayer*>(&base),
                                                                 PolymorphicDowncast<PermuteLayer*>(&child));
                    break;
                case DataType::QSymmS8:
                    ReplaceConstPermuteLayer<DataType::QSymmS8>(graph,
                                                                 PolymorphicDowncast<ConstantLayer*>(&base),
                                                                 PolymorphicDowncast<PermuteLayer*>(&child));
                    break;
                case DataType::QAsymmS8:
                    ReplaceConstPermuteLayer<DataType::QAsymmS8>(graph,
                                                                 PolymorphicDowncast<ConstantLayer*>(&base),
                                                                 PolymorphicDowncast<PermuteLayer*>(&child));
                    break;
                case DataType::BFloat16:
                    ReplaceConstPermuteLayer<DataType::BFloat16>(graph,
                                                                 PolymorphicDowncast<ConstantLayer*>(&base),
                                                                 PolymorphicDowncast<PermuteLayer*>(&child));
                    break;
                case DataType::Signed64:
                    ReplaceConstPermuteLayer<DataType::Signed64>(graph,
                                                                 PolymorphicDowncast<ConstantLayer*>(&base),
                                                                 PolymorphicDowncast<PermuteLayer*>(&child));
                    break;
                case DataType::Boolean:
                    ReplaceConstPermuteLayer<DataType::Boolean>(graph,
                                                                 PolymorphicDowncast<ConstantLayer*>(&base),
                                                                 PolymorphicDowncast<PermuteLayer*>(&child));
                    break;
            }
        }
    }
protected:
    ConvertConstPermuteLayersToConstLayers()  = default;
    ~ConvertConstPermuteLayersToConstLayers() = default;
private:
    template<armnn::DataType ArmnnType,
             typename T = armnn::ResolveType<ArmnnType>>
    static void ReplaceConstPermuteLayer(Graph& graph,
                                         ConstantLayer* constantLayer,
                                         PermuteLayer* permuteLayer)
    {
        IgnoreUnused(graph);
        /**
         * This optimisation is to find situations where a constant set of inputs is being provided to a Permute
         * layer. In this case we don't want the overhead of Permuting the values on every inference, instead we
         * want to Permute them once and store them in a Const layer to be used everytime as they will not change.
         */
        TensorInfo outputPermuteInfo = permuteLayer->GetOutputSlot(0).GetTensorInfo();
        std::vector<T> newValues(outputPermuteInfo.GetNumElements());
        armnnUtils::Permute(outputPermuteInfo.GetShape(), permuteLayer->GetPermutation(),
                            constantLayer->m_LayerOutput->Map(true), newValues.data(),
                            GetDataTypeSize(outputPermuteInfo.GetDataType()));

        TensorInfo newInfo = outputPermuteInfo;
        newInfo.SetConstant(true);
        ConstTensor newInput(newInfo, newValues);
        constantLayer->m_LayerOutput.reset(new ScopedTensorHandle(newInput));

        // Moves connections in permute output to the constant layer.
        // Permute layer will be removed if left unconnected.
        permuteLayer->GetOutputSlot().MoveAllConnections(constantLayer->GetOutputSlot());

        // Updating the output tensor
        constantLayer->GetOutputSlot(0).SetTensorInfo(newInfo);
        ARMNN_ASSERT(constantLayer->GetOutputSlot(0).GetTensorInfo().IsConstant() == true);
    }
};

using FusePermuteIntoConstLayer = OptimizeForConnection<ConstantLayer,
                                                        PermuteLayer,
                                                        ConvertConstPermuteLayersToConstLayers>;

} // namespace optimizations
} // namespace armnn