ArmNN
 21.08
FuseBatchNorm< ConvLayer, ArmnnType, T > Class Template Reference

#include <FuseBatchNorm.hpp>

Public Member Functions

void Run (Graph &graph, InputSlot &connection) const
 Run for every exclusive connection between any base Convolution layer and a child BatchNorm layer for not quantized layers. More...
 

Protected Member Functions

 FuseBatchNorm ()=default
 
 ~FuseBatchNorm ()=default
 

Detailed Description

template<typename ConvLayer, armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
class armnn::optimizations::FuseBatchNorm< ConvLayer, ArmnnType, T >

Definition at line 19 of file FuseBatchNorm.hpp.

Constructor & Destructor Documentation

◆ FuseBatchNorm()

◆ ~FuseBatchNorm()

~FuseBatchNorm ( )
protecteddefault

Member Function Documentation

◆ Run()

void Run ( Graph graph,
InputSlot connection 
) const
inline

Run for every exclusive connection between any base Convolution layer and a child BatchNorm layer for not quantized layers.

The child will be removed, the base will be removed if it's left unconnected. A new Convolution layer will be added, its weights and bias will be calculated using the weights and bias of the base Convolution layer combined with the parameters of the child BatchNorm layer.

Definition at line 27 of file FuseBatchNorm.hpp.

References ARMNN_ASSERT, ARMNN_ASSERT_MSG, armnn::BatchNormalization, armnn::Convolution2d, armnn::DepthwiseConvolution2d, FuseBatchNorm< ConvLayer, ArmnnType, T >::FuseBatchNorm(), InputSlot::GetConnectedOutputSlot(), Layer::GetDataType(), Layer::GetInputSlot(), Layer::GetName(), Layer::GetOutputSlot(), InputSlot::GetOwningLayer(), OutputSlot::GetOwningLayer(), TensorInfo::GetShape(), OutputSlot::GetTensorInfo(), Layer::GetType(), armnn::IgnoreUnused(), Graph::InsertNewLayer(), BatchNormalizationDescriptor::m_Eps, OutputSlot::MoveAllConnections(), armnn::NHWC, and FuseBatchNorm< ConvLayer, ArmnnType, T >::~FuseBatchNorm().

