diff options
Diffstat (limited to 'src/armnn/optimizations/FuseBatchNorm.hpp')
-rw-r--r-- | src/armnn/optimizations/FuseBatchNorm.hpp | 68 |
1 files changed, 24 insertions, 44 deletions
diff --git a/src/armnn/optimizations/FuseBatchNorm.hpp b/src/armnn/optimizations/FuseBatchNorm.hpp index 6a50fc4a0c..bca0c7d00a 100644 --- a/src/armnn/optimizations/FuseBatchNorm.hpp +++ b/src/armnn/optimizations/FuseBatchNorm.hpp @@ -14,8 +14,8 @@ namespace armnn namespace optimizations { -template <typename ConvLayer, armnn::DataType ArmnnType, - typename T = armnn::ResolveType<ArmnnType>> +template<typename ConvLayer, armnn::DataType ArmnnType, + typename T = armnn::ResolveType<ArmnnType>> class FuseBatchNorm { public: @@ -26,7 +26,7 @@ public: /// combined with the parameters of the child BatchNorm layer. void Run(Graph& graph, InputSlot& connection) const { - Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer(); + Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer(); Layer& child = connection.GetOwningLayer(); bool depthwise = (base.GetType() == LayerType::DepthwiseConvolution2d); @@ -37,7 +37,7 @@ public: if (base.GetDataType() == ArmnnType && child.GetDataType() == ArmnnType) { OutputSlot* parentOut = base.GetInputSlot(0).GetConnectedOutputSlot(); - auto convLayer = PolymorphicDowncast<ConvLayer*>(&base); + auto convLayer = PolymorphicDowncast<ConvLayer*>(&base); auto batchNormLayer = PolymorphicDowncast<BatchNormalizationLayer*>(&child); // Read convolution and batch norm parameters @@ -50,25 +50,16 @@ public: ConstTensor meanTensor(batchNormLayer->m_Mean->GetTensorInfo(), batchNormLayer->m_Mean->Map(true)); ConstTensor varTensor(batchNormLayer->m_Variance->GetTensorInfo(), batchNormLayer->m_Variance->Map(true)); - auto convDescriptor = convLayer->GetParameters(); + auto convDescriptor = convLayer->GetParameters(); ConstTensor weightsTensor; - if (convLayer->GetNumInputSlots() > 1) - { - ARMNN_ASSERT_MSG(convLayer->GetInputSlots()[1].GetConnection() != nullptr, - "FuseBatchNorm: Weight data should not be null."); - InputSlot & oldSlotWeights = const_cast<InputSlot&>(convLayer->GetInputSlots()[1]); - OutputSlot & constantSlotWeights = const_cast<OutputSlot&>(*oldSlotWeights.GetConnectedOutputSlot()); - ConstantLayer* weightLayer = PolymorphicDowncast<ConstantLayer*>( - &constantSlotWeights.GetOwningLayer()); - weightsTensor = ConstTensor(weightLayer->m_LayerOutput->GetTensorInfo(), - weightLayer->m_LayerOutput->Map(true)); - } - else - { - ARMNN_ASSERT_MSG(convLayer->m_Weight != nullptr, - "FuseBatchNorm: Bias data should not be null if bias is enabled."); - weightsTensor = ConstTensor(convLayer->m_Weight->GetTensorInfo(), convLayer->m_Weight->Map(true)); - } + ARMNN_ASSERT_MSG(convLayer->GetInputSlots()[1].GetConnection() != nullptr, + "FuseBatchNorm: Weight data should not be null."); + + ConstantLayer* weightLayer = PolymorphicDowncast<ConstantLayer*>( + &base.GetInputSlot(1).GetConnectedOutputSlot()->GetOwningLayer()); + + weightsTensor = ConstTensor(weightLayer->m_LayerOutput->GetTensorInfo(), + weightLayer->m_LayerOutput->Map(true)); armnnUtils::DataLayoutIndexed dataLayout(convDescriptor.m_DataLayout); auto weightsShape = weightsTensor.GetInfo().GetShape(); @@ -76,9 +67,9 @@ public: const unsigned int depthMultiplier = depthwise ? weightsShape[3] / inputChannels : 1; const unsigned int outputChannels = depthwise ? weightsShape[3] : weightsShape[0]; const unsigned int weightsHeight = depthwise ? weightsShape[1] : - weightsShape[dataLayout.GetHeightIndex()]; + weightsShape[dataLayout.GetHeightIndex()]; const unsigned int weightsWidth = depthwise ? weightsShape[2] : - weightsShape[dataLayout.GetWidthIndex()]; + weightsShape[dataLayout.GetWidthIndex()]; const auto* weightsBuffer = static_cast<const T*>(weightsTensor.GetMemoryArea()); const auto* betaBuffer = static_cast<const T*>(betaTensor.GetMemoryArea()); @@ -99,7 +90,7 @@ public: { for (unsigned int cOut = 0; cOut < outputChannels; ++cOut) { - T mult = gammaVector[cOut] / static_cast<T>(sqrtf (varianceVector[cOut] + epsilon)); + T mult = gammaVector[cOut] / static_cast<T>(sqrtf(varianceVector[cOut] + epsilon)); for (unsigned int h = 0; h < weightsHeight; ++h) { @@ -140,23 +131,14 @@ public: if (biasWasEnabledBeforeOpt) { ConstTensor biasTensor; - if (convLayer->GetNumInputSlots() > 1) - { - ARMNN_ASSERT_MSG(convLayer->GetInputSlots()[2].GetConnection() != nullptr, - "FuseBatchNorm: Bias data should not be null if bias is enabled."); - InputSlot & oldSlotBias = const_cast<InputSlot&>(convLayer->GetInputSlots()[2]); - OutputSlot & constantSlotBias = const_cast<OutputSlot&>(*oldSlotBias.GetConnectedOutputSlot()); - ConstantLayer* biasLayer = PolymorphicDowncast<ConstantLayer*>( - &constantSlotBias.GetOwningLayer()); - biasTensor = ConstTensor(biasLayer->m_LayerOutput->GetTensorInfo(), - biasLayer->m_LayerOutput->Map(true)); - } - else - { - ARMNN_ASSERT_MSG(convLayer->m_Bias != nullptr, - "FuseBatchNorm: Bias data should not be null if bias is enabled."); - biasTensor = ConstTensor(convLayer->m_Bias->GetTensorInfo(), convLayer->m_Bias->Map(true)); - } + ARMNN_ASSERT_MSG(convLayer->GetInputSlots()[2].GetConnection() != nullptr, + "FuseBatchNorm: Bias data should not be null if bias is enabled."); + + ConstantLayer* biasLayer = PolymorphicDowncast<ConstantLayer*>( + &base.GetInputSlot(2).GetConnectedOutputSlot()->GetOwningLayer()); + + biasTensor = ConstTensor(biasLayer->m_LayerOutput->GetTensorInfo(), + biasLayer->m_LayerOutput->Map(true)); const auto* biasBuffer = static_cast<const T*>(biasTensor.GetMemoryArea()); std::vector<T> biasVector(biasBuffer, biasBuffer + biasTensor.GetNumElements()); @@ -192,8 +174,6 @@ public: // This optimization will always have 3 input slots on the Conv2d base layer if (newConv2dLayer.GetNumInputSlots() > 1) { - ConstantLayer* weightLayer = PolymorphicDowncast<ConstantLayer*>( - &base.GetInputSlot(1).GetConnectedOutputSlot()->GetOwningLayer()); // Remove old connection and connect to new layer2d weightLayer->GetOutputSlot(0).Disconnect(base.GetInputSlot(1)); weightLayer->GetOutputSlot(0).Connect(newConv2dLayer.GetInputSlot(1)); |