aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/optimizations/OptimizeInverseConversions.hpp
blob: c3a56832ce32fec702b0d3124b5126dd5cae995d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once

#include "Optimization.hpp"

namespace armnn
{
namespace optimizations
{

class OptimizeInverseConversionsImpl
{
public:
    /// Run for every connection between two inverse data type conversion layers, i.e.
    /// Fp16ToFp32 followed by Fp32ToFp16 or vice-versa.
    void Run(Graph& graph, InputSlot& connection) const
    {
        Layer& base  = connection.GetConnectedOutputSlot()->GetOwningLayer();
        Layer& child = connection.GetOwningLayer();

        BOOST_ASSERT((base.GetType() == LayerType::ConvertFp16ToFp32 &&
                     child.GetType() == LayerType::ConvertFp32ToFp16) ||
                     (base.GetType() == LayerType::ConvertFp32ToFp16 &&
                     child.GetType() == LayerType::ConvertFp16ToFp32));

        // Bypass both conversion layers
        child.GetOutputSlot().MoveAllConnections(*base.GetInputSlot(0).GetConnectedOutputSlot());
    }

protected:
    OptimizeInverseConversionsImpl()  = default;
    ~OptimizeInverseConversionsImpl() = default;
};

using OptimizeInverseConversionsFp16 =
    OptimizeForConnection<ConvertFp16ToFp32Layer, ConvertFp32ToFp16Layer, OptimizeInverseConversionsImpl>;
using OptimizeInverseConversionsFp32 =
    OptimizeForConnection<ConvertFp32ToFp16Layer, ConvertFp16ToFp32Layer, OptimizeInverseConversionsImpl>;

} // namespace optimizations
} // namespace armnn