aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/BatchNormImpl.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/BatchNormImpl.hpp')
-rw-r--r--src/backends/reference/workloads/BatchNormImpl.hpp56
1 files changed, 56 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/BatchNormImpl.hpp b/src/backends/reference/workloads/BatchNormImpl.hpp
new file mode 100644
index 0000000000..a7579c8373
--- /dev/null
+++ b/src/backends/reference/workloads/BatchNormImpl.hpp
@@ -0,0 +1,56 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "RefWorkloadUtils.hpp"
+
+#include <armnn/Tensor.hpp>
+
+#include <cmath>
+
+namespace armnn
+{
+
+template<typename NormData>
+static void BatchNormImpl(NormData data,
+ const float* varIn,
+ const float* meanIn,
+ const float* gammaIn,
+ const float* betaIn,
+ float * outputData,
+ const float * inputData)
+{
+ const TensorInfo& inputInfo0 = GetTensorInfo(data.m_Inputs[0]);
+ for (unsigned int c = 0; c < inputInfo0.GetShape()[1]; c++)
+ {
+ float var = varIn[c];
+ float mean = meanIn[c];
+ float gamma = gammaIn[c];
+ float beta = betaIn[c];
+
+ float mult = gamma / sqrtf(var + data.m_Parameters.m_Eps);
+ float add = beta - mult * mean;
+
+ for (unsigned int n = 0; n < inputInfo0.GetShape()[0]; n++)
+ {
+ for (unsigned int j = 0; j < inputInfo0.GetShape()[2]; j++)
+ {
+ for (unsigned int i = 0; i < inputInfo0.GetShape()[3]; i++)
+ {
+ unsigned int index = i +
+ j*inputInfo0.GetShape()[3] +
+ c*inputInfo0.GetShape()[3] * inputInfo0.GetShape()[2] +
+ n*inputInfo0.GetShape()[3] * inputInfo0.GetShape()[2]
+ * inputInfo0.GetShape()[1];
+
+ outputData[index] = mult * inputData[index] + add;
+ }
+ }
+ }
+ }
+}
+
+} //namespace armnn