aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp
blob: 5047d5d678da1a0af01a9f2b7141fb13301583d3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once

#include "Optimization.hpp"

namespace armnn
{
namespace optimizations
{

class OptimizeConsecutiveReshapesImpl
{
public:
    /// Run for every connection between a base RashapeLayer and a child ReshapeLayer.
    /// Inserts an equivalent ReshapeLayer that bypasses both for that connection.
    void Run(Graph& graph, InputSlot& connection) const
    {
        Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
        Layer& child = connection.GetOwningLayer();

        BOOST_ASSERT(base.GetType() == LayerType::Reshape);
        BOOST_ASSERT(child.GetType() == LayerType::Reshape);

        OutputSlot* parentOut = base.GetInputSlot(0).GetConnectedOutputSlot();

        const TensorInfo& inInfo = parentOut->GetTensorInfo();
        const TensorInfo& outInfo = child.GetOutputHandler().GetTensorInfo();

        if (inInfo.GetShape() != outInfo.GetShape())
        {
            // Inserts equivalent reshape before base layer.
            const std::string name = std::string("merged-") + base.GetName() + std::string("-with-") + child.GetName();
            const ReshapeDescriptor descriptor{outInfo.GetShape()};
            auto& newReshape = *graph.InsertNewLayer<ReshapeLayer>(base.GetInputSlot(0), descriptor, name.c_str());
            // Sets tensor info for new layer.
            newReshape.GetOutputHandler().SetTensorInfo(outInfo);
            // Reconnects base with original parent.
            newReshape.GetOutputSlot().MoveAllConnections(*parentOut);
            // Parent is now the new layer.
            parentOut = &newReshape.GetOutputSlot();
        }

        // Moves connections in child output to parent layer.
        // Child layer will be removed as it's left unconnected.
        // Base layer will be removed if left unconnected.
        child.GetOutputSlot().MoveAllConnections(*parentOut);
    }

protected:
    OptimizeConsecutiveReshapesImpl() = default;
    ~OptimizeConsecutiveReshapesImpl() = default;
};

using OptimizeConsecutiveReshapes = OptimizeForConnection<ReshapeLayer, ReshapeLayer, OptimizeConsecutiveReshapesImpl>;

} // namespace optimizations
} // namespace armnn