blob: 9a926a57a4efbc5b8fe91e5aa9d09e08f43b7209 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
|
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// See LICENSE file in the project root for full license information.
//
#pragma once
#include "Optimization.hpp"
namespace armnn
{
namespace optimizations
{
class OptimizeConsecutiveReshapesImpl
{
public:
/// Run for every connection between a base RashapeLayer and a child ReshapeLayer.
/// Inserts an equivalent ReshapeLayer that bypasses both for that connection.
void Run(Graph& graph, InputSlot& connection) const
{
Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
Layer& child = connection.GetOwningLayer();
BOOST_ASSERT(base.GetType() == LayerType::Reshape);
BOOST_ASSERT(child.GetType() == LayerType::Reshape);
OutputSlot* parentOut = base.GetInputSlot(0).GetConnectedOutputSlot();
const TensorInfo& inInfo = parentOut->GetTensorInfo();
const TensorInfo& outInfo = child.GetOutputHandler().GetTensorInfo();
if (inInfo.GetShape() != outInfo.GetShape())
{
// Insert equivalent reshape before base layer
const std::string name = std::string("merged-") + base.GetName() + std::string("-with-") + child.GetName();
const ReshapeDescriptor descriptor{outInfo.GetShape()};
auto& newReshape = *graph.InsertNewLayer<ReshapeLayer>(base.GetInputSlot(0), descriptor, name.c_str());
// Set tensor info for new layer
newReshape.GetOutputHandler().SetTensorInfo(outInfo);
// Reconnect base with original parent
newReshape.GetOutputSlot().MoveAllConnections(*parentOut);
// Parent is now the new layer
parentOut = &newReshape.GetOutputSlot();
}
// Move connections in child output to parent layer.
// Child layer will be removed as it's left unconnected.
// Base layer will be removed if left unconnected.
child.GetOutputSlot().MoveAllConnections(*parentOut);
}
protected:
OptimizeConsecutiveReshapesImpl() = default;
~OptimizeConsecutiveReshapesImpl() = default;
};
using OptimizeConsecutiveReshapes = OptimizeForConnection<ReshapeLayer, ReshapeLayer, OptimizeConsecutiveReshapesImpl>;
} // namespace optimizations
} // namespace armnn
|