ArmNN
 22.05.01
FuseBatchNorm< ConvLayer, ArmnnType, T > Class Template Reference

#include <FuseBatchNorm.hpp>

Public Member Functions

void Run (Graph &graph, InputSlot &connection) const
 Run for every exclusive connection between any base Convolution layer and a child BatchNorm layer for not quantized layers. More...
 

Protected Member Functions

 FuseBatchNorm ()=default
 
 ~FuseBatchNorm ()=default
 

Detailed Description

template<typename ConvLayer, armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
class armnn::optimizations::FuseBatchNorm< ConvLayer, ArmnnType, T >

Definition at line 19 of file FuseBatchNorm.hpp.

Constructor & Destructor Documentation

◆ FuseBatchNorm()

◆ ~FuseBatchNorm()

~FuseBatchNorm ( )
protecteddefault

Member Function Documentation

◆ Run()

void Run ( Graph graph,
InputSlot connection 
) const
inline

Run for every exclusive connection between any base Convolution layer and a child BatchNorm layer for not quantized layers.

The child will be removed, the base will be removed if it's left unconnected. A new Convolution layer will be added, its weights and bias will be calculated using the weights and bias of the base Convolution layer combined with the parameters of the child BatchNorm layer.

Definition at line 27 of file FuseBatchNorm.hpp.

References Graph::AddLayer(), ARMNN_ASSERT, ARMNN_ASSERT_MSG, armnn::BatchNormalization, OutputSlot::Connect(), armnn::Convolution2d, armnn::DepthwiseConvolution2d, OutputSlot::Disconnect(), FuseBatchNorm< ConvLayer, ArmnnType, T >::FuseBatchNorm(), InputSlot::GetConnectedOutputSlot(), Layer::GetDataType(), BaseTensor< MemoryType >::GetInfo(), Layer::GetInputSlot(), Layer::GetName(), BaseTensor< MemoryType >::GetNumElements(), Layer::GetOutputSlot(), InputSlot::GetOwningLayer(), OutputSlot::GetOwningLayer(), TensorInfo::GetShape(), OutputSlot::GetTensorInfo(), Layer::GetType(), armnn::IgnoreUnused(), Graph::InsertNewLayer(), BatchNormalizationDescriptor::m_Eps, ConstantLayer::m_LayerOutput, OutputSlot::MoveAllConnections(), armnn::NHWC, OutputSlot::SetTensorInfo(), and FuseBatchNorm< ConvLayer, ArmnnType, T >::~FuseBatchNorm().

28  {
29  Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
30  Layer& child = connection.GetOwningLayer();
31 
32  bool depthwise = (base.GetType() == LayerType::DepthwiseConvolution2d);
33 
34  ARMNN_ASSERT(base.GetType() == LayerType::Convolution2d || depthwise);
35  ARMNN_ASSERT(child.GetType() == LayerType::BatchNormalization);
36 
37  if (base.GetDataType() == ArmnnType && child.GetDataType() == ArmnnType)
38  {
39  OutputSlot* parentOut = base.GetInputSlot(0).GetConnectedOutputSlot();
40  auto convLayer = PolymorphicDowncast<ConvLayer*>(&base);
41  auto batchNormLayer = PolymorphicDowncast<BatchNormalizationLayer*>(&child);
42 
43  // Read convolution and batch norm parameters
44  BatchNormalizationDescriptor batchNormDescriptor = batchNormLayer->GetParameters();
45  auto epsilon = batchNormDescriptor.m_Eps;
46  IgnoreUnused(epsilon);
47 
48  ConstTensor betaTensor(batchNormLayer->m_Beta->GetTensorInfo(), batchNormLayer->m_Beta->Map(true));
49  ConstTensor gammaTensor(batchNormLayer->m_Gamma->GetTensorInfo(), batchNormLayer->m_Gamma->Map(true));
50  ConstTensor meanTensor(batchNormLayer->m_Mean->GetTensorInfo(), batchNormLayer->m_Mean->Map(true));
51  ConstTensor varTensor(batchNormLayer->m_Variance->GetTensorInfo(), batchNormLayer->m_Variance->Map(true));
52 
53  auto convDescriptor = convLayer->GetParameters();
54  ConstTensor weightsTensor;
55  ARMNN_ASSERT_MSG(convLayer->GetInputSlots()[1].GetConnection() != nullptr,
56  "FuseBatchNorm: Weight data should not be null.");
57 
58  ConstantLayer* weightLayer = PolymorphicDowncast<ConstantLayer*>(
59  &base.GetInputSlot(1).GetConnectedOutputSlot()->GetOwningLayer());
60 
61  weightsTensor = ConstTensor(weightLayer->m_LayerOutput->GetTensorInfo(),
62  weightLayer->m_LayerOutput->Map(true));
63 
64  armnnUtils::DataLayoutIndexed dataLayout(convDescriptor.m_DataLayout);
65  auto weightsShape = weightsTensor.GetInfo().GetShape();
66  const unsigned int inputChannels = parentOut->GetTensorInfo().GetShape()[dataLayout.GetChannelsIndex()];
67  const unsigned int depthMultiplier = depthwise ? weightsShape[3] / inputChannels : 1;
68  const unsigned int outputChannels = depthwise ? weightsShape[3] : weightsShape[0];
69  const unsigned int weightsHeight = depthwise ? weightsShape[1] :
70  weightsShape[dataLayout.GetHeightIndex()];
71  const unsigned int weightsWidth = depthwise ? weightsShape[2] :
72  weightsShape[dataLayout.GetWidthIndex()];
73 
74  const auto* weightsBuffer = static_cast<const T*>(weightsTensor.GetMemoryArea());
75  const auto* betaBuffer = static_cast<const T*>(betaTensor.GetMemoryArea());
76  const auto* gammaBuffer = static_cast<const T*>(gammaTensor.GetMemoryArea());
77  const auto* meanBuffer = static_cast<const T*>(meanTensor.GetMemoryArea());
78  const auto* varBuffer = static_cast<const T*>(varTensor.GetMemoryArea());
79 
80  std::vector<T> weightsVector (weightsBuffer, weightsBuffer + weightsTensor.GetNumElements());
81  std::vector<T> betaVector (betaBuffer, betaBuffer + betaTensor.GetNumElements());
82  std::vector<T> gammaVector (gammaBuffer, gammaBuffer + gammaTensor.GetNumElements());
83  std::vector<T> meanVector (meanBuffer, meanBuffer + meanTensor.GetNumElements());
84  std::vector<T> varianceVector(varBuffer, varBuffer + varTensor.GetNumElements());
85 
86  // fusedWeights = ( gamma * weights ) / ( std - epsilon);
87  std::vector<T> fusedWeightsVector(weightsVector.size());
88 
89  for (unsigned int cInput = 0; cInput < inputChannels; ++cInput)
90  {
91  for (unsigned int cOut = 0; cOut < outputChannels; ++cOut)
92  {
93  T mult = gammaVector[cOut] / static_cast<T>(sqrtf(varianceVector[cOut] + epsilon));
94 
95  for (unsigned int h = 0; h < weightsHeight; ++h)
96  {
97  for (unsigned int w = 0; w < weightsWidth; ++w)
98  {
99  unsigned int weightsIdx = 0;
100 
101  if (depthwise)
102  {
103  cInput = cOut / depthMultiplier;
104  weightsIdx = w * outputChannels + cOut +
105  h * weightsWidth * outputChannels;
106  }
107  else if (convDescriptor.m_DataLayout == DataLayout::NHWC)
108  {
109  weightsIdx = cOut * weightsHeight * weightsWidth * inputChannels +
110  h * weightsWidth * inputChannels +
111  w * inputChannels +
112  cInput;
113  }
114  else
115  {
116  weightsIdx = cOut * weightsWidth * weightsHeight * inputChannels +
117  cInput * weightsWidth * weightsHeight +
118  h * weightsWidth +
119  w;
120  }
121  fusedWeightsVector[weightsIdx] = mult * weightsVector[weightsIdx];
122  }
123  }
124  }
125  }
126  ConstTensor fusedWeightsTensor(weightsTensor.GetInfo(), fusedWeightsVector);
127 
128  // fusedBias = (gamma * (bias - mean)) / (variance - epsilon) + beta;
129  std::vector<T> fusedBiasVector(outputChannels);
130  bool biasWasEnabledBeforeOpt = convDescriptor.m_BiasEnabled;
131  if (biasWasEnabledBeforeOpt)
132  {
133  ConstTensor biasTensor;
134  ARMNN_ASSERT_MSG(convLayer->GetInputSlots()[2].GetConnection() != nullptr,
135  "FuseBatchNorm: Bias data should not be null if bias is enabled.");
136 
137  ConstantLayer* biasLayer = PolymorphicDowncast<ConstantLayer*>(
138  &base.GetInputSlot(2).GetConnectedOutputSlot()->GetOwningLayer());
139 
140  biasTensor = ConstTensor(biasLayer->m_LayerOutput->GetTensorInfo(),
141  biasLayer->m_LayerOutput->Map(true));
142 
143  const auto* biasBuffer = static_cast<const T*>(biasTensor.GetMemoryArea());
144  std::vector<T> biasVector(biasBuffer, biasBuffer + biasTensor.GetNumElements());
145 
146  for (unsigned int cOut = 0; cOut < outputChannels; ++cOut)
147  {
148  fusedBiasVector[cOut] = ((gammaVector[cOut] * (biasVector[cOut] - meanVector[cOut])) /
149  sqrtf(varianceVector[cOut] + epsilon)) + betaVector[cOut];
150  }
151  }
152  else
153  {
154  convDescriptor.m_BiasEnabled = true;
155  std::vector<T> biasVector(outputChannels, T(0));
156 
157  for (unsigned int cOut = 0; cOut < outputChannels; ++cOut)
158  {
159  fusedBiasVector[cOut] = ((gammaVector[cOut] * (biasVector[cOut] - meanVector[cOut])) /
160  sqrtf(varianceVector[cOut] + epsilon)) + betaVector[cOut];
161  }
162  }
163  ConstTensor fusedBiasTensor(TensorInfo({outputChannels}, ArmnnType, 0.0f, 0, true), fusedBiasVector);
164 
165  // Insert the new convolution layer that has batch norm parameters fused into
166  const std::string name = std::string("fused-") + child.GetName() + std::string("-into-") + base.GetName();
167  auto& newConv2dLayer = *graph.InsertNewLayer<ConvLayer>(base.GetInputSlot(0),
168  convDescriptor,
169  name.c_str());
170  newConv2dLayer.m_Weight = std::make_unique<ScopedTensorHandle>(fusedWeightsTensor);
171  newConv2dLayer.m_Bias = std::make_unique<ScopedTensorHandle>(ConstTensor(fusedBiasTensor));
172 
173  // Connect weights and bias from old to new Conv2d layer
174  // This optimization will always have 3 input slots on the Conv2d base layer
175  if (newConv2dLayer.GetNumInputSlots() > 1)
176  {
177  // Remove old connection and connect to new layer2d
178  weightLayer->GetOutputSlot(0).Disconnect(base.GetInputSlot(1));
179  weightLayer->GetOutputSlot(0).Connect(newConv2dLayer.GetInputSlot(1));
180  weightLayer->m_LayerOutput = newConv2dLayer.m_Weight;
181 
182  // Move bias const layers as normal if it was enabled before the optimisation
183  ConstantLayer* biasLayer;
184  if (biasWasEnabledBeforeOpt)
185  {
186  biasLayer = PolymorphicDowncast<ConstantLayer*>(
187  &base.GetInputSlot(2).GetConnectedOutputSlot()->GetOwningLayer());
188  // Remove old connection and connect to new layer2d
189  biasLayer->GetOutputSlot(0).Disconnect(base.GetInputSlot(2));
190  biasLayer->GetOutputSlot(0).Connect(newConv2dLayer.GetInputSlot(2));
191 
192  }
193  // Otherwise create a new bias layer and add to the new convolution2d
194  else
195  {
196  // Add in bias constant layer
197  biasLayer = graph.AddLayer<ConstantLayer>("Bias");
198  biasLayer->GetOutputSlot(0).SetTensorInfo(fusedBiasTensor.GetInfo());
199  biasLayer->GetOutputSlot(0).Connect(newConv2dLayer.GetInputSlot(2));
200  }
201  biasLayer->m_LayerOutput = newConv2dLayer.m_Bias;
202  }
203 
204 
205  // Reconnects with original parent.
206  newConv2dLayer.GetOutputSlot().MoveAllConnections(*parentOut);
207  // Parent is now the new convolution2d layer.
208  parentOut = &newConv2dLayer.GetOutputSlot();
209 
210  // Moves connections in child output to parent layer.
211  // Child layer will be removed as it's left unconnected.
212  // Base layer will be removed if left unconnected.
213  child.GetOutputSlot().MoveAllConnections(*parentOut);
214  }
215  }
void IgnoreUnused(Ts &&...)
#define ARMNN_ASSERT_MSG(COND, MSG)
Definition: Assert.hpp:15
Provides access to the appropriate indexes for Channels, Height and Width based on DataLayout...
#define ARMNN_ASSERT(COND)
Definition: Assert.hpp:14

The documentation for this class was generated from the following file: