ArmNN
 21.02
FuseBatchNorm< ConvLayer, ArmnnType, T > Class Template Reference

#include <FuseBatchNorm.hpp>

Public Member Functions

void Run (Graph &graph, InputSlot &connection) const
 Run for every exclusive connection between any base Convolution layer and a child BatchNorm layer for not quantized layers. More...
 

Protected Member Functions

 FuseBatchNorm ()=default
 
 ~FuseBatchNorm ()=default
 

Detailed Description

template<typename ConvLayer, armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
class armnn::optimizations::FuseBatchNorm< ConvLayer, ArmnnType, T >

Definition at line 19 of file FuseBatchNorm.hpp.

Constructor & Destructor Documentation

◆ FuseBatchNorm()

◆ ~FuseBatchNorm()

~FuseBatchNorm ( )
protecteddefault

Member Function Documentation

◆ Run()

void Run ( Graph graph,
InputSlot connection 
) const
inline

Run for every exclusive connection between any base Convolution layer and a child BatchNorm layer for not quantized layers.

The child will be removed, the base will be removed if it's left unconnected. A new Convolution layer will be added, its weights and bias will be calculated using the weights and bias of the base Convolution layer combined with the parameters of the child BatchNorm layer.

Definition at line 27 of file FuseBatchNorm.hpp.

References ARMNN_ASSERT, ARMNN_ASSERT_MSG, armnn::BatchNormalization, armnn::Convolution2d, armnn::DepthwiseConvolution2d, FuseBatchNorm< ConvLayer, ArmnnType, T >::FuseBatchNorm(), InputSlot::GetConnectedOutputSlot(), Layer::GetDataType(), Layer::GetInputSlot(), Layer::GetName(), Layer::GetOutputSlot(), InputSlot::GetOwningLayer(), OutputSlot::GetOwningLayer(), Layer::GetType(), armnn::IgnoreUnused(), Graph::InsertNewLayer(), BatchNormalizationDescriptor::m_Eps, OutputSlot::MoveAllConnections(), armnn::NHWC, and FuseBatchNorm< ConvLayer, ArmnnType, T >::~FuseBatchNorm().

