ArmNN
 20.05
PermuteAsReshape.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "Optimization.hpp"
8 
9 namespace armnn
10 {
11 namespace optimizations
12 {
13 
15 {
16 public:
17  /// Run for every PermuteLayer. Replaces it with a ReshapeLayer if they are equivalent.
18  void Run(Graph& graph, PermuteLayer& permute) const
19  {
20  if (IsReshape(permute))
21  {
22  const TensorInfo& outInfo = permute.GetOutputHandler().GetTensorInfo();
23 
24  const std::string name = std::string("as_reshape-") + permute.GetName();
25  const ReshapeDescriptor descriptor{outInfo.GetShape()};
26  // Inserts NewLayer so layers don't need to be re-sorted.
27  auto reshape = graph.InsertNewLayer<ReshapeLayer>(permute.GetInputSlot(0), descriptor, name.c_str());
28 
29  // Bypass permute. It will be deleted since it's left unconnected.
30  permute.GetOutputSlot().MoveAllConnections(reshape->GetOutputSlot());
31  }
32  }
33 
34 protected:
35  PermuteAsReshapeImpl() = default;
36  ~PermuteAsReshapeImpl() = default;
37 
38 private:
39  static bool IsReshape(const PermuteLayer& layer)
40  {
41  const TensorShape& outShape = layer.GetOutputHandler().GetTensorInfo().GetShape();
42  const PermutationVector& permutation = layer.GetPermutation();
43 
44  const unsigned int numDimensions = permutation.GetSize();
45 
46  unsigned int lastGtOne = 0;
47  while ((lastGtOne < numDimensions) && (outShape[(permutation[lastGtOne])] == 1U))
48  {
49  ++lastGtOne;
50  }
51 
52  bool isReshape = true;
53  for (unsigned int i = lastGtOne + 1U; isReshape && (i < numDimensions); ++i)
54  {
55  if (outShape[permutation[i]] > 1U)
56  {
57  isReshape = permutation[lastGtOne] < permutation[i];
58  lastGtOne = i;
59  }
60  }
61 
62  return isReshape;
63  }
64 };
65 
67 
68 } // namespace optimizations
69 } // namespace armnn
const TensorShape & GetShape() const
Definition: Tensor.hpp:88
A ReshapeDescriptor for the ReshapeLayer.
This layer represents a reshape operation.
Copyright (c) 2020 ARM Limited.
SizeType GetSize() const
Definition: Types.hpp:202
This layer represents a permutation operation.
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
Definition: Layer.hpp:310
const PermutationVector & GetPermutation() const
const OutputHandler & GetOutputHandler(unsigned int i=0) const
Definition: Layer.hpp:221
const OutputSlot & GetOutputSlot(unsigned int index=0) const override
Get the const output slot handle by slot index.
Definition: Layer.hpp:312
const char * GetName() const override
Returns the name of the layer.
Definition: Layer.hpp:305
void Run(Graph &graph, PermuteLayer &permute) const
Run for every PermuteLayer. Replaces it with a ReshapeLayer if they are equivalent.
LayerT * InsertNewLayer(InputSlot &insertBefore, Args &&... args)
Inserts a new layer between the output slot currently connected to insertBefore and insertBefore itse...
Definition: Graph.hpp:410
void MoveAllConnections(OutputSlot &destination)
Moves all connections to another OutputSlot.
Definition: Layer.cpp:112
const TensorInfo & GetTensorInfo() const
Gets the matching TensorInfo for the output.