diff options
Diffstat (limited to 'src/armnn/optimizations/FuseBatchNorm.hpp')
-rw-r--r-- | src/armnn/optimizations/FuseBatchNorm.hpp | 84 |
1 files changed, 75 insertions, 9 deletions
diff --git a/src/armnn/optimizations/FuseBatchNorm.hpp b/src/armnn/optimizations/FuseBatchNorm.hpp index 66f722a8ef..6a50fc4a0c 100644 --- a/src/armnn/optimizations/FuseBatchNorm.hpp +++ b/src/armnn/optimizations/FuseBatchNorm.hpp @@ -50,12 +50,28 @@ public: ConstTensor meanTensor(batchNormLayer->m_Mean->GetTensorInfo(), batchNormLayer->m_Mean->Map(true)); ConstTensor varTensor(batchNormLayer->m_Variance->GetTensorInfo(), batchNormLayer->m_Variance->Map(true)); - auto convDescriptor = convLayer->GetParameters(); - auto weightsInfo(convLayer->m_Weight->GetTensorInfo()); - ConstTensor weightsTensor(weightsInfo, convLayer->m_Weight->Map(true)); + auto convDescriptor = convLayer->GetParameters(); + ConstTensor weightsTensor; + if (convLayer->GetNumInputSlots() > 1) + { + ARMNN_ASSERT_MSG(convLayer->GetInputSlots()[1].GetConnection() != nullptr, + "FuseBatchNorm: Weight data should not be null."); + InputSlot & oldSlotWeights = const_cast<InputSlot&>(convLayer->GetInputSlots()[1]); + OutputSlot & constantSlotWeights = const_cast<OutputSlot&>(*oldSlotWeights.GetConnectedOutputSlot()); + ConstantLayer* weightLayer = PolymorphicDowncast<ConstantLayer*>( + &constantSlotWeights.GetOwningLayer()); + weightsTensor = ConstTensor(weightLayer->m_LayerOutput->GetTensorInfo(), + weightLayer->m_LayerOutput->Map(true)); + } + else + { + ARMNN_ASSERT_MSG(convLayer->m_Weight != nullptr, + "FuseBatchNorm: Bias data should not be null if bias is enabled."); + weightsTensor = ConstTensor(convLayer->m_Weight->GetTensorInfo(), convLayer->m_Weight->Map(true)); + } armnnUtils::DataLayoutIndexed dataLayout(convDescriptor.m_DataLayout); - auto weightsShape = weightsInfo.GetShape(); + auto weightsShape = weightsTensor.GetInfo().GetShape(); const unsigned int inputChannels = parentOut->GetTensorInfo().GetShape()[dataLayout.GetChannelsIndex()]; const unsigned int depthMultiplier = depthwise ? weightsShape[3] / inputChannels : 1; const unsigned int outputChannels = depthwise ? weightsShape[3] : weightsShape[0]; @@ -116,16 +132,32 @@ public: } } } - ConstTensor fusedWeightsTensor(weightsInfo, fusedWeightsVector); + ConstTensor fusedWeightsTensor(weightsTensor.GetInfo(), fusedWeightsVector); // fusedBias = (gamma * (bias - mean)) / (variance - epsilon) + beta; std::vector<T> fusedBiasVector(outputChannels); - if (convDescriptor.m_BiasEnabled) + bool biasWasEnabledBeforeOpt = convDescriptor.m_BiasEnabled; + if (biasWasEnabledBeforeOpt) { - ARMNN_ASSERT_MSG(convLayer->m_Bias != nullptr, - "FuseBatchNorm: Bias data should not be null if bias is enabled."); + ConstTensor biasTensor; + if (convLayer->GetNumInputSlots() > 1) + { + ARMNN_ASSERT_MSG(convLayer->GetInputSlots()[2].GetConnection() != nullptr, + "FuseBatchNorm: Bias data should not be null if bias is enabled."); + InputSlot & oldSlotBias = const_cast<InputSlot&>(convLayer->GetInputSlots()[2]); + OutputSlot & constantSlotBias = const_cast<OutputSlot&>(*oldSlotBias.GetConnectedOutputSlot()); + ConstantLayer* biasLayer = PolymorphicDowncast<ConstantLayer*>( + &constantSlotBias.GetOwningLayer()); + biasTensor = ConstTensor(biasLayer->m_LayerOutput->GetTensorInfo(), + biasLayer->m_LayerOutput->Map(true)); + } + else + { + ARMNN_ASSERT_MSG(convLayer->m_Bias != nullptr, + "FuseBatchNorm: Bias data should not be null if bias is enabled."); + biasTensor = ConstTensor(convLayer->m_Bias->GetTensorInfo(), convLayer->m_Bias->Map(true)); + } - ConstTensor biasTensor(convLayer->m_Bias->GetTensorInfo(), convLayer->m_Bias->Map(true)); const auto* biasBuffer = static_cast<const T*>(biasTensor.GetMemoryArea()); std::vector<T> biasVector(biasBuffer, biasBuffer + biasTensor.GetNumElements()); @@ -156,6 +188,40 @@ public: newConv2dLayer.m_Weight = std::make_unique<ScopedTensorHandle>(fusedWeightsTensor); newConv2dLayer.m_Bias = std::make_unique<ScopedTensorHandle>(ConstTensor(fusedBiasTensor)); + // Connect weights and bias from old to new Conv2d layer + // This optimization will always have 3 input slots on the Conv2d base layer + if (newConv2dLayer.GetNumInputSlots() > 1) + { + ConstantLayer* weightLayer = PolymorphicDowncast<ConstantLayer*>( + &base.GetInputSlot(1).GetConnectedOutputSlot()->GetOwningLayer()); + // Remove old connection and connect to new layer2d + weightLayer->GetOutputSlot(0).Disconnect(base.GetInputSlot(1)); + weightLayer->GetOutputSlot(0).Connect(newConv2dLayer.GetInputSlot(1)); + weightLayer->m_LayerOutput = newConv2dLayer.m_Weight; + + // Move bias const layers as normal if it was enabled before the optimisation + ConstantLayer* biasLayer; + if (biasWasEnabledBeforeOpt) + { + biasLayer = PolymorphicDowncast<ConstantLayer*>( + &base.GetInputSlot(2).GetConnectedOutputSlot()->GetOwningLayer()); + // Remove old connection and connect to new layer2d + biasLayer->GetOutputSlot(0).Disconnect(base.GetInputSlot(2)); + biasLayer->GetOutputSlot(0).Connect(newConv2dLayer.GetInputSlot(2)); + + } + // Otherwise create a new bias layer and add to the new convolution2d + else + { + // Add in bias constant layer + biasLayer = graph.AddLayer<ConstantLayer>("Bias"); + biasLayer->GetOutputSlot(0).SetTensorInfo(fusedBiasTensor.GetInfo()); + biasLayer->GetOutputSlot(0).Connect(newConv2dLayer.GetInputSlot(2)); + } + biasLayer->m_LayerOutput = newConv2dLayer.m_Bias; + } + + // Reconnects with original parent. newConv2dLayer.GetOutputSlot().MoveAllConnections(*parentOut); // Parent is now the new convolution2d layer. |