ArmNN
 23.02
OptimizeConsecutiveReshapes.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "Optimization.hpp"
8 
9 namespace armnn
10 {
11 namespace optimizations
12 {
13 
15 {
16 public:
17  /// Run for every connection between a base ReshapeLayer and a child ReshapeLayer.
18  /// Inserts an equivalent ReshapeLayer that bypasses both for that connection.
19  void Run(Graph& graph, InputSlot& connection) const
20  {
21  Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
22  Layer& child = connection.GetOwningLayer();
23 
26 
27  OutputSlot* parentOut = base.GetInputSlot(0).GetConnectedOutputSlot();
28 
29  const TensorInfo& inInfo = parentOut->GetTensorInfo();
30  const TensorInfo& outInfo = child.GetOutputHandler().GetTensorInfo();
31 
32  // This Optimization is only appropriate when the base ReshapeLayer is connected to the child ReshapeLayer
33  // and no other Layer.
34  if (base.GetOutputSlot(0).GetNumConnections() > 1)
35  {
36  return;
37  }
38 
39  if (inInfo.GetShape() != outInfo.GetShape())
40  {
41  // Inserts equivalent reshape before base layer.
42  const std::string name = std::string("merged-") + base.GetName() + std::string("-with-") + child.GetName();
43  const ReshapeDescriptor descriptor{outInfo.GetShape()};
44  auto& newReshape = *graph.InsertNewLayer<ReshapeLayer>(base.GetInputSlot(0), descriptor, name.c_str());
45 
46  // Parent is now the new layer.
47  parentOut = &newReshape.GetOutputSlot();
48  }
49 
50  // Moves connections in child output to parent layer.
51  // Child layer will be removed as it's left unconnected.
52  // Base layer will be removed if left unconnected.
53  child.GetOutputSlot().MoveAllConnections(*parentOut);
54  }
55 
56 protected:
59 };
60 
62 
63 } // namespace optimizations
64 } // namespace armnn
armnn::InputSlot::GetOwningLayer
Layer & GetOwningLayer() const
Definition: Layer.hpp:53
armnn::OutputSlot
Definition: Layer.hpp:87
armnn::OutputHandler::GetTensorInfo
const TensorInfo & GetTensorInfo() const
Gets the matching TensorInfo for the output.
Definition: OutputHandler.hpp:42
armnn::InputSlot
Definition: Layer.hpp:42
armnn::InputSlot::GetConnectedOutputSlot
const OutputSlot * GetConnectedOutputSlot() const
Definition: Layer.hpp:56
armnn::Layer
Definition: Layer.hpp:217
armnn::Graph::InsertNewLayer
LayerT * InsertNewLayer(InputSlot &insertBefore, Args &&... args)
Inserts a new layer between the output slot currently connected to insertBefore and insertBefore itse...
Definition: Graph.hpp:471
armnn::ReshapeLayer
This layer represents a reshape operation.
Definition: ReshapeLayer.hpp:15
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition: 01_00_quick_start.dox:6
armnn::OutputSlot::GetTensorInfo
const TensorInfo & GetTensorInfo() const override
Definition: Layer.cpp:92
armnn::OptimizeForConnection
Definition: Optimization.hpp:118
armnn::optimizations::OptimizeConsecutiveReshapesImpl::Run
void Run(Graph &graph, InputSlot &connection) const
Run for every connection between a base ReshapeLayer and a child ReshapeLayer.
Definition: OptimizeConsecutiveReshapes.hpp:19
armnn::LayerType::Reshape
@ Reshape
armnn::Layer::GetOutputSlot
const OutputSlot & GetOutputSlot(unsigned int index=0) const override
Get the const output slot handle by slot index.
Definition: Layer.hpp:326
Optimization.hpp
armnn::Layer::GetOutputHandler
const OutputHandler & GetOutputHandler(unsigned int i=0) const
Definition: Layer.hpp:232
armnn::Layer::GetType
LayerType GetType() const override
Returns the armnn::LayerType of this layer.
Definition: Layer.hpp:273
armnn::TensorInfo
Definition: Tensor.hpp:152
armnn::optimizations::OptimizeConsecutiveReshapesImpl
Definition: OptimizeConsecutiveReshapes.hpp:14
armnn::OutputSlot::MoveAllConnections
void MoveAllConnections(OutputSlot &destination)
Moves all connections to another OutputSlot.
Definition: Layer.cpp:145
armnn::OutputSlot::GetOwningLayer
Layer & GetOwningLayer() const
Definition: Layer.hpp:119
armnn::TensorInfo::GetShape
const TensorShape & GetShape() const
Definition: Tensor.hpp:191
armnn::Layer::GetInputSlot
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
Definition: Layer.hpp:324
armnn::ReshapeDescriptor
A ReshapeDescriptor for the ReshapeLayer.
Definition: Descriptors.hpp:970
armnn::Graph
Definition: Graph.hpp:30
ARMNN_ASSERT
#define ARMNN_ASSERT(COND)
Definition: Assert.hpp:14
armnn::OutputSlot::GetNumConnections
unsigned int GetNumConnections() const override
Definition: Layer.hpp:145
armnn::optimizations::OptimizeConsecutiveReshapesImpl::~OptimizeConsecutiveReshapesImpl
~OptimizeConsecutiveReshapesImpl()=default
armnn::Layer::GetName
const char * GetName() const override
Returns the name of the layer.
Definition: Layer.hpp:319
armnn::optimizations::OptimizeConsecutiveReshapesImpl::OptimizeConsecutiveReshapesImpl
OptimizeConsecutiveReshapesImpl()=default