Run for every exclusive connection between any base Convolution layer and a child BatchNorm layer for not quantized layers.
The child will be removed, the base will be removed if it's left unconnected. A new Convolution layer will be added, its weights and bias will be calculated using the weights and bias of the base Convolution layer combined with the parameters of the child BatchNorm layer.
29 Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
30 Layer& child = connection.GetOwningLayer();
37 if (base.GetDataType() == ArmnnType && child.GetDataType() == ArmnnType)
39 OutputSlot* parentOut = base.GetInputSlot(0).GetConnectedOutputSlot();
40 auto convLayer = PolymorphicDowncast<ConvLayer*>(&base);
41 auto batchNormLayer = PolymorphicDowncast<BatchNormalizationLayer*>(&child);
44 BatchNormalizationDescriptor batchNormDescriptor = batchNormLayer->GetParameters();
45 auto epsilon = batchNormDescriptor.m_Eps;
48 ConstTensor betaTensor(batchNormLayer->m_Beta->GetTensorInfo(), batchNormLayer->m_Beta->Map(
true));
49 ConstTensor gammaTensor(batchNormLayer->m_Gamma->GetTensorInfo(), batchNormLayer->m_Gamma->Map(
true));
50 ConstTensor meanTensor(batchNormLayer->m_Mean->GetTensorInfo(), batchNormLayer->m_Mean->Map(
true));
51 ConstTensor varTensor(batchNormLayer->m_Variance->GetTensorInfo(), batchNormLayer->m_Variance->Map(
true));
53 auto convDescriptor = convLayer->GetParameters();
54 auto weightsInfo(convLayer->m_Weight->GetTensorInfo());
55 ConstTensor weightsTensor(weightsInfo, convLayer->m_Weight->Map(
true));
58 auto weightsShape = weightsInfo.GetShape();
59 const unsigned int inputChannels = parentOut->GetTensorInfo().GetShape()[dataLayout.GetChannelsIndex()];
60 const unsigned int depthMultiplier = depthwise ? weightsShape[3] / inputChannels : 1;
61 const unsigned int outputChannels = depthwise ? weightsShape[3] : weightsShape[0];
62 const unsigned int weightsHeight = depthwise ? weightsShape[1] :
63 weightsShape[dataLayout.GetHeightIndex()];
64 const unsigned int weightsWidth = depthwise ? weightsShape[2] :
65 weightsShape[dataLayout.GetWidthIndex()];
67 const auto* weightsBuffer =
static_cast<const T*
>(weightsTensor.GetMemoryArea());
68 const auto* betaBuffer =
static_cast<const T*
>(betaTensor.GetMemoryArea());
69 const auto* gammaBuffer =
static_cast<const T*
>(gammaTensor.GetMemoryArea());
70 const auto* meanBuffer =
static_cast<const T*
>(meanTensor.GetMemoryArea());
71 const auto* varBuffer =
static_cast<const T*
>(varTensor.GetMemoryArea());
73 std::vector<T> weightsVector (weightsBuffer, weightsBuffer + weightsTensor.GetNumElements());
74 std::vector<T> betaVector (betaBuffer, betaBuffer + betaTensor.GetNumElements());
75 std::vector<T> gammaVector (gammaBuffer, gammaBuffer + gammaTensor.GetNumElements());
76 std::vector<T> meanVector (meanBuffer, meanBuffer + meanTensor.GetNumElements());
77 std::vector<T> varianceVector(varBuffer, varBuffer + varTensor.GetNumElements());
80 std::vector<T> fusedWeightsVector(weightsVector.size());
82 for (
unsigned int cInput = 0; cInput < inputChannels; ++cInput)
84 for (
unsigned int cOut = 0; cOut < outputChannels; ++cOut)
86 T mult = gammaVector[cOut] /
static_cast<T
>(sqrtf (varianceVector[cOut] + epsilon));
88 for (
unsigned int h = 0; h < weightsHeight; ++h)
90 for (
unsigned int w = 0; w < weightsWidth; ++w)
92 unsigned int weightsIdx = 0;
96 cInput = cOut / depthMultiplier;
97 weightsIdx = w * outputChannels + cOut +
98 h * weightsWidth * outputChannels;
102 weightsIdx = cOut * weightsHeight * weightsWidth * inputChannels +
103 h * weightsWidth * inputChannels +
109 weightsIdx = cOut * weightsWidth * weightsHeight * inputChannels +
110 cInput * weightsWidth * weightsHeight +
114 fusedWeightsVector[weightsIdx] = mult * weightsVector[weightsIdx];
119 ConstTensor fusedWeightsTensor(weightsInfo, fusedWeightsVector);
122 std::vector<T> fusedBiasVector(outputChannels);
123 if (convDescriptor.m_BiasEnabled)
126 "FuseBatchNorm: Bias data should not be null if bias is enabled.");
128 ConstTensor biasTensor(convLayer->m_Bias->GetTensorInfo(), convLayer->m_Bias->Map(
true));
129 const auto* biasBuffer =
static_cast<const T*
>(biasTensor.GetMemoryArea());
130 std::vector<T> biasVector(biasBuffer, biasBuffer + biasTensor.GetNumElements());
132 for (
unsigned int cOut = 0; cOut < outputChannels; ++cOut)
134 fusedBiasVector[cOut] = ((gammaVector[cOut] * (biasVector[cOut] - meanVector[cOut])) /
135 sqrtf(varianceVector[cOut] + epsilon)) + betaVector[cOut];
140 convDescriptor.m_BiasEnabled =
true;
141 std::vector<T> biasVector(outputChannels, T(0));
143 for (
unsigned int cOut = 0; cOut < outputChannels; ++cOut)
145 fusedBiasVector[cOut] = ((gammaVector[cOut] * (biasVector[cOut] - meanVector[cOut])) /
146 sqrtf(varianceVector[cOut] + epsilon)) + betaVector[cOut];
149 ConstTensor fusedBiasTensor(TensorInfo({outputChannels}, ArmnnType, 0.0f, 0,
true), fusedBiasVector);
152 const std::string name = std::string(
"fused-") + child.GetName() + std::string(
"-into-") + base.GetName();
153 auto& newConv2dLayer = *graph.InsertNewLayer<ConvLayer>(base.GetInputSlot(0),
156 newConv2dLayer.m_Weight = std::make_unique<ScopedTensorHandle>(fusedWeightsTensor);
157 newConv2dLayer.m_Bias = std::make_unique<ScopedTensorHandle>(ConstTensor(fusedBiasTensor));
160 newConv2dLayer.GetOutputSlot().MoveAllConnections(*parentOut);
162 parentOut = &newConv2dLayer.GetOutputSlot();
167 child.GetOutputSlot().MoveAllConnections(*parentOut);
void IgnoreUnused(Ts &&...)
#define ARMNN_ASSERT_MSG(COND, MSG)
Provides access to the appropriate indexes for Channels, Height and Width based on DataLayout...
#define ARMNN_ASSERT(COND)