aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/batchnormalization_layer.cl
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/cl_kernels/batchnormalization_layer.cl')
-rw-r--r--src/core/CL/cl_kernels/batchnormalization_layer.cl21
1 files changed, 20 insertions, 1 deletions
diff --git a/src/core/CL/cl_kernels/batchnormalization_layer.cl b/src/core/CL/cl_kernels/batchnormalization_layer.cl
index fbffefb3c0..5ddeb1a6a1 100644
--- a/src/core/CL/cl_kernels/batchnormalization_layer.cl
+++ b/src/core/CL/cl_kernels/batchnormalization_layer.cl
@@ -23,6 +23,8 @@
*/
#include "helpers.h"
+#if defined(VEC_SIZE) && defined(DATA_TYPE)
+
#if defined(FIXED_POINT_POSITION)
#include "fixed_point.h"
@@ -42,6 +44,16 @@
#endif /* FIXED_POINT_POSITION */
+#if defined(LU_BRELU)
+#define ACTIVATION_FUNC(x) CLAMP(x, (DATA_TYPE)B_VAL, (DATA_TYPE)A_VAL)
+#elif defined(BRELU)
+#define ACTIVATION_FUNC(x) CLAMP(x, (DATA_TYPE)0, (DATA_TYPE)A_VAL)
+#elif defined(RELU)
+#define ACTIVATION_FUNC(x) max(x, (DATA_TYPE)0)
+#else /* FUSED_ACT */
+#define ACTIVATION_FUNC(x) (x)
+#endif /* FUSED_ACT */
+
/** Apply batch normalization.
*
* @param[in] input_ptr Pointer to the first source tensor. Supported data types: QS8/QS16/F16/F32
@@ -126,6 +138,13 @@ __kernel void batchnormalization_layer(TENSOR3D_DECLARATION(input),
gamma_vec = *((__global DATA_TYPE *)(gamma.ptr + current_slice * gamma.stride_x));
beta_vec = *((__global DATA_TYPE *)(beta.ptr + current_slice * beta.stride_x));
+ VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
+ res = ADD_OP(MUL_OP(gamma_vec, x_bar), beta_vec);
+
+ res = ACTIVATION_FUNC(res);
+
VSTORE(VEC_SIZE)
- (ADD_OP(MUL_OP(gamma_vec, x_bar), beta_vec), 0, (__global DATA_TYPE *)out.ptr);
+ (res, 0, (__global DATA_TYPE *)out.ptr);
}
+
+#endif /* defined(VEC_SIZE) && defined(DATA_TYPE) */ \ No newline at end of file