diff options
Diffstat (limited to 'src/armnn/optimizations/FuseBatchNorm.hpp')
-rw-r--r-- | src/armnn/optimizations/FuseBatchNorm.hpp | 25 |
1 files changed, 8 insertions, 17 deletions
diff --git a/src/armnn/optimizations/FuseBatchNorm.hpp b/src/armnn/optimizations/FuseBatchNorm.hpp index 3fb4b34d28..fe8238bf14 100644 --- a/src/armnn/optimizations/FuseBatchNorm.hpp +++ b/src/armnn/optimizations/FuseBatchNorm.hpp @@ -56,13 +56,12 @@ public: armnnUtils::DataLayoutIndexed dataLayout(convDescriptor.m_DataLayout); auto weightsShape = weightsInfo.GetShape(); - const unsigned int depthMultiplier = depthwise ? weightsShape[0] : 1; - const unsigned int inputChannels = depthwise ? weightsShape[1] : - weightsShape[dataLayout.GetChannelsIndex()]; - const unsigned int outputChannels = depthwise ? inputChannels * depthMultiplier : weightsShape[0]; - const unsigned int weightsHeight = depthwise ? weightsShape[2] : + const unsigned int inputChannels = parentOut->GetTensorInfo().GetShape()[dataLayout.GetChannelsIndex()]; + const unsigned int depthMultiplier = depthwise ? weightsShape[3] / inputChannels : 1; + const unsigned int outputChannels = depthwise ? weightsShape[3] : weightsShape[0]; + const unsigned int weightsHeight = depthwise ? weightsShape[1] : weightsShape[dataLayout.GetHeightIndex()]; - const unsigned int weightsWidth = depthwise ? weightsShape[3] : + const unsigned int weightsWidth = depthwise ? weightsShape[2] : weightsShape[dataLayout.GetWidthIndex()]; const auto* weightsBuffer = static_cast<const T*>(weightsTensor.GetMemoryArea()); @@ -79,7 +78,6 @@ public: // fusedWeights = ( gamma * weights ) / ( std - epsilon); std::vector<T> fusedWeightsVector(weightsVector.size()); - unsigned int depthwiseMultiplierIdx = 0; for (unsigned int cInput = 0; cInput < inputChannels; ++cInput) { @@ -87,12 +85,6 @@ public: { T mult = gammaVector[cOut] / static_cast<T>(sqrtf (varianceVector[cOut] + epsilon)); - if (depthwise) - { - cInput = cOut / depthMultiplier; - depthwiseMultiplierIdx = cOut % depthMultiplier; - } - for (unsigned int h = 0; h < weightsHeight; ++h) { for (unsigned int w = 0; w < weightsWidth; ++w) @@ -101,10 +93,9 @@ public: if (depthwise) { - weightsIdx = depthwiseMultiplierIdx * weightsWidth * weightsHeight * inputChannels + - cInput * weightsWidth * weightsHeight + - h * weightsWidth + - w; + cInput = cOut / depthMultiplier; + weightsIdx = w * outputChannels + cOut + + h * weightsWidth * outputChannels; } else if (convDescriptor.m_DataLayout == DataLayout::NHWC) { |