28  {
29  Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
30  Layer& child = connection.GetOwningLayer();
31 
32  bool depthwise = (base.GetType() == LayerType::DepthwiseConvolution2d);
33 
34  ARMNN_ASSERT(base.GetType() == LayerType::Convolution2d || depthwise);
35  ARMNN_ASSERT(child.GetType() == LayerType::BatchNormalization);
36 
37  if (base.GetDataType() == ArmnnType && child.GetDataType() == ArmnnType)
38  {
39  OutputSlot* parentOut = base.GetInputSlot(0).GetConnectedOutputSlot();
40  auto convLayer = PolymorphicDowncast<ConvLayer*>(&base);
41  auto batchNormLayer = PolymorphicDowncast<BatchNormalizationLayer*>(&child);
42 
43  // Read convolution and batch norm parameters
44  BatchNormalizationDescriptor batchNormDescriptor = batchNormLayer->GetParameters();
45  auto epsilon = batchNormDescriptor.m_Eps;
46  IgnoreUnused(epsilon);
47 
48  ConstTensor betaTensor(batchNormLayer->m_Beta->GetTensorInfo(), batchNormLayer->m_Beta->Map(true));
49  ConstTensor gammaTensor(batchNormLayer->m_Gamma->GetTensorInfo(), batchNormLayer->m_Gamma->Map(true));
50  ConstTensor meanTensor(batchNormLayer->m_Mean->GetTensorInfo(), batchNormLayer->m_Mean->Map(true));
51  ConstTensor varTensor(batchNormLayer->m_Variance->GetTensorInfo(), batchNormLayer->m_Variance->Map(true));
52 
53  auto convDescriptor = convLayer->GetParameters();
54  auto weightsInfo(convLayer->m_Weight->GetTensorInfo());
55  ConstTensor weightsTensor(weightsInfo, convLayer->m_Weight->Map(true));
56 
57  armnnUtils::DataLayoutIndexed dataLayout(convDescriptor.m_DataLayout);
58  auto weightsShape = weightsInfo.GetShape();
59  const unsigned int depthMultiplier = depthwise ? weightsShape[0] : 1;
60  const unsigned int inputChannels = depthwise ? weightsShape[1] :
61  weightsShape[dataLayout.GetChannelsIndex()];
62  const unsigned int outputChannels = depthwise ? inputChannels * depthMultiplier : weightsShape[0];
63  const unsigned int weightsHeight = depthwise ? weightsShape[2] :
64  weightsShape[dataLayout.GetHeightIndex()];
65  const unsigned int weightsWidth = depthwise ? weightsShape[3] :
66  weightsShape[dataLayout.GetWidthIndex()];
67 
68  const auto* weightsBuffer = static_cast<const T*>(weightsTensor.GetMemoryArea());
69  const auto* betaBuffer = static_cast<const T*>(betaTensor.GetMemoryArea());
70  const auto* gammaBuffer = static_cast<const T*>(gammaTensor.GetMemoryArea());
71  const auto* meanBuffer = static_cast<const T*>(meanTensor.GetMemoryArea());
72  const auto* varBuffer = static_cast<const T*>(varTensor.GetMemoryArea());
73 
74  std::vector<T> weightsVector (weightsBuffer, weightsBuffer + weightsTensor.GetNumElements());
75  std::vector<T> betaVector (betaBuffer, betaBuffer + betaTensor.GetNumElements());
76  std::vector<T> gammaVector (gammaBuffer, gammaBuffer + gammaTensor.GetNumElements());
77  std::vector<T> meanVector (meanBuffer, meanBuffer + meanTensor.GetNumElements());
78  std::vector<T> varianceVector(varBuffer, varBuffer + varTensor.GetNumElements());
79 
80  // fusedWeights = ( gamma * weights ) / ( std - epsilon);
81  std::vector<T> fusedWeightsVector(weightsVector.size());
82  unsigned int depthwiseMultiplierIdx = 0;
83 
84  for (unsigned int cInput = 0; cInput < inputChannels; ++cInput)
85  {
86  for (unsigned int cOut = 0; cOut < outputChannels; ++cOut)
87  {
88  T mult = gammaVector[cOut] / static_cast<T>(sqrtf (varianceVector[cOut] + epsilon));
89 
90  if (depthwise)
91  {
92  cInput = cOut / depthMultiplier;
93  depthwiseMultiplierIdx = cOut % depthMultiplier;
94  }
95 
96  for (unsigned int h = 0; h < weightsHeight; ++h)
97  {
98  for (unsigned int w = 0; w < weightsWidth; ++w)
99  {
100  unsigned int weightsIdx = 0;
101 
102  if (depthwise)
103  {
104  weightsIdx = depthwiseMultiplierIdx * weightsWidth * weightsHeight * inputChannels +
105  cInput * weightsWidth * weightsHeight +
106  h * weightsWidth +
107  w;
108  }
109  else if (convDescriptor.m_DataLayout == DataLayout::NHWC)
110  {
111  weightsIdx = cOut * weightsHeight * weightsWidth * inputChannels +
112  h * weightsWidth * inputChannels +
113  w * inputChannels +
114  cInput;
115  }
116  else
117  {
118  weightsIdx = cOut * weightsWidth * weightsHeight * inputChannels +
119  cInput * weightsWidth * weightsHeight +
120  h * weightsWidth +
121  w;
122  }
123  fusedWeightsVector[weightsIdx] = mult * weightsVector[weightsIdx];
124  }
125  }
126  }
127  }
128  ConstTensor fusedWeightsTensor(weightsInfo, fusedWeightsVector);
129 
130  // fusedBias = (gamma * (bias - mean)) / (variance - epsilon) + beta;
131  std::vector<T> fusedBiasVector(outputChannels);
132  if (convDescriptor.m_BiasEnabled)
133  {
134  ARMNN_ASSERT_MSG(convLayer->m_Bias != nullptr,
135  "FuseBatchNorm: Bias data should not be null if bias is enabled.");
136 
137  ConstTensor biasTensor(convLayer->m_Bias->GetTensorInfo(), convLayer->m_Bias->Map(true));
138  const auto* biasBuffer = static_cast<const T*>(biasTensor.GetMemoryArea());
139  std::vector<T> biasVector(biasBuffer, biasBuffer + biasTensor.GetNumElements());
140 
141  for (unsigned int cOut = 0; cOut < outputChannels; ++cOut)
142  {
143  fusedBiasVector[cOut] = ((gammaVector[cOut] * (biasVector[cOut] - meanVector[cOut])) /
144  sqrtf(varianceVector[cOut] + epsilon)) + betaVector[cOut];
145  }
146  }
147  else
148  {
149  convDescriptor.m_BiasEnabled = true;
150  std::vector<T> biasVector(outputChannels, T(0));
151 
152  for (unsigned int cOut = 0; cOut < outputChannels; ++cOut)
153  {
154  fusedBiasVector[cOut] = ((gammaVector[cOut] * (biasVector[cOut] - meanVector[cOut])) /
155  sqrtf(varianceVector[cOut] + epsilon)) + betaVector[cOut];
156  }
157  }
158  ConstTensor fusedBiasTensor(TensorInfo({outputChannels}, ArmnnType), fusedBiasVector);
159 
160  // Insert the new convolution layer that has batch norm parameters fused into
161  const std::string name = std::string("fused-") + child.GetName() + std::string("-into-") + base.GetName();
162  auto& newConv2dLayer = *graph.InsertNewLayer<ConvLayer>(base.GetInputSlot(0),
163  convDescriptor,
164  name.c_str());
165  newConv2dLayer.m_Weight = std::make_unique<ScopedCpuTensorHandle>(fusedWeightsTensor);
166  newConv2dLayer.m_Bias = std::make_unique<ScopedCpuTensorHandle>(ConstTensor(fusedBiasTensor));
167 
168  // Reconnects with original parent.
169  newConv2dLayer.GetOutputSlot().MoveAllConnections(*parentOut);
170  // Parent is now the new convolution2d layer.
171  parentOut = &newConv2dLayer.GetOutputSlot();
172 
173  // Moves connections in child output to parent layer.
174  // Child layer will be removed as it's left unconnected.
175  // Base layer will be removed if left unconnected.
176  child.GetOutputSlot().MoveAllConnections(*parentOut);
177  }
178  }
void IgnoreUnused(Ts &&...)
#define ARMNN_ASSERT_MSG(COND, MSG)
Definition: Assert.hpp:15
Provides access to the appropriate indexes for Channels, Height and Width based on DataLayout...
#define ARMNN_ASSERT(COND)
Definition: Assert.hpp:14

The documentation for this class was generated from the following file: