aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/batchnormalization_layer.cl
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/cl_kernels/batchnormalization_layer.cl')
-rw-r--r--src/core/CL/cl_kernels/batchnormalization_layer.cl10
1 files changed, 5 insertions, 5 deletions
diff --git a/src/core/CL/cl_kernels/batchnormalization_layer.cl b/src/core/CL/cl_kernels/batchnormalization_layer.cl
index b7423d8757..f7aa5eb518 100644
--- a/src/core/CL/cl_kernels/batchnormalization_layer.cl
+++ b/src/core/CL/cl_kernels/batchnormalization_layer.cl
@@ -44,7 +44,7 @@
/** Apply batch normalization.
*
- * @param[in] input_ptr Pointer to the first source tensor. Supported data types: QS8/QS16/F32
+ * @param[in] input_ptr Pointer to the first source tensor. Supported data types: QS8/QS16/F16/F32
* @param[in] input_stride_x Stride of the first source tensor in X dimension (in bytes)
* @param[in] input_step_x input_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] input_stride_y Stride of the first source tensor in Y dimension (in bytes)
@@ -100,7 +100,7 @@ __kernel void batchnormalization_layer(TENSOR3D_DECLARATION(input),
Vector gamma = CONVERT_TO_VECTOR_STRUCT(gamma);
VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
- _in = 0;
+ data = 0;
VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
denominator = 0;
VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
@@ -114,13 +114,13 @@ __kernel void batchnormalization_layer(TENSOR3D_DECLARATION(input),
const int current_slice = get_global_id(2);
- _in = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)in.ptr);
+ data = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)in.ptr);
denominator = *((__global DATA_TYPE *)(var.ptr + current_slice * var.stride_x));
- denominator = INVSQRT_OP(ADD_OP(denominator, SQCVT_SAT(epsilon)));
+ denominator = INVSQRT_OP(ADD_OP(denominator, ((VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))SQCVT_SAT(epsilon))));
// Calculate x bar and store results
numerator = *((__global DATA_TYPE *)(mean.ptr + current_slice * mean.stride_x));
- numerator = SUB_OP(_in, numerator);
+ numerator = SUB_OP(data, numerator);
x_bar = MUL_OP(numerator, denominator);
gamma_vec = *((__global DATA_TYPE *)(gamma.ptr + current_slice * beta.stride_x));