aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/optimizations/FuseBatchNorm.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/optimizations/FuseBatchNorm.hpp')
-rw-r--r--src/armnn/optimizations/FuseBatchNorm.hpp25
1 files changed, 8 insertions, 17 deletions
diff --git a/src/armnn/optimizations/FuseBatchNorm.hpp b/src/armnn/optimizations/FuseBatchNorm.hpp
index 3fb4b34d28..fe8238bf14 100644
--- a/src/armnn/optimizations/FuseBatchNorm.hpp
+++ b/src/armnn/optimizations/FuseBatchNorm.hpp
@@ -56,13 +56,12 @@ public:
armnnUtils::DataLayoutIndexed dataLayout(convDescriptor.m_DataLayout);
auto weightsShape = weightsInfo.GetShape();
- const unsigned int depthMultiplier = depthwise ? weightsShape[0] : 1;
- const unsigned int inputChannels = depthwise ? weightsShape[1] :
- weightsShape[dataLayout.GetChannelsIndex()];
- const unsigned int outputChannels = depthwise ? inputChannels * depthMultiplier : weightsShape[0];
- const unsigned int weightsHeight = depthwise ? weightsShape[2] :
+ const unsigned int inputChannels = parentOut->GetTensorInfo().GetShape()[dataLayout.GetChannelsIndex()];
+ const unsigned int depthMultiplier = depthwise ? weightsShape[3] / inputChannels : 1;
+ const unsigned int outputChannels = depthwise ? weightsShape[3] : weightsShape[0];
+ const unsigned int weightsHeight = depthwise ? weightsShape[1] :
weightsShape[dataLayout.GetHeightIndex()];
- const unsigned int weightsWidth = depthwise ? weightsShape[3] :
+ const unsigned int weightsWidth = depthwise ? weightsShape[2] :
weightsShape[dataLayout.GetWidthIndex()];
const auto* weightsBuffer = static_cast<const T*>(weightsTensor.GetMemoryArea());
@@ -79,7 +78,6 @@ public:
// fusedWeights = ( gamma * weights ) / ( std - epsilon);
std::vector<T> fusedWeightsVector(weightsVector.size());
- unsigned int depthwiseMultiplierIdx = 0;
for (unsigned int cInput = 0; cInput < inputChannels; ++cInput)
{
@@ -87,12 +85,6 @@ public:
{
T mult = gammaVector[cOut] / static_cast<T>(sqrtf (varianceVector[cOut] + epsilon));
- if (depthwise)
- {
- cInput = cOut / depthMultiplier;
- depthwiseMultiplierIdx = cOut % depthMultiplier;
- }
-
for (unsigned int h = 0; h < weightsHeight; ++h)
{
for (unsigned int w = 0; w < weightsWidth; ++w)
@@ -101,10 +93,9 @@ public:
if (depthwise)
{
- weightsIdx = depthwiseMultiplierIdx * weightsWidth * weightsHeight * inputChannels +
- cInput * weightsWidth * weightsHeight +
- h * weightsWidth +
- w;
+ cInput = cOut / depthMultiplier;
+ weightsIdx = w * outputChannels + cOut +
+ h * weightsWidth * outputChannels;
}
else if (convDescriptor.m_DataLayout == DataLayout::NHWC)
{