ArmNN
 20.02
OptimizeConsecutiveReshapes.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "Optimization.hpp"
8 
9 namespace armnn
10 {
11 namespace optimizations
12 {
13 
15 {
16 public:
17  /// Run for every connection between a base ReshapeLayer and a child ReshapeLayer.
18  /// Inserts an equivalent ReshapeLayer that bypasses both for that connection.
19  void Run(Graph& graph, InputSlot& connection) const
20  {
21  Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
22  Layer& child = connection.GetOwningLayer();
23 
24  BOOST_ASSERT(base.GetType() == LayerType::Reshape);
25  BOOST_ASSERT(child.GetType() == LayerType::Reshape);
26 
27  OutputSlot* parentOut = base.GetInputSlot(0).GetConnectedOutputSlot();
28 
29  const TensorInfo& inInfo = parentOut->GetTensorInfo();
30  const TensorInfo& outInfo = child.GetOutputHandler().GetTensorInfo();
31 
32  // This Optimization is only appropriate when the base ReshapeLayer is connected to the child ReshapeLayer
33  // and no other Layer.
34  if (base.GetOutputSlot(0).GetNumConnections() > 1)
35  {
36  return;
37  }
38 
39  if (inInfo.GetShape() != outInfo.GetShape())
40  {
41  // Inserts equivalent reshape before base layer.
42  const std::string name = std::string("merged-") + base.GetName() + std::string("-with-") + child.GetName();
43  const ReshapeDescriptor descriptor{outInfo.GetShape()};
44  auto& newReshape = *graph.InsertNewLayer<ReshapeLayer>(base.GetInputSlot(0), descriptor, name.c_str());
45 
46  // Sets tensor info for new layer.
47  newReshape.GetOutputHandler().SetTensorInfo(outInfo);
48  // Parent is now the new layer.
49  parentOut = &newReshape.GetOutputSlot();
50  }
51 
52  // Moves connections in child output to parent layer.
53  // Child layer will be removed as it's left unconnected.
54  // Base layer will be removed if left unconnected.
55  child.GetOutputSlot().MoveAllConnections(*parentOut);
56  }
57 
58 protected:
61 };
62 
64 
65 } // namespace optimizations
66 } // namespace armnn
const TensorShape & GetShape() const
Definition: Tensor.hpp:88
A ReshapeDescriptor for the ReshapeLayer.
Layer & GetOwningLayer() const
Definition: Layer.hpp:115
This layer represents a reshape operation.
Copyright (c) 2020 ARM Limited.
unsigned int GetNumConnections() const override
Definition: Layer.hpp:138
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
Definition: Layer.hpp:310
void Run(Graph &graph, InputSlot &connection) const
Run for every connection between a base ReshapeLayer and a child ReshapeLayer.
const OutputSlot * GetConnectedOutputSlot() const
Definition: Layer.hpp:55
Layer & GetOwningLayer() const
Definition: Layer.hpp:52
void SetTensorInfo(const TensorInfo &tensorInfo)
Sets the TensorInfo used by this output handler.
const OutputHandler & GetOutputHandler(unsigned int i=0) const
Definition: Layer.hpp:221
LayerType GetType() const
Definition: Layer.hpp:259
const OutputSlot & GetOutputSlot(unsigned int index=0) const override
Get the const output slot handle by slot index.
Definition: Layer.hpp:312
const char * GetName() const override
Returns the name of the layer.
Definition: Layer.hpp:305
LayerT * InsertNewLayer(InputSlot &insertBefore, Args &&... args)
Inserts a new layer between the output slot currently connected to insertBefore and insertBefore itse...
Definition: Graph.hpp:409
void MoveAllConnections(OutputSlot &destination)
Moves all connections to another OutputSlot.
Definition: Layer.cpp:112
const TensorInfo & GetTensorInfo() const
Gets the matching TensorInfo for the output.