ArmNN
 22.05.01
RedirectMembersToConstantInputs.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "Optimization.hpp"
9 
12 
13 namespace armnn
14 {
15 namespace optimizations
16 {
17 
19 {
20 public:
21  /// Search for layers with ConstantLayers as inputs. If the inputs are constant redirect the layers member
22  /// variable for ConstTensors (e.g. m_weights) to the data stored in the ConstantLayer it is connected to.
23  void Run(Graph& graph, Layer& layer) const
24  {
25  IgnoreUnused(graph);
26 
27  switch (layer.GetType())
28  {
30  break;
32  RedirectWeightsAndBiases<Convolution2dLayer>(&layer);
33  break;
35  RedirectWeightsAndBiases<DepthwiseConvolution2dLayer>(&layer);
36  break;
38  break;
40  RedirectWeightsAndBiases<FullyConnectedLayer>(&layer);
41  break;
42  case LayerType::Lstm:
43  break;
45  break;
46  default:
47  break;
48  }
49  }
50 
51 protected:
54 
55 private:
56  template <typename LayerT>
57  static LayerT* RedirectWeightsAndBiases(Layer* layer)
58  {
59  LayerT* layerPtr = PolymorphicDowncast<LayerT*>(layer);
60 
61  // Loop through input slots to check for constant weights and biases layers.
62  // Weights index = 1, Biases index = 2.
63  for (unsigned int inputSlotIndex = 1; inputSlotIndex != layerPtr->GetNumInputSlots(); ++inputSlotIndex)
64  {
65  OutputSlot* outputSlot = layerPtr->GetInputSlot(inputSlotIndex).GetConnectedOutputSlot();
66  if (outputSlot->GetOwningLayer().GetType() == LayerType::Constant)
67  {
68  // Get constant layer and redirect base layer member variables.
69  ConstantLayer& constantLayer = dynamic_cast<ConstantLayer&>(outputSlot->GetOwningLayer());
70  if (inputSlotIndex == 1)
71  {
72  layerPtr->m_Weight = constantLayer.m_LayerOutput;
73  }
74  else if (inputSlotIndex == 2)
75  {
76  layerPtr->m_Bias = constantLayer.m_LayerOutput;
77  }
78  }
79  }
80 
81  return layerPtr;
82  }
83 };
84 
86 
87 } // namespace optimizations
88 } // namespace armnn
A layer that the constant data can be bound to.
std::shared_ptr< ConstTensorHandle > m_LayerOutput
Layer & GetOwningLayer() const
Definition: Layer.hpp:118
Copyright (c) 2021 ARM Limited and Contributors.
void IgnoreUnused(Ts &&...)
LayerType GetType() const override
Returns the armnn::LayerType of this layer.
Definition: Layer.hpp:271
void Run(Graph &graph, Layer &layer) const
Search for layers with ConstantLayers as inputs.