aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/core/CL/cl_kernels/batchnormalization_layer.cl10
-rw-r--r--src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp2
-rw-r--r--tests/validation/CL/BatchNormalizationLayer.cpp22
3 files changed, 23 insertions, 11 deletions
diff --git a/src/core/CL/cl_kernels/batchnormalization_layer.cl b/src/core/CL/cl_kernels/batchnormalization_layer.cl
index b7423d8757..f7aa5eb518 100644
--- a/src/core/CL/cl_kernels/batchnormalization_layer.cl
+++ b/src/core/CL/cl_kernels/batchnormalization_layer.cl
@@ -44,7 +44,7 @@
/** Apply batch normalization.
*
- * @param[in] input_ptr Pointer to the first source tensor. Supported data types: QS8/QS16/F32
+ * @param[in] input_ptr Pointer to the first source tensor. Supported data types: QS8/QS16/F16/F32
* @param[in] input_stride_x Stride of the first source tensor in X dimension (in bytes)
* @param[in] input_step_x input_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] input_stride_y Stride of the first source tensor in Y dimension (in bytes)
@@ -100,7 +100,7 @@ __kernel void batchnormalization_layer(TENSOR3D_DECLARATION(input),
Vector gamma = CONVERT_TO_VECTOR_STRUCT(gamma);
VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
- _in = 0;
+ data = 0;
VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
denominator = 0;
VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
@@ -114,13 +114,13 @@ __kernel void batchnormalization_layer(TENSOR3D_DECLARATION(input),
const int current_slice = get_global_id(2);
- _in = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)in.ptr);
+ data = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)in.ptr);
denominator = *((__global DATA_TYPE *)(var.ptr + current_slice * var.stride_x));
- denominator = INVSQRT_OP(ADD_OP(denominator, SQCVT_SAT(epsilon)));
+ denominator = INVSQRT_OP(ADD_OP(denominator, ((VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))SQCVT_SAT(epsilon))));
// Calculate x bar and store results
numerator = *((__global DATA_TYPE *)(mean.ptr + current_slice * mean.stride_x));
- numerator = SUB_OP(_in, numerator);
+ numerator = SUB_OP(data, numerator);
x_bar = MUL_OP(numerator, denominator);
gamma_vec = *((__global DATA_TYPE *)(gamma.ptr + current_slice * beta.stride_x));
diff --git a/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp b/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp
index 18c0c9721e..43f39f423f 100644
--- a/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp
+++ b/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp
@@ -45,7 +45,7 @@ CLBatchNormalizationLayerKernel::CLBatchNormalizationLayerKernel()
void CLBatchNormalizationLayerKernel::configure(ICLTensor *input, ICLTensor *output, const ICLTensor *mean, const ICLTensor *var, const ICLTensor *beta, const ICLTensor *gamma,
float epsilon)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
_input = input;
_output = output;
diff --git a/tests/validation/CL/BatchNormalizationLayer.cpp b/tests/validation/CL/BatchNormalizationLayer.cpp
index ac30c638b5..69f8d7b635 100644
--- a/tests/validation/CL/BatchNormalizationLayer.cpp
+++ b/tests/validation/CL/BatchNormalizationLayer.cpp
@@ -43,9 +43,10 @@ namespace validation
{
namespace
{
-constexpr AbsoluteTolerance<float> tolerance_f(0.00001f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */
-constexpr AbsoluteTolerance<float> tolerance_qs8(3.0f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::QS8 */
-constexpr AbsoluteTolerance<float> tolerance_qs16(6.0f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::QS16 */
+constexpr AbsoluteTolerance<float> tolerance_f32(0.00001f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */
+constexpr AbsoluteTolerance<float> tolerance_f16(0.01f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */
+constexpr AbsoluteTolerance<float> tolerance_qs8(3.0f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::QS8 */
+constexpr AbsoluteTolerance<float> tolerance_qs16(6.0f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::QS16 */
} // namespace
TEST_SUITE(CL)
@@ -54,7 +55,7 @@ TEST_SUITE(BatchNormalizationLayer)
template <typename T>
using CLBatchNormalizationLayerFixture = BatchNormalizationLayerValidationFixture<CLTensor, CLAccessor, CLBatchNormalizationLayer, T>;
-DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(datasets::RandomBatchNormalizationLayerDataset(), framework::dataset::make("DataType", { DataType::QS8, DataType::QS16, DataType::F32 })),
+DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(datasets::RandomBatchNormalizationLayerDataset(), framework::dataset::make("DataType", { DataType::QS8, DataType::QS16, DataType::F16, DataType::F32 })),
shape0, shape1, epsilon, dt)
{
// Set fixed point position data type allowed
@@ -78,14 +79,25 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(datasets::Ran
}
TEST_SUITE(Float)
+TEST_SUITE(FP32)
FIXTURE_DATA_TEST_CASE(Random, CLBatchNormalizationLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::RandomBatchNormalizationLayerDataset(),
framework::dataset::make("DataType", DataType::F32)))
{
// Validate output
- validate(CLAccessor(_target), _reference, tolerance_f, 0);
+ validate(CLAccessor(_target), _reference, tolerance_f32, 0);
}
TEST_SUITE_END()
+TEST_SUITE(FP16)
+FIXTURE_DATA_TEST_CASE(Random, CLBatchNormalizationLayerFixture<half>, framework::DatasetMode::PRECOMMIT, combine(datasets::RandomBatchNormalizationLayerDataset(),
+ framework::dataset::make("DataType", DataType::F16)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f16, 0);
+}
+TEST_SUITE_END()
+TEST_SUITE_END()
+
TEST_SUITE(Quantized)
template <typename T>
using CLBatchNormalizationLayerFixedPointFixture = BatchNormalizationLayerValidationFixedPointFixture<CLTensor, CLAccessor, CLBatchNormalizationLayer, T>;