aboutsummaryrefslogtreecommitdiff
path: root/src/core
diff options
context:
space:
mode:
Diffstat (limited to 'src/core')
-rw-r--r--src/core/CL/cl_kernels/batchnormalization_layer.cl10
-rw-r--r--src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp2
2 files changed, 6 insertions, 6 deletions
diff --git a/src/core/CL/cl_kernels/batchnormalization_layer.cl b/src/core/CL/cl_kernels/batchnormalization_layer.cl
index b7423d8757..f7aa5eb518 100644
--- a/src/core/CL/cl_kernels/batchnormalization_layer.cl
+++ b/src/core/CL/cl_kernels/batchnormalization_layer.cl
@@ -44,7 +44,7 @@
/** Apply batch normalization.
*
- * @param[in] input_ptr Pointer to the first source tensor. Supported data types: QS8/QS16/F32
+ * @param[in] input_ptr Pointer to the first source tensor. Supported data types: QS8/QS16/F16/F32
* @param[in] input_stride_x Stride of the first source tensor in X dimension (in bytes)
* @param[in] input_step_x input_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] input_stride_y Stride of the first source tensor in Y dimension (in bytes)
@@ -100,7 +100,7 @@ __kernel void batchnormalization_layer(TENSOR3D_DECLARATION(input),
Vector gamma = CONVERT_TO_VECTOR_STRUCT(gamma);
VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
- _in = 0;
+ data = 0;
VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
denominator = 0;
VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
@@ -114,13 +114,13 @@ __kernel void batchnormalization_layer(TENSOR3D_DECLARATION(input),
const int current_slice = get_global_id(2);
- _in = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)in.ptr);
+ data = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)in.ptr);
denominator = *((__global DATA_TYPE *)(var.ptr + current_slice * var.stride_x));
- denominator = INVSQRT_OP(ADD_OP(denominator, SQCVT_SAT(epsilon)));
+ denominator = INVSQRT_OP(ADD_OP(denominator, ((VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))SQCVT_SAT(epsilon))));
// Calculate x bar and store results
numerator = *((__global DATA_TYPE *)(mean.ptr + current_slice * mean.stride_x));
- numerator = SUB_OP(_in, numerator);
+ numerator = SUB_OP(data, numerator);
x_bar = MUL_OP(numerator, denominator);
gamma_vec = *((__global DATA_TYPE *)(gamma.ptr + current_slice * beta.stride_x));
diff --git a/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp b/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp
index 18c0c9721e..43f39f423f 100644
--- a/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp
+++ b/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp
@@ -45,7 +45,7 @@ CLBatchNormalizationLayerKernel::CLBatchNormalizationLayerKernel()
void CLBatchNormalizationLayerKernel::configure(ICLTensor *input, ICLTensor *output, const ICLTensor *mean, const ICLTensor *var, const ICLTensor *beta, const ICLTensor *gamma,
float epsilon)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
_input = input;
_output = output;