aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/optimizations/FuseBatchNorm.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/optimizations/FuseBatchNorm.hpp')
-rw-r--r--src/armnn/optimizations/FuseBatchNorm.hpp84
1 files changed, 75 insertions, 9 deletions
diff --git a/src/armnn/optimizations/FuseBatchNorm.hpp b/src/armnn/optimizations/FuseBatchNorm.hpp
index 66f722a8ef..6a50fc4a0c 100644
--- a/src/armnn/optimizations/FuseBatchNorm.hpp
+++ b/src/armnn/optimizations/FuseBatchNorm.hpp
@@ -50,12 +50,28 @@ public:
ConstTensor meanTensor(batchNormLayer->m_Mean->GetTensorInfo(), batchNormLayer->m_Mean->Map(true));
ConstTensor varTensor(batchNormLayer->m_Variance->GetTensorInfo(), batchNormLayer->m_Variance->Map(true));
- auto convDescriptor = convLayer->GetParameters();
- auto weightsInfo(convLayer->m_Weight->GetTensorInfo());
- ConstTensor weightsTensor(weightsInfo, convLayer->m_Weight->Map(true));
+ auto convDescriptor = convLayer->GetParameters();
+ ConstTensor weightsTensor;
+ if (convLayer->GetNumInputSlots() > 1)
+ {
+ ARMNN_ASSERT_MSG(convLayer->GetInputSlots()[1].GetConnection() != nullptr,
+ "FuseBatchNorm: Weight data should not be null.");
+ InputSlot & oldSlotWeights = const_cast<InputSlot&>(convLayer->GetInputSlots()[1]);
+ OutputSlot & constantSlotWeights = const_cast<OutputSlot&>(*oldSlotWeights.GetConnectedOutputSlot());
+ ConstantLayer* weightLayer = PolymorphicDowncast<ConstantLayer*>(
+ &constantSlotWeights.GetOwningLayer());
+ weightsTensor = ConstTensor(weightLayer->m_LayerOutput->GetTensorInfo(),
+ weightLayer->m_LayerOutput->Map(true));
+ }
+ else
+ {
+ ARMNN_ASSERT_MSG(convLayer->m_Weight != nullptr,
+ "FuseBatchNorm: Bias data should not be null if bias is enabled.");
+ weightsTensor = ConstTensor(convLayer->m_Weight->GetTensorInfo(), convLayer->m_Weight->Map(true));
+ }
armnnUtils::DataLayoutIndexed dataLayout(convDescriptor.m_DataLayout);
- auto weightsShape = weightsInfo.GetShape();
+ auto weightsShape = weightsTensor.GetInfo().GetShape();
const unsigned int inputChannels = parentOut->GetTensorInfo().GetShape()[dataLayout.GetChannelsIndex()];
const unsigned int depthMultiplier = depthwise ? weightsShape[3] / inputChannels : 1;
const unsigned int outputChannels = depthwise ? weightsShape[3] : weightsShape[0];
@@ -116,16 +132,32 @@ public:
}
}
}
- ConstTensor fusedWeightsTensor(weightsInfo, fusedWeightsVector);
+ ConstTensor fusedWeightsTensor(weightsTensor.GetInfo(), fusedWeightsVector);
// fusedBias = (gamma * (bias - mean)) / (variance - epsilon) + beta;
std::vector<T> fusedBiasVector(outputChannels);
- if (convDescriptor.m_BiasEnabled)
+ bool biasWasEnabledBeforeOpt = convDescriptor.m_BiasEnabled;
+ if (biasWasEnabledBeforeOpt)
{
- ARMNN_ASSERT_MSG(convLayer->m_Bias != nullptr,
- "FuseBatchNorm: Bias data should not be null if bias is enabled.");
+ ConstTensor biasTensor;
+ if (convLayer->GetNumInputSlots() > 1)
+ {
+ ARMNN_ASSERT_MSG(convLayer->GetInputSlots()[2].GetConnection() != nullptr,
+ "FuseBatchNorm: Bias data should not be null if bias is enabled.");
+ InputSlot & oldSlotBias = const_cast<InputSlot&>(convLayer->GetInputSlots()[2]);
+ OutputSlot & constantSlotBias = const_cast<OutputSlot&>(*oldSlotBias.GetConnectedOutputSlot());
+ ConstantLayer* biasLayer = PolymorphicDowncast<ConstantLayer*>(
+ &constantSlotBias.GetOwningLayer());
+ biasTensor = ConstTensor(biasLayer->m_LayerOutput->GetTensorInfo(),
+ biasLayer->m_LayerOutput->Map(true));
+ }
+ else
+ {
+ ARMNN_ASSERT_MSG(convLayer->m_Bias != nullptr,
+ "FuseBatchNorm: Bias data should not be null if bias is enabled.");
+ biasTensor = ConstTensor(convLayer->m_Bias->GetTensorInfo(), convLayer->m_Bias->Map(true));
+ }
- ConstTensor biasTensor(convLayer->m_Bias->GetTensorInfo(), convLayer->m_Bias->Map(true));
const auto* biasBuffer = static_cast<const T*>(biasTensor.GetMemoryArea());
std::vector<T> biasVector(biasBuffer, biasBuffer + biasTensor.GetNumElements());
@@ -156,6 +188,40 @@ public:
newConv2dLayer.m_Weight = std::make_unique<ScopedTensorHandle>(fusedWeightsTensor);
newConv2dLayer.m_Bias = std::make_unique<ScopedTensorHandle>(ConstTensor(fusedBiasTensor));
+ // Connect weights and bias from old to new Conv2d layer
+ // This optimization will always have 3 input slots on the Conv2d base layer
+ if (newConv2dLayer.GetNumInputSlots() > 1)
+ {
+ ConstantLayer* weightLayer = PolymorphicDowncast<ConstantLayer*>(
+ &base.GetInputSlot(1).GetConnectedOutputSlot()->GetOwningLayer());
+ // Remove old connection and connect to new layer2d
+ weightLayer->GetOutputSlot(0).Disconnect(base.GetInputSlot(1));
+ weightLayer->GetOutputSlot(0).Connect(newConv2dLayer.GetInputSlot(1));
+ weightLayer->m_LayerOutput = newConv2dLayer.m_Weight;
+
+ // Move bias const layers as normal if it was enabled before the optimisation
+ ConstantLayer* biasLayer;
+ if (biasWasEnabledBeforeOpt)
+ {
+ biasLayer = PolymorphicDowncast<ConstantLayer*>(
+ &base.GetInputSlot(2).GetConnectedOutputSlot()->GetOwningLayer());
+ // Remove old connection and connect to new layer2d
+ biasLayer->GetOutputSlot(0).Disconnect(base.GetInputSlot(2));
+ biasLayer->GetOutputSlot(0).Connect(newConv2dLayer.GetInputSlot(2));
+
+ }
+ // Otherwise create a new bias layer and add to the new convolution2d
+ else
+ {
+ // Add in bias constant layer
+ biasLayer = graph.AddLayer<ConstantLayer>("Bias");
+ biasLayer->GetOutputSlot(0).SetTensorInfo(fusedBiasTensor.GetInfo());
+ biasLayer->GetOutputSlot(0).Connect(newConv2dLayer.GetInputSlot(2));
+ }
+ biasLayer->m_LayerOutput = newConv2dLayer.m_Bias;
+ }
+
+
// Reconnects with original parent.
newConv2dLayer.GetOutputSlot().MoveAllConnections(*parentOut);
// Parent is now the new convolution2d layer.