aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/InstanceNorm.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/InstanceNorm.cpp')
-rw-r--r--src/backends/reference/workloads/InstanceNorm.cpp86
1 files changed, 86 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/InstanceNorm.cpp b/src/backends/reference/workloads/InstanceNorm.cpp
new file mode 100644
index 0000000000..9d6532fa6e
--- /dev/null
+++ b/src/backends/reference/workloads/InstanceNorm.cpp
@@ -0,0 +1,86 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "InstanceNorm.hpp"
+#include "RefWorkloadUtils.hpp"
+
+#include <armnn/Tensor.hpp>
+
+#include <DataLayoutIndexed.hpp>
+
+#include <cmath>
+
+namespace armnn
+{
+
+void InstanceNorm(const InstanceNormalizationQueueDescriptor& data,
+ Decoder<float>& inputDecoder,
+ Encoder<float>& outputEncoder)
+{
+ const TensorInfo& inputInfo = GetTensorInfo(data.m_Inputs[0]);
+ const TensorShape inputShape = inputInfo.GetShape();
+
+ armnnUtils::DataLayoutIndexed dataLayout(data.m_Parameters.m_DataLayout);
+
+ unsigned int inputBatches = inputShape[0];
+ unsigned int inputHeight = inputShape[dataLayout.GetHeightIndex()];
+ unsigned int inputWidth = inputShape[dataLayout.GetWidthIndex()];
+ unsigned int inputChannels = inputShape[dataLayout.GetChannelsIndex()];
+
+ float beta = data.m_Parameters.m_Beta;
+ float eps = data.m_Parameters.m_Eps;
+ float gamma = data.m_Parameters.m_Gamma;
+
+ for (unsigned int n = 0; n < inputBatches; ++n)
+ {
+ for (unsigned int c = 0; c < inputChannels; ++c)
+ {
+ float mean = 0, var = 0;
+
+ //Calculate Mean
+ for (unsigned int h = 0; h < inputHeight; h++)
+ {
+ for (unsigned int w = 0; w < inputWidth; w++)
+ {
+ unsigned int index = dataLayout.GetIndex(inputShape, n, c, h, w);
+
+ inputDecoder[index];
+ float value = inputDecoder.Get();
+ mean += value;
+ }
+ }
+ mean /= static_cast<float>(inputHeight * inputWidth);
+
+ //Calculate Variance
+ for (unsigned int h = 0; h < inputHeight; h++)
+ {
+ for (unsigned int w = 0; w < inputWidth; w++)
+ {
+ unsigned int index = dataLayout.GetIndex(inputShape, n, c, h, w);
+
+ inputDecoder[index];
+ float value = inputDecoder.Get();
+ var += (value - mean) * (value - mean);
+ }
+ }
+ var /= static_cast<float>(inputHeight * inputWidth);
+
+ // Apply Instance Normalisation
+ for (unsigned int h = 0; h < inputHeight; ++h)
+ {
+ for (unsigned int w = 0; w < inputWidth; ++w)
+ {
+ unsigned int index = dataLayout.GetIndex(inputShape, n, c, h, w);
+ inputDecoder[index];
+ outputEncoder[index];
+ outputEncoder.Set((inputDecoder.Get() - mean) * gamma / std::sqrt ( var + eps) + beta);
+ }
+
+ }
+ }
+ }
+}
+
+} // namespace armnn