ArmNN
 20.02
TransposeAsReshape.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2020 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "Optimization.hpp"
8 
9 namespace armnn
10 {
11 namespace optimizations
12 {
13 
15 {
16 public:
17  /// Run for every TransposeLayer. Replaces it with a ReshapeLayer if they are equivalent.
18  void Run(Graph& graph, TransposeLayer& transpose) const
19  {
20  if (IsReshape(transpose))
21  {
22  const TensorInfo& outInfo = transpose.GetOutputHandler().GetTensorInfo();
23 
24  const std::string name = std::string("as_reshape-") + transpose.GetName();
25  const ReshapeDescriptor descriptor{outInfo.GetShape()};
26  // Inserts NewLayer so layers don't need to be re-sorted.
27  auto reshape = graph.InsertNewLayer<ReshapeLayer>(transpose.GetInputSlot(0), descriptor, name.c_str());
28  reshape->GetOutputHandler().SetTensorInfo(outInfo);
29 
30  // Bypass transpose. It will be deleted since it's left unconnected.
31  transpose.GetOutputSlot().MoveAllConnections(reshape->GetOutputSlot());
32  }
33  }
34 
35 protected:
36  TransposeAsReshapeImpl() = default;
37  ~TransposeAsReshapeImpl() = default;
38 
39 private:
40  static bool IsReshape(const TransposeLayer& layer)
41  {
42  const TensorShape& outShape = layer.GetOutputHandler().GetTensorInfo().GetShape();
43  const PermutationVector& permutation = layer.GetPermutation();
44 
45  const unsigned int numDimensions = permutation.GetSize();
46  std::map<unsigned int, unsigned int> permuteMappings;
47  for (unsigned int i = 0; i < permutation.GetSize(); ++i)
48  {
49  permuteMappings[permutation[i]] = i;
50  }
51 
52  std::vector<unsigned int> permuteVector;
53  for (unsigned int i = 0; i < permutation.GetSize(); ++i)
54  {
55  permuteVector.push_back(permuteMappings.at(i));
56  }
57 
58  unsigned int lastGtOne = 0;
59  while ((lastGtOne < numDimensions) && (outShape[(permuteVector[lastGtOne])] == 1U))
60  {
61  ++lastGtOne;
62  }
63 
64  bool isReshape = true;
65  for (unsigned int i = lastGtOne + 1U; isReshape && (i < numDimensions); ++i)
66  {
67  if (outShape[permuteVector[i]] > 1U)
68  {
69  isReshape = permuteVector[lastGtOne] < permuteVector[i];
70  lastGtOne = i;
71  }
72  }
73 
74  return isReshape;
75  }
76 };
77 
79 
80 } // namespace optimizations
81 } // namespace armnn
const TensorShape & GetShape() const
Definition: Tensor.hpp:88
A ReshapeDescriptor for the ReshapeLayer.
This layer represents a reshape operation.
Copyright (c) 2020 ARM Limited.
SizeType GetSize() const
Definition: Types.hpp:202
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
Definition: Layer.hpp:310
This layer represents a transpose operation.
void SetTensorInfo(const TensorInfo &tensorInfo)
Sets the TensorInfo used by this output handler.
const OutputHandler & GetOutputHandler(unsigned int i=0) const
Definition: Layer.hpp:221
const OutputSlot & GetOutputSlot(unsigned int index=0) const override
Get the const output slot handle by slot index.
Definition: Layer.hpp:312
const char * GetName() const override
Returns the name of the layer.
Definition: Layer.hpp:305
LayerT * InsertNewLayer(InputSlot &insertBefore, Args &&... args)
Inserts a new layer between the output slot currently connected to insertBefore and insertBefore itse...
Definition: Graph.hpp:409
const PermutationVector & GetPermutation() const
void MoveAllConnections(OutputSlot &destination)
Moves all connections to another OutputSlot.
Definition: Layer.cpp:112
void Run(Graph &graph, TransposeLayer &transpose) const
Run for every TransposeLayer. Replaces it with a ReshapeLayer if they are equivalent.
const TensorInfo & GetTensorInfo() const
Gets the matching TensorInfo for the output.