28  {
29  Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
30  Layer& child = connection.GetOwningLayer();
31 
32  bool depthwise = (base.GetType() == LayerType::DepthwiseConvolution2d);
33 
34  ARMNN_ASSERT(base.GetType() == LayerType::Convolution2d || depthwise);
35  ARMNN_ASSERT(child.GetType() == LayerType::BatchNormalization);
36 
37  if (base.GetDataType() == ArmnnType && child.GetDataType() == ArmnnType)
38  {
39  OutputSlot* parentOut = base.GetInputSlot(0).GetConnectedOutputSlot();
40  auto convLayer = PolymorphicDowncast<ConvLayer*>(&base);
41  auto batchNormLayer = PolymorphicDowncast<BatchNormalizationLayer*>(&child);
42 
43  // Read convolution and batch norm parameters
44  BatchNormalizationDescriptor batchNormDescriptor = batchNormLayer->GetParameters();
45  auto epsilon = batchNormDescriptor.m_Eps;
46  IgnoreUnused(epsilon);
47 
48  ConstTensor betaTensor(batchNormLayer->m_Beta->GetTensorInfo(), batchNormLayer->m_Beta->Map(true));
49  ConstTensor gammaTensor(batchNormLayer->m_Gamma->GetTensorInfo(), batchNormLayer->m_Gamma->Map(true));
50  ConstTensor meanTensor(batchNormLayer->m_Mean->GetTensorInfo(), batchNormLayer->m_Mean->Map(true));
51  ConstTensor varTensor(batchNormLayer->m_Variance->GetTensorInfo(), batchNormLayer->m_Variance->Map(true));
52 
53  auto convDescriptor = convLayer->GetParameters();
54  auto weightsInfo(convLayer->m_Weight->GetTensorInfo());
55  ConstTensor weightsTensor(weightsInfo, convLayer->m_Weight->Map(true));
56 
57  armnnUtils::DataLayoutIndexed dataLayout(convDescriptor.m_DataLayout);
58  auto weightsShape = weightsInfo.GetShape();
59  const unsigned int inputChannels = parentOut->GetTensorInfo().GetShape()[dataLayout.GetChannelsIndex()];
60  const unsigned int depthMultiplier = depthwise ? weightsShape[3] / inputChannels : 1;
61  const unsigned int outputChannels = depthwise ? weightsShape[3] : weightsShape[0];
62  const unsigned int weightsHeight = depthwise ? weightsShape[1] :
63  weightsShape[dataLayout.GetHeightIndex()];
64  const unsigned int weightsWidth = depthwise ? weightsShape[2] :
65  weightsShape[dataLayout.GetWidthIndex()];
66 
67  const auto* weightsBuffer = static_cast<const T*>(weightsTensor.GetMemoryArea());
68  const auto* betaBuffer = static_cast<const T*>(betaTensor.GetMemoryArea());
69  const auto* gammaBuffer = static_cast<const T*>(gammaTensor.GetMemoryArea());
70  const auto* meanBuffer = static_cast<const T*>(meanTensor.GetMemoryArea());
71  const auto* varBuffer = static_cast<const T*>(varTensor.GetMemoryArea());
72 
73  std::vector<T> weightsVector (weightsBuffer, weightsBuffer + weightsTensor.GetNumElements());
74  std::vector<T> betaVector (betaBuffer, betaBuffer + betaTensor.GetNumElements());
75  std::vector<T> gammaVector (gammaBuffer, gammaBuffer + gammaTensor.GetNumElements());
76  std::vector<T> meanVector (meanBuffer, meanBuffer + meanTensor.GetNumElements());
77  std::vector<T> varianceVector(varBuffer, varBuffer + varTensor.GetNumElements());
78 
79  // fusedWeights = ( gamma * weights ) / ( std - epsilon);
80  std::vector<T> fusedWeightsVector(weightsVector.size());
81 
82  for (unsigned int cInput = 0; cInput < inputChannels; ++cInput)
83  {
84  for (unsigned int cOut = 0; cOut < outputChannels; ++cOut)
85  {
86  T mult = gammaVector[cOut] / static_cast<T>(sqrtf (varianceVector[cOut] + epsilon));
87 
88  for (unsigned int h = 0; h < weightsHeight; ++h)
89  {
90  for (unsigned int w = 0; w < weightsWidth; ++w)
91  {
92  unsigned int weightsIdx = 0;
93 
94  if (depthwise)
95  {
96  cInput = cOut / depthMultiplier;
97  weightsIdx = w * outputChannels + cOut +
98  h * weightsWidth * outputChannels;
99  }
100  else if (convDescriptor.m_DataLayout == DataLayout::NHWC)
101  {
102  weightsIdx = cOut * weightsHeight * weightsWidth * inputChannels +
103  h * weightsWidth * inputChannels +
104  w * inputChannels +
105  cInput;
106  }
107  else
108  {
109  weightsIdx = cOut * weightsWidth * weightsHeight * inputChannels +
110  cInput * weightsWidth * weightsHeight +
111  h * weightsWidth +
112  w;
113  }
114  fusedWeightsVector[weightsIdx] = mult * weightsVector[weightsIdx];
115  }
116  }
117  }
118  }
119  ConstTensor fusedWeightsTensor(weightsInfo, fusedWeightsVector);
120 
121  // fusedBias = (gamma * (bias - mean)) / (variance - epsilon) + beta;
122  std::vector<T> fusedBiasVector(outputChannels);
123  if (convDescriptor.m_BiasEnabled)
124  {
125  ARMNN_ASSERT_MSG(convLayer->m_Bias != nullptr,
126  "FuseBatchNorm: Bias data should not be null if bias is enabled.");
127 
128  ConstTensor biasTensor(convLayer->m_Bias->GetTensorInfo(), convLayer->m_Bias->Map(true));
129  const auto* biasBuffer = static_cast<const T*>(biasTensor.GetMemoryArea());
130  std::vector<T> biasVector(biasBuffer, biasBuffer + biasTensor.GetNumElements());
131 
132  for (unsigned int cOut = 0; cOut < outputChannels; ++cOut)
133  {
134  fusedBiasVector[cOut] = ((gammaVector[cOut] * (biasVector[cOut] - meanVector[cOut])) /
135  sqrtf(varianceVector[cOut] + epsilon)) + betaVector[cOut];
136  }
137  }
138  else
139  {
140  convDescriptor.m_BiasEnabled = true;
141  std::vector<T> biasVector(outputChannels, T(0));
142 
143  for (unsigned int cOut = 0; cOut < outputChannels; ++cOut)
144  {
145  fusedBiasVector[cOut] = ((gammaVector[cOut] * (biasVector[cOut] - meanVector[cOut])) /
146  sqrtf(varianceVector[cOut] + epsilon)) + betaVector[cOut];
147  }
148  }
149  ConstTensor fusedBiasTensor(TensorInfo({outputChannels}, ArmnnType), fusedBiasVector);
150 
151  // Insert the new convolution layer that has batch norm parameters fused into
152  const std::string name = std::string("fused-") + child.GetName() + std::string("-into-") + base.GetName();
153  auto& newConv2dLayer = *graph.InsertNewLayer<ConvLayer>(base.GetInputSlot(0),
154  convDescriptor,
155  name.c_str());
156  newConv2dLayer.m_Weight = std::make_unique<ScopedTensorHandle>(fusedWeightsTensor);
157  newConv2dLayer.m_Bias = std::make_unique<ScopedTensorHandle>(ConstTensor(fusedBiasTensor));
158 
159  // Reconnects with original parent.
160  newConv2dLayer.GetOutputSlot().MoveAllConnections(*parentOut);
161  // Parent is now the new convolution2d layer.
162  parentOut = &newConv2dLayer.GetOutputSlot();
163 
164  // Moves connections in child output to parent layer.
165  // Child layer will be removed as it's left unconnected.
166  // Base layer will be removed if left unconnected.
167  child.GetOutputSlot().MoveAllConnections(*parentOut);
168  }
169  }
void IgnoreUnused(Ts &&...)
#define ARMNN_ASSERT_MSG(COND, MSG)
Definition: Assert.hpp:15
Provides access to the appropriate indexes for Channels, Height and Width based on DataLayout...
#define ARMNN_ASSERT(COND)
Definition: Assert.hpp:14

The documentation for this class was generated from the following file: