ArmNN
 24.02
TransposeAsReshape.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2020 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "Optimization.hpp"
8 
9 namespace armnn
10 {
11 namespace optimizations
12 {
13 
15 {
16 public:
17  /// Run for every TransposeLayer. Replaces it with a ReshapeLayer if they are equivalent.
18  void Run(Graph& graph, TransposeLayer& transpose) const
19  {
20  if (IsReshape(transpose))
21  {
22  const TensorInfo& outInfo = transpose.GetOutputHandler().GetTensorInfo();
23 
24  const std::string name = std::string("as_reshape-") + transpose.GetName();
25  const ReshapeDescriptor descriptor{outInfo.GetShape()};
26  // Inserts NewLayer so layers don't need to be re-sorted.
27  auto reshape = graph.InsertNewLayer<ReshapeLayer>(transpose.GetInputSlot(0), descriptor, name.c_str());
28 
29  // Bypass transpose. It will be deleted since it's left unconnected.
30  transpose.GetOutputSlot().MoveAllConnections(reshape->GetOutputSlot());
31  }
32  }
33 
34 protected:
35  TransposeAsReshapeImpl() = default;
36  ~TransposeAsReshapeImpl() = default;
37 
38 private:
39  static bool IsReshape(const TransposeLayer& layer)
40  {
41  const TensorShape& outShape = layer.GetOutputHandler().GetTensorInfo().GetShape();
42  const PermutationVector& permutation = layer.GetPermutation();
43 
44  const unsigned int numDimensions = permutation.GetSize();
45  std::map<unsigned int, unsigned int> permuteMappings;
46  for (unsigned int i = 0; i < permutation.GetSize(); ++i)
47  {
48  permuteMappings[permutation[i]] = i;
49  }
50 
51  std::vector<unsigned int> permuteVector;
52  for (unsigned int i = 0; i < permutation.GetSize(); ++i)
53  {
54  permuteVector.push_back(permuteMappings.at(i));
55  }
56 
57  unsigned int lastGtOne = 0;
58  while ((lastGtOne < numDimensions) && (outShape[(permuteVector[lastGtOne])] == 1U))
59  {
60  ++lastGtOne;
61  }
62 
63  bool isReshape = true;
64  for (unsigned int i = lastGtOne + 1U; isReshape && (i < numDimensions); ++i)
65  {
66  if (outShape[permuteVector[i]] > 1U)
67  {
68  isReshape = permuteVector[lastGtOne] < permuteVector[i];
69  lastGtOne = i;
70  }
71  }
72 
73  return isReshape;
74  }
75 };
76 
78 
79 } // namespace optimizations
80 } // namespace armnn
armnn::TensorInfo
Definition: Tensor.hpp:152
armnn::Layer::GetOutputSlot
const OutputSlot & GetOutputSlot(unsigned int index=0) const override
Get the const output slot handle by slot index.
Definition: Layer.hpp:339
armnn::optimizations::TransposeAsReshapeImpl
Definition: TransposeAsReshape.hpp:14
Optimization.hpp
armnn::Layer::GetInputSlot
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
Definition: Layer.hpp:337
armnn::Layer::GetName
const char * GetName() const override
Returns the name of the layer.
Definition: Layer.hpp:332
armnn::TransposeLayer
This layer represents a transpose operation.
Definition: TransposeLayer.hpp:15
armnn::TensorShape
Definition: Tensor.hpp:20
armnn::ReshapeLayer
This layer represents a reshape operation.
Definition: ReshapeLayer.hpp:15
armnn::ReshapeDescriptor
A ReshapeDescriptor for the ReshapeLayer.
Definition: Descriptors.hpp:1023
armnn::optimizations::TransposeAsReshapeImpl::~TransposeAsReshapeImpl
~TransposeAsReshapeImpl()=default
armnn::Layer::GetOutputHandler
const OutputHandler & GetOutputHandler(unsigned int i=0) const
Definition: Layer.hpp:245
armnn::optimizations::TransposeAsReshapeImpl::TransposeAsReshapeImpl
TransposeAsReshapeImpl()=default
armnn::PermutationVector
Definition: Types.hpp:314
armnn::OutputSlot::MoveAllConnections
void MoveAllConnections(OutputSlot &destination)
Moves all connections to another OutputSlot.
Definition: Layer.cpp:145
armnn::OptimizeForType
Definition: Optimization.hpp:67
armnn::TransposeLayer::GetPermutation
const PermutationVector & GetPermutation() const
Definition: TransposeLayer.hpp:37
armnn::PermutationVector::GetSize
SizeType GetSize() const
Definition: Types.hpp:357
armnn::TensorInfo::GetShape
const TensorShape & GetShape() const
Definition: Tensor.hpp:193
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition: 01_00_quick_start.dox:6
armnn::OutputHandler::GetTensorInfo
const TensorInfo & GetTensorInfo() const
Gets the matching TensorInfo for the output.
Definition: OutputHandler.hpp:42
armnn::optimizations::TransposeAsReshapeImpl::Run
void Run(Graph &graph, TransposeLayer &transpose) const
Run for every TransposeLayer. Replaces it with a ReshapeLayer if they are equivalent.
Definition: TransposeAsReshape.hpp:18
armnn::Graph
Definition: Graph.hpp:30
armnn::Graph::InsertNewLayer
LayerT * InsertNewLayer(InputSlot &insertBefore, Args &&... args)
Inserts a new layer between the output slot currently connected to insertBefore and insertBefore itse...
Definition: Graph.hpp:471