ArmNN
 23.05
MoveTransposeUpImpl Class Reference

#include <MoveTransposeUp.hpp>

Public Member Functions

void Run (Graph &graph, InputSlot &connection) const
 Run for every connection between a base Layer (any) and a child TransposeLayer. More...
 

Protected Member Functions

 MoveTransposeUpImpl ()=default
 
 ~MoveTransposeUpImpl ()=default
 

Detailed Description

Definition at line 16 of file MoveTransposeUp.hpp.

Constructor & Destructor Documentation

◆ MoveTransposeUpImpl()

MoveTransposeUpImpl ( )
protecteddefault

◆ ~MoveTransposeUpImpl()

~MoveTransposeUpImpl ( )
protecteddefault

Member Function Documentation

◆ Run()

void Run ( Graph graph,
InputSlot connection 
) const
inline

Run for every connection between a base Layer (any) and a child TransposeLayer.

If the type of the base layer allows it, it moves the permutation to the inputs of the base layer. I.e., adds equivalent permutations before the inputs of the base layer and moves the connections in the output of the child transpose layer to the output of the base layer.

Definition at line 23 of file MoveTransposeUp.hpp.

24  {
25  OutputSlot& baseOutput = *connection.GetConnectedOutputSlot();
26 
27  if (baseOutput.GetNumConnections() == 1U)
28  {
29  Layer& base = baseOutput.GetOwningLayer();
30 
31  if (CanMoveTransposeToInputs(base))
32  {
33  auto transpose = PolymorphicDowncast<TransposeLayer*>(&connection.GetOwningLayer());
34  const PermutationVector& perm = transpose->GetPermutation();
35 
36  // Inserts an equivalent transpose before every input of the base layer.
37  for (auto baseInput = base.BeginInputSlots(); baseInput != base.EndInputSlots(); ++baseInput)
38  {
39  // Inserts a new transpose layer.
40  const std::string name = std::string("moved_up-") + transpose->GetName();
41  TransposeLayer& permLayer = *graph.InsertNewLayer<TransposeLayer>(*baseInput, perm, name.c_str());
42 
43  // Sets output tensor info for the new layer.
44  OutputSlot& parentOutput = *permLayer.GetInputSlot(0).GetConnectedOutputSlot();
45  const TensorInfo permOutInfo = armnnUtils::TransposeTensorShape(parentOutput.GetTensorInfo(), perm);
46  permLayer.GetOutputHandler().SetTensorInfo(permOutInfo);
47  }
48 
49  // Bypasses transpose. It will be removed as it's left unconnected.
50  transpose->GetOutputSlot().MoveAllConnections(base.GetOutputSlot());
51  }
52  }
53  }

References Layer::BeginInputSlots(), Layer::EndInputSlots(), InputSlot::GetConnectedOutputSlot(), Layer::GetInputSlot(), OutputSlot::GetNumConnections(), Layer::GetOutputHandler(), Layer::GetOutputSlot(), InputSlot::GetOwningLayer(), OutputSlot::GetOwningLayer(), Graph::InsertNewLayer(), OutputHandler::SetTensorInfo(), and armnnUtils::TransposeTensorShape().


The documentation for this class was generated from the following file:
armnnUtils::TransposeTensorShape
armnn::TensorShape TransposeTensorShape(const armnn::TensorShape &srcShape, const armnn::PermutationVector &mappings)
Definition: Transpose.cpp:98