aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/InstanceNorm.cpp
blob: 9d6532fa6eb7cfed72aacc0840b1ba602892ccdd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
//
// Copyright © 2019 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include "InstanceNorm.hpp"
#include "RefWorkloadUtils.hpp"

#include <armnn/Tensor.hpp>

#include <DataLayoutIndexed.hpp>

#include <cmath>

namespace armnn
{

void InstanceNorm(const InstanceNormalizationQueueDescriptor& data,
                  Decoder<float>& inputDecoder,
                  Encoder<float>& outputEncoder)
{
    const TensorInfo& inputInfo = GetTensorInfo(data.m_Inputs[0]);
    const TensorShape inputShape = inputInfo.GetShape();

    armnnUtils::DataLayoutIndexed dataLayout(data.m_Parameters.m_DataLayout);

    unsigned int inputBatches  = inputShape[0];
    unsigned int inputHeight   = inputShape[dataLayout.GetHeightIndex()];
    unsigned int inputWidth    = inputShape[dataLayout.GetWidthIndex()];
    unsigned int inputChannels = inputShape[dataLayout.GetChannelsIndex()];

    float beta  = data.m_Parameters.m_Beta;
    float eps   = data.m_Parameters.m_Eps;
    float gamma = data.m_Parameters.m_Gamma;

    for (unsigned int n = 0; n < inputBatches; ++n)
    {
        for (unsigned int c = 0; c < inputChannels; ++c)
        {
            float mean = 0, var = 0;

            //Calculate Mean
            for (unsigned int h = 0; h < inputHeight; h++)
            {
                for (unsigned int w = 0; w < inputWidth; w++)
                {
                    unsigned int index = dataLayout.GetIndex(inputShape, n, c, h, w);

                    inputDecoder[index];
                    float value = inputDecoder.Get();
                    mean += value;
                }
            }
            mean /= static_cast<float>(inputHeight * inputWidth);

            //Calculate Variance
            for (unsigned int h = 0; h < inputHeight; h++)
            {
                for (unsigned int w = 0; w < inputWidth; w++)
                {
                    unsigned int index = dataLayout.GetIndex(inputShape, n, c, h, w);

                    inputDecoder[index];
                    float value = inputDecoder.Get();
                    var += (value - mean) * (value - mean);
                }
            }
            var /= static_cast<float>(inputHeight * inputWidth);

            // Apply Instance Normalisation
            for (unsigned int h = 0; h < inputHeight; ++h)
            {
                for (unsigned int w = 0; w < inputWidth; ++w)
                {
                    unsigned int index = dataLayout.GetIndex(inputShape, n, c, h, w);
                    inputDecoder[index];
                    outputEncoder[index];
                    outputEncoder.Set((inputDecoder.Get() - mean) * gamma /  std::sqrt ( var + eps) + beta);
                }

            }
        }
    }
}

} // namespace armnn