ArmNN
 22.02
RedirectMembersToConstantInputs.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "Optimization.hpp"
9 
12 
13 namespace armnn
14 {
15 namespace optimizations
16 {
17 
19 {
20 public:
21  /// Search for layers with ConstantLayers as inputs. If the inputs are constant redirect the layers member
22  /// variable for ConstTensors (e.g. m_weights) to the data stored in the ConstantLayer it is connected to.
23  void Run(Graph& graph, Layer& layer) const
24  {
25  IgnoreUnused(graph);
26 
27  switch (layer.GetType())
28  {
30  break;
32  break;
34  break;
36  break;
38  RedirectWeightsAndBiases<FullyConnectedLayer>(&layer);
39  break;
40  case LayerType::Lstm:
41  break;
43  break;
44  default:
45  break;
46  }
47  }
48 
49 protected:
52 
53 private:
54  template <typename LayerT>
55  static LayerT* RedirectWeightsAndBiases(Layer* layer)
56  {
57  LayerT* layerPtr = PolymorphicDowncast<LayerT*>(layer);
58 
59  // Loop through input slots to check for constant weights and biases layers.
60  // Weights index = 1, Biases index = 2.
61  for (unsigned int inputSlotIndex = 1; inputSlotIndex != layerPtr->GetNumInputSlots(); ++inputSlotIndex)
62  {
63  OutputSlot* outputSlot = layerPtr->GetInputSlot(inputSlotIndex).GetConnectedOutputSlot();
64  if (outputSlot->GetOwningLayer().GetType() == LayerType::Constant)
65  {
66  // Get constant layer and redirect base layer member variables.
67  ConstantLayer& constantLayer = dynamic_cast<ConstantLayer&>(outputSlot->GetOwningLayer());
68  if (inputSlotIndex == 1)
69  {
70  layerPtr->m_Weight = constantLayer.m_LayerOutput;
71  }
72  else if (inputSlotIndex == 2)
73  {
74  layerPtr->m_Bias = constantLayer.m_LayerOutput;
75  }
76  }
77  }
78 
79  return layerPtr;
80  }
81 };
82 
84 
85 } // namespace optimizations
86 } // namespace armnn
A layer that the constant data can be bound to.
std::shared_ptr< ConstTensorHandle > m_LayerOutput
Layer & GetOwningLayer() const
Definition: Layer.hpp:118
Copyright (c) 2021 ARM Limited and Contributors.
void IgnoreUnused(Ts &&...)
LayerType GetType() const override
Returns the armnn::LayerType of this layer.
Definition: Layer.hpp:270
void Run(Graph &graph, Layer &layer) const
Search for layers with ConstantLayers as inputs.