From 90231b8c9f680d323e4b93dcd0820a47925e6d24 Mon Sep 17 00:00:00 2001 From: Mike Kelly Date: Thu, 5 Nov 2020 15:44:56 +0000 Subject: IVGCVSW-5315 Create FuseBatchNorm class Signed-off-by: Teresa Charlin Signed-off-by: Mike Kelly Change-Id: Id0625c58dbeea79874bf986b70d136ed9390bf83 --- src/armnn/optimizations/FuseBatchNorm.hpp | 125 +++++++++++++++++++++--------- 1 file changed, 89 insertions(+), 36 deletions(-) (limited to 'src/armnn/optimizations/FuseBatchNorm.hpp') diff --git a/src/armnn/optimizations/FuseBatchNorm.hpp b/src/armnn/optimizations/FuseBatchNorm.hpp index e8e8c5d77f..9d25379930 100644 --- a/src/armnn/optimizations/FuseBatchNorm.hpp +++ b/src/armnn/optimizations/FuseBatchNorm.hpp @@ -7,13 +7,15 @@ #include "Optimization.hpp" #include +#include namespace armnn { namespace optimizations { -template +template > class FuseBatchNorm { public: @@ -27,10 +29,12 @@ public: Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer(); Layer& child = connection.GetOwningLayer(); - ARMNN_ASSERT(base.GetType() == LayerType::Convolution2d); + bool depthwise = (base.GetType() == LayerType::DepthwiseConvolution2d); + + ARMNN_ASSERT(base.GetType() == LayerType::Convolution2d || depthwise); ARMNN_ASSERT(child.GetType() == LayerType::BatchNormalization); - if (base.GetDataType() == DataType::Float32 && child.GetDataType() == DataType::Float32) + if (base.GetDataType() == ArmnnType && child.GetDataType() == ArmnnType) { OutputSlot* parentOut = base.GetInputSlot(0).GetConnectedOutputSlot(); auto convLayer = PolymorphicDowncast(&base); @@ -47,58 +51,92 @@ public: ConstTensor varTensor(batchNormLayer->m_Variance->GetTensorInfo(), batchNormLayer->m_Variance->Map(true)); auto convDescriptor = convLayer->GetParameters(); - ConstTensor weightsTensor(convLayer->m_Weight->GetTensorInfo(), convLayer->m_Weight->Map(true)); + auto weightsInfo(convLayer->m_Weight->GetTensorInfo()); + ConstTensor weightsTensor(weightsInfo, convLayer->m_Weight->Map(true)); armnnUtils::DataLayoutIndexed dataLayout(convDescriptor.m_DataLayout); - auto weightsShape = convLayer->m_Weight->GetTensorInfo().GetShape(); - const unsigned int outputChannels = weightsShape[0]; - const unsigned int inputChannels = weightsShape[dataLayout.GetChannelsIndex()]; - const unsigned int weightsHeight = weightsShape[dataLayout.GetHeightIndex()]; - const unsigned int weightsWidth = weightsShape[dataLayout.GetWidthIndex()]; - - const auto* weightsBuffer = static_cast(weightsTensor.GetMemoryArea()); - const auto* betaBuffer = static_cast(betaTensor.GetMemoryArea()); - const auto* gammaBuffer = static_cast(gammaTensor.GetMemoryArea()); - const auto* meanBuffer = static_cast(meanTensor.GetMemoryArea()); - const auto* varBuffer = static_cast(varTensor.GetMemoryArea()); - - std::vector weightsVector (weightsBuffer, weightsBuffer + weightsTensor.GetNumElements()); - std::vector betaVector (betaBuffer, betaBuffer + betaTensor.GetNumElements()); - std::vector gammaVector (gammaBuffer, gammaBuffer + gammaTensor.GetNumElements()); - std::vector meanVector (meanBuffer, meanBuffer + meanTensor.GetNumElements()); - std::vector varianceVector(varBuffer, varBuffer + varTensor.GetNumElements()); + auto weightsShape = weightsInfo.GetShape(); + const unsigned int depthMultiplier = depthwise ? weightsShape[0] : 1; + const unsigned int inputChannels = depthwise ? weightsShape[1] : + weightsShape[dataLayout.GetChannelsIndex()]; + const unsigned int outputChannels = depthwise ? inputChannels * depthMultiplier : weightsShape[0]; + const unsigned int weightsHeight = depthwise ? weightsShape[2] : + weightsShape[dataLayout.GetHeightIndex()]; + const unsigned int weightsWidth = depthwise ? weightsShape[3] : + weightsShape[dataLayout.GetWidthIndex()]; + + const auto* weightsBuffer = static_cast(weightsTensor.GetMemoryArea()); + const auto* betaBuffer = static_cast(betaTensor.GetMemoryArea()); + const auto* gammaBuffer = static_cast(gammaTensor.GetMemoryArea()); + const auto* meanBuffer = static_cast(meanTensor.GetMemoryArea()); + const auto* varBuffer = static_cast(varTensor.GetMemoryArea()); + + std::vector weightsVector (weightsBuffer, weightsBuffer + weightsTensor.GetNumElements()); + std::vector betaVector (betaBuffer, betaBuffer + betaTensor.GetNumElements()); + std::vector gammaVector (gammaBuffer, gammaBuffer + gammaTensor.GetNumElements()); + std::vector meanVector (meanBuffer, meanBuffer + meanTensor.GetNumElements()); + std::vector varianceVector(varBuffer, varBuffer + varTensor.GetNumElements()); // fusedWeights = ( gamma * weights ) / ( std - epsilon); - std::vector fusedWeightsVector(weightsVector.size()); + std::vector fusedWeightsVector(weightsVector.size()); + unsigned int depthwiseMultiplierIdx = 0; - unsigned int i = 0; - for (unsigned int cOut = 0; cOut < outputChannels; ++cOut) + for (unsigned int cInput = 0; cInput < inputChannels; ++cInput) { - auto mult = gammaVector[cOut] / sqrtf (varianceVector[cOut] + epsilon); - for (unsigned int cInput = 0; cInput < inputChannels; ++cInput) + for (unsigned int cOut = 0; cOut < outputChannels; ++cOut) { + T mult = gammaVector[cOut] / static_cast(sqrtf (varianceVector[cOut] + epsilon)); + + if (depthwise) + { + cInput = cOut / depthMultiplier; + depthwiseMultiplierIdx = cOut % depthMultiplier; + } + for (unsigned int h = 0; h < weightsHeight; ++h) { for (unsigned int w = 0; w < weightsWidth; ++w) { - fusedWeightsVector[i] = mult * weightsVector[i]; - i++; + unsigned int weightsIdx = 0; + + if (depthwise) + { + weightsIdx = depthwiseMultiplierIdx * weightsWidth * weightsHeight * inputChannels + + cInput * weightsWidth * weightsHeight + + h * weightsWidth + + w; + } + else if (convDescriptor.m_DataLayout == DataLayout::NHWC) + { + weightsIdx = cOut * weightsHeight * weightsWidth * inputChannels + + h * weightsWidth * inputChannels + + w * inputChannels + + cInput; + } + else + { + weightsIdx = cOut * weightsWidth * weightsHeight * inputChannels + + cInput * weightsWidth * weightsHeight + + h * weightsWidth + + w; + } + fusedWeightsVector[weightsIdx] = mult * weightsVector[weightsIdx]; } } } } - ConstTensor fusedWeightsTensor(convLayer->m_Weight->GetTensorInfo(), fusedWeightsVector); + ConstTensor fusedWeightsTensor(weightsInfo, fusedWeightsVector); // fusedBias = (gamma * (bias - mean)) / (variance - epsilon) + beta; - std::vector fusedBiasVector(outputChannels); + std::vector fusedBiasVector(outputChannels); if (convDescriptor.m_BiasEnabled) { ARMNN_ASSERT_MSG(convLayer->m_Bias != nullptr, "FuseBatchNorm: Bias data should not be null if bias is enabled."); ConstTensor biasTensor(convLayer->m_Bias->GetTensorInfo(), convLayer->m_Bias->Map(true)); - const auto* biasBuffer = static_cast(biasTensor.GetMemoryArea()); - std::vector biasVector(biasBuffer, biasBuffer + biasTensor.GetNumElements()); + const auto* biasBuffer = static_cast(biasTensor.GetMemoryArea()); + std::vector biasVector(biasBuffer, biasBuffer + biasTensor.GetNumElements()); for (unsigned int cOut = 0; cOut < outputChannels; ++cOut) { @@ -109,7 +147,7 @@ public: else { convDescriptor.m_BiasEnabled = true; - std::vector biasVector(outputChannels, 0); + std::vector biasVector(outputChannels, T(0)); for (unsigned int cOut = 0; cOut < outputChannels; ++cOut) { @@ -117,7 +155,7 @@ public: sqrtf(varianceVector[cOut] + epsilon)) + betaVector[cOut]; } } - ConstTensor fusedBiasTensor(TensorInfo({outputChannels}, DataType::Float32), fusedBiasVector); + ConstTensor fusedBiasTensor(TensorInfo({outputChannels}, ArmnnType), fusedBiasVector); // Insert the new convolution layer that has batch norm parameters fused into const std::string name = std::string("fused-") + child.GetName() + std::string("-into-") + base.GetName(); @@ -143,10 +181,25 @@ protected: ~FuseBatchNorm() = default; }; -using FuseBatchNormIntoConvolution2D = +using FuseBatchNormIntoConvolution2DFloat32 = OptimizeForExclusiveConnection>; + FuseBatchNorm>; + +using FuseBatchNormIntoConvolution2DFloat16 = + OptimizeForExclusiveConnection>; + +using FuseBatchNormIntoDepthwiseConvolution2DFloat32 = + OptimizeForExclusiveConnection>; + +using FuseBatchNormIntoDepthwiseConvolution2DFloat16 = + OptimizeForExclusiveConnection>; } // namespace optimizations } // namespace armnn \ No newline at end of file -- cgit v1.2.1