ArmNN
 22.05.01
ConvertConstPermuteLayersToConstLayers.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "Optimization.hpp"
9 #include <armnnUtils/Permute.hpp>
10 #include <ResolveType.hpp>
11 
12 namespace armnn
13 {
14 namespace optimizations
15 {
16 
18 {
19 public:
20  void Run(Graph& graph, InputSlot& connection) const
21  {
22  Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
23  Layer& child = connection.GetOwningLayer();
24 
27 
28  if (base.GetDataType() == child.GetDataType())
29  {
30  switch (base.GetDataType())
31  {
32  case DataType::Float16:
33  ReplaceConstPermuteLayer<DataType::Float16>(graph,
34  PolymorphicDowncast<ConstantLayer*>(&base),
35  PolymorphicDowncast<PermuteLayer*>(&child));
36  break;
37  case DataType::Float32:
38  ReplaceConstPermuteLayer<DataType::Float32>(graph,
39  PolymorphicDowncast<ConstantLayer*>(&base),
40  PolymorphicDowncast<PermuteLayer*>(&child));
41  break;
42  case DataType::QAsymmU8:
43  ReplaceConstPermuteLayer<DataType::QAsymmU8>(graph,
44  PolymorphicDowncast<ConstantLayer*>(&base),
45  PolymorphicDowncast<PermuteLayer*>(&child));
46  break;
47  case DataType::Signed32:
48  ReplaceConstPermuteLayer<DataType::Signed32>(graph,
49  PolymorphicDowncast<ConstantLayer*>(&base),
50  PolymorphicDowncast<PermuteLayer*>(&child));
51  break;
52  case DataType::QSymmS16:
53  ReplaceConstPermuteLayer<DataType::QSymmS16>(graph,
54  PolymorphicDowncast<ConstantLayer*>(&base),
55  PolymorphicDowncast<PermuteLayer*>(&child));
56  break;
57  case DataType::QSymmS8:
58  ReplaceConstPermuteLayer<DataType::QSymmS8>(graph,
59  PolymorphicDowncast<ConstantLayer*>(&base),
60  PolymorphicDowncast<PermuteLayer*>(&child));
61  break;
62  case DataType::QAsymmS8:
63  ReplaceConstPermuteLayer<DataType::QAsymmS8>(graph,
64  PolymorphicDowncast<ConstantLayer*>(&base),
65  PolymorphicDowncast<PermuteLayer*>(&child));
66  break;
67  case DataType::BFloat16:
68  ReplaceConstPermuteLayer<DataType::BFloat16>(graph,
69  PolymorphicDowncast<ConstantLayer*>(&base),
70  PolymorphicDowncast<PermuteLayer*>(&child));
71  break;
72  case DataType::Signed64:
73  ReplaceConstPermuteLayer<DataType::Signed64>(graph,
74  PolymorphicDowncast<ConstantLayer*>(&base),
75  PolymorphicDowncast<PermuteLayer*>(&child));
76  break;
77  case DataType::Boolean:
78  ReplaceConstPermuteLayer<DataType::Boolean>(graph,
79  PolymorphicDowncast<ConstantLayer*>(&base),
80  PolymorphicDowncast<PermuteLayer*>(&child));
81  break;
82  }
83  }
84  }
85 protected:
88 private:
89  template<armnn::DataType ArmnnType,
90  typename T = armnn::ResolveType<ArmnnType>>
91  static void ReplaceConstPermuteLayer(Graph& graph,
92  ConstantLayer* constantLayer,
93  PermuteLayer* permuteLayer)
94  {
95  IgnoreUnused(graph);
96  /**
97  * This optimisation is to find situations where a constant set of inputs is being provided to a Permute
98  * layer. In this case we don't want the overhead of Permuting the values on every inference, instead we
99  * want to Permute them once and store them in a Const layer to be used everytime as they will not change.
100  */
101  TensorInfo outputPermuteInfo = permuteLayer->GetOutputSlot(0).GetTensorInfo();
102  std::vector<T> newValues(outputPermuteInfo.GetNumElements());
103  armnnUtils::Permute(outputPermuteInfo.GetShape(), permuteLayer->GetPermutation(),
104  constantLayer->m_LayerOutput->Map(true), newValues.data(),
105  GetDataTypeSize(outputPermuteInfo.GetDataType()));
106 
107  TensorInfo newInfo = outputPermuteInfo;
108  newInfo.SetConstant(true);
109  ConstTensor newInput(newInfo, newValues);
110  constantLayer->m_LayerOutput.reset(new ScopedTensorHandle(newInput));
111 
112  // Moves connections in permute output to the constant layer.
113  // Permute layer will be removed if left unconnected.
114  permuteLayer->GetOutputSlot().MoveAllConnections(constantLayer->GetOutputSlot());
115 
116  // Updating the output tensor
117  constantLayer->GetOutputSlot(0).SetTensorInfo(newInfo);
118  ARMNN_ASSERT(constantLayer->GetOutputSlot(0).GetTensorInfo().IsConstant() == true);
119  }
120 };
121 
123  PermuteLayer,
125 
126 } // namespace optimizations
127 } // namespace armnn
A layer that the constant data can be bound to.
bool IsConstant() const
Definition: Tensor.cpp:509
const TensorShape & GetShape() const
Definition: Tensor.hpp:191
std::shared_ptr< ConstTensorHandle > m_LayerOutput
Layer & GetOwningLayer() const
Definition: Layer.hpp:118
typename ResolveTypeImpl< DT >::Type ResolveType
Definition: ResolveType.hpp:79
Copyright (c) 2021 ARM Limited and Contributors.
void IgnoreUnused(Ts &&...)
This layer represents a permutation operation.
DataType
Definition: Types.hpp:48
DataType GetDataType() const
Definition: Tensor.hpp:198
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition: Tensor.hpp:327
LayerType GetType() const override
Returns the armnn::LayerType of this layer.
Definition: Layer.hpp:271
#define ARMNN_ASSERT(COND)
Definition: Assert.hpp:14
const OutputSlot * GetConnectedOutputSlot() const
Definition: Layer.hpp:56
Layer & GetOwningLayer() const
Definition: Layer.hpp:53
const PermutationVector & GetPermutation() const
void SetTensorInfo(const TensorInfo &tensorInfo) override
Definition: Layer.cpp:87
void SetConstant(const bool IsConstant=true)
Marks the data corresponding to this tensor info as constant.
Definition: Tensor.cpp:514
DataType GetDataType() const
Definition: Layer.cpp:313
const OutputSlot & GetOutputSlot(unsigned int index=0) const override
Get the const output slot handle by slot index.
Definition: Layer.hpp:324
const TensorInfo & GetTensorInfo() const override
Definition: Layer.cpp:92
void MoveAllConnections(OutputSlot &destination)
Moves all connections to another OutputSlot.
Definition: Layer.cpp:145
unsigned int GetNumElements() const
Definition: Tensor.hpp:196
constexpr unsigned int GetDataTypeSize(DataType dataType)
Definition: TypesUtils.hpp:151