aboutsummaryrefslogtreecommitdiff
path: root/src/core/GLES_COMPUTE
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2018-03-02 09:43:54 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:49:37 +0000
commit4d33630096c769dd43716dd5607f151e3d5abef7 (patch)
tree762897c2acac9553c0dad688d0c21842c8edff16 /src/core/GLES_COMPUTE
parent1cd41495153c4e89d6195b42f870967339c1a13b (diff)
downloadComputeLibrary-4d33630096c769dd43716dd5607f151e3d5abef7.tar.gz
COMPMID-987: Make beta and gamma optional in BatchNormalization
Currently we have beta and gamma compulsory in Batch normalization. There are network that might not need one or both of those. Thus these should be optional with beta(offset) defaulting to zero and gamma(scale) to 1. Will also reduce some memory requirements. Change-Id: I15bf1ec14b814be2acebf1be1a4fba9c4fbd3190 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/123237 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core/GLES_COMPUTE')
-rw-r--r--src/core/GLES_COMPUTE/cs_shaders/batchnormalization_layer.cs143
-rw-r--r--src/core/GLES_COMPUTE/kernels/GCBatchNormalizationLayerKernel.cpp171
2 files changed, 235 insertions, 79 deletions
diff --git a/src/core/GLES_COMPUTE/cs_shaders/batchnormalization_layer.cs b/src/core/GLES_COMPUTE/cs_shaders/batchnormalization_layer.cs
index 7629b255b7..81be9679b2 100644
--- a/src/core/GLES_COMPUTE/cs_shaders/batchnormalization_layer.cs
+++ b/src/core/GLES_COMPUTE/cs_shaders/batchnormalization_layer.cs
@@ -50,6 +50,8 @@ precision mediump float;
*
* @note The data type must be passed at compile time using "#define DATA_TYPE_NAME". e.g. "#define DATA_TYPE_FP32"
* @note Epsilon parameter in the batch normalization equation should be given as a preprocessor argument using "#define EPSILON". e.g. "#define EPSILON 0.1"
+ * @note Beta is optional with default value of 0. If not provided, the preprocessor argument "USE_DEFAULT_BETA" should be given
+ * @note Gamma is optional with default value of 1. If not provided, the preprocessor argument "USE_DEFAULT_GAMMA" should be given
*
* @param[in] src_ptr Pointer to the first source tensor. Supported data types: F16/F32
* @param[in] src_attrs The attributes of the source tensor
@@ -59,10 +61,10 @@ precision mediump float;
* @param[in] mean_attrs The attributes of the mean tensor
* @param[in] var_ptr Pointer to the var tensor. Supported data types: same as @p src_ptr
* @param[in] var_attrs The attributes of the var tensor
- * @param[in] beta_ptr Pointer to the beta source tensor. Supported data types: same as @p src_ptr
- * @param[in] beta_attrs The attributes of the beta tensor
- * @param[in] gamma_ptr Pointer to the gamma source tensor. Supported data types: same as @p src_ptr
- * @param[in] gamma_attrs The attributes of the gamma tensor
+ * @param[in] beta_ptr (Optional) Pointer to the beta source tensor. If not provided, default value of beta is 0. Supported data types: same as @p src_ptr
+ * @param[in] beta_attrs (Optional) The attributes of the beta tensor
+ * @param[in] gamma_ptr (Optional) Pointer to the gamma source tensor. If not provided, default value of gamma is 1. Supported data types: same as @p src_ptr
+ * @param[in] gamma_attrs (Optional) The attributes of the gamma tensor
*/
SHADER_PARAMS_DECLARATION
{
@@ -70,8 +72,12 @@ SHADER_PARAMS_DECLARATION
Tensor3DAttributes dst_attrs;
VectorAttributes mean_attrs;
VectorAttributes var_attrs;
- VectorAttributes beta_attrs;
- VectorAttributes gamma_attrs;
+#ifndef USE_DEFAULT_BETA
+ VectorAttributes beta_attrs;
+#endif /* USE_DEFAULT_BETA */
+#ifndef USE_DEFAULT_GAMMA
+ VectorAttributes gamma_attrs;
+#endif /* USE_DEFAULT_GAMMA */
};
#ifdef DATA_TYPE_FP32
@@ -79,24 +85,34 @@ TENSOR_DECLARATION(1, srcBuffer, float, src_ptr, src_shift, 2, readonly);
TENSOR_DECLARATION(2, dstBuffer, float, dst_ptr, dst_shift, 2, writeonly);
TENSOR_DECLARATION(3, meanBuffer, float, mean_ptr, mean_shift, 2, readonly);
TENSOR_DECLARATION(4, varBuffer, float, var_ptr, var_shift, 2, readonly);
+#ifndef USE_DEFAULT_BETA
TENSOR_DECLARATION(5, betaBuffer, float, beta_ptr, beta_shift, 2, readonly);
+#endif /* USE_DEFAULT_BETA */
+#ifndef USE_DEFAULT_GAMMA
+#ifdef USE_DEFAULT_BETA
+TENSOR_DECLARATION(5, gammaBuffer, float, gamma_ptr, gamma_shift, 2, readonly);
+#else /* USE_DEFAULT_BETA */
TENSOR_DECLARATION(6, gammaBuffer, float, gamma_ptr, gamma_shift, 2, readonly);
+#endif /* USE_DEFAULT_BETA */
+#endif /* USE_DEFAULT_GAMMA */
void main(void)
{
- Tensor3DIterator src_iter = CONVERT_TO_TENSOR3D_ITERATOR(src_attrs, src_shift);
- Tensor3DIterator dst_iter = CONVERT_TO_TENSOR3D_ITERATOR(dst_attrs, dst_shift);
- VectorIterator mean_iter = CONVERT_TO_VECTOR_ITERATOR(mean_attrs, mean_shift);
- VectorIterator var_iter = CONVERT_TO_VECTOR_ITERATOR(var_attrs, var_shift);
- VectorIterator beta_iter = CONVERT_TO_VECTOR_ITERATOR(beta_attrs, beta_shift);
- VectorIterator gamma_iter = CONVERT_TO_VECTOR_ITERATOR(gamma_attrs, gamma_shift);
+ Tensor3DIterator src_iter = CONVERT_TO_TENSOR3D_ITERATOR(src_attrs, src_shift);
+ Tensor3DIterator dst_iter = CONVERT_TO_TENSOR3D_ITERATOR(dst_attrs, dst_shift);
+ VectorIterator mean_iter = CONVERT_TO_VECTOR_ITERATOR(mean_attrs, mean_shift);
+ VectorIterator var_iter = CONVERT_TO_VECTOR_ITERATOR(var_attrs, var_shift);
+#ifndef USE_DEFAULT_BETA
+ VectorIterator beta_iter = CONVERT_TO_VECTOR_ITERATOR(beta_attrs, beta_shift);
+#endif /* USE_DEFAULT_BETA */
+#ifndef USE_DEFAULT_GAMMA
+ VectorIterator gamma_iter = CONVERT_TO_VECTOR_ITERATOR(gamma_attrs, gamma_shift);
+#endif /* USE_DEFAULT_GAMMA */
float input_value = 0.f;
float denominator = 0.f;
float numerator = 0.f;
float x_bar = 0.f;
- float gamma_param = 0.f;
- float beta_param = 0.f;
uint current_slice = gl_GlobalInvocationID.z;
@@ -109,10 +125,18 @@ void main(void)
numerator = SUB_OP(input_value, numerator);
x_bar = MUL_OP(numerator, denominator);
- gamma_param = LOAD(gamma_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(gamma_iter, current_slice * beta_attrs.stride_x));
- beta_param = LOAD(beta_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(beta_iter, current_slice * beta_attrs.stride_x));
+#ifndef USE_DEFAULT_GAMMA
+ float gamma_param = LOAD(gamma_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(gamma_iter, current_slice * gamma_attrs.stride_x));
+
+ x_bar = MUL_OP(gamma_param, x_bar);
+#endif /* USE_DEFAULT_GAMMA */
+#ifndef USE_DEFAULT_BETA
+ float beta_param = LOAD(beta_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(beta_iter, current_slice * beta_attrs.stride_x));
+
+ x_bar = ADD_OP(x_bar, beta_param);
+#endif /* USE_DEFAULT_BETA */
- STORE_CURRENT_ITEM(dst_ptr, dst_iter, ACTIVATION_FUNC(ADD_OP(MUL_OP(gamma_param, x_bar), beta_param)));
+ STORE_CURRENT_ITEM(dst_ptr, dst_iter, ACTIVATION_FUNC(x_bar));
}
#elif defined(DATA_TYPE_FP16)
@@ -120,8 +144,16 @@ TENSOR_DECLARATION(1, srcBuffer, uvec2, src_ptr, src_shift, 3, readonly);
TENSOR_DECLARATION(2, dstBuffer, uvec2, dst_ptr, dst_shift, 3, writeonly);
TENSOR_DECLARATION(3, meanBuffer, uvec2, mean_ptr, mean_shift, 3, readonly);
TENSOR_DECLARATION(4, varBuffer, uvec2, var_ptr, var_shift, 3, readonly);
+#ifndef USE_DEFAULT_BETA
TENSOR_DECLARATION(5, betaBuffer, uvec2, beta_ptr, beta_shift, 3, readonly);
+#endif /* USE_DEFAULT_BETA */
+#ifndef USE_DEFAULT_GAMMA
+#ifdef USE_DEFAULT_BETA
+TENSOR_DECLARATION(5, gammaBuffer, uvec2, gamma_ptr, gamma_shift, 3, readonly);
+#else /* USE_DEFAULT_BETA */
TENSOR_DECLARATION(6, gammaBuffer, uvec2, gamma_ptr, gamma_shift, 3, readonly);
+#endif /* USE_DEFAULT_BETA */
+#endif /* USE_DEFAULT_GAMMA */
void main(void)
{
@@ -129,14 +161,18 @@ void main(void)
Tensor3DIterator dst_iter = CONVERT_TO_TENSOR3D_ITERATOR(dst_attrs, dst_shift);
VectorIterator mean_iter = CONVERT_TO_VECTOR_ITERATOR(mean_attrs, mean_shift);
VectorIterator var_iter = CONVERT_TO_VECTOR_ITERATOR(var_attrs, var_shift);
+#ifndef USE_DEFAULT_BETA
VectorIterator beta_iter = CONVERT_TO_VECTOR_ITERATOR(beta_attrs, beta_shift);
+#endif /* USE_DEFAULT_BETA */
+#ifndef USE_DEFAULT_GAMMA
VectorIterator gamma_iter = CONVERT_TO_VECTOR_ITERATOR(gamma_attrs, gamma_shift);
+#endif /* USE_DEFAULT_GAMMA */
vec4 unpacked_s[5];
float denominator;
float numerator;
- float gamma_param;
- float beta_param;
+ float gamma_param = 1.f;
+ float beta_param = 0.f;
vec4 x_bar;
vec4 result;
@@ -144,68 +180,87 @@ void main(void)
unpacked_s[0] = LOAD_UNPACK4_CURRENT_ITEM_HALF(src_ptr, src_iter);
unpacked_s[1] = LOAD_UNPACK4_HALF(var_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(var_iter, current_slice * var_attrs.stride_x));
unpacked_s[2] = LOAD_UNPACK4_HALF(mean_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(mean_iter, current_slice * mean_attrs.stride_x));
- unpacked_s[3] = LOAD_UNPACK4_HALF(gamma_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(gamma_iter, current_slice * beta_attrs.stride_x));
+#ifndef USE_DEFAULT_GAMMA
+ unpacked_s[3] = LOAD_UNPACK4_HALF(gamma_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(gamma_iter, current_slice * gamma_attrs.stride_x));
+#endif /* USE_DEFAULT_BETA */
+#ifndef USE_DEFAULT_BETA
unpacked_s[4] = LOAD_UNPACK4_HALF(beta_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(beta_iter, current_slice * beta_attrs.stride_x));
+#endif /* USE_DEFAULT_GAMMA */
if((current_slice % uint(4)) == uint(0))
{
denominator = unpacked_s[1].x;
denominator = INVSQRT_OP(ADD_OP(denominator, SQCVT_SAT(float(ESPILON))));
- //Calculate x bar and store results
- numerator = unpacked_s[2].x;
- x_bar = MUL_OP(SUB_OP(unpacked_s[0], numerator), denominator);
+ // Calculate x bar
+ numerator = unpacked_s[2].x;
+ x_bar = MUL_OP(SUB_OP(unpacked_s[0], numerator), denominator);
+#ifndef USE_DEFAULT_GAMMA
gamma_param = unpacked_s[3].x;
+#endif /* USE_DEFAULT_GAMMA */
+#ifndef USE_DEFAULT_BETA
beta_param = unpacked_s[4].x;
- result = ACTIVATION_FUNC(ADD_OP(MUL_OP(gamma_param, x_bar), beta_param));
-
- STORE_PACK4_CURRENT_ITEM_HALF(dst_ptr, dst_iter, result);
+#endif /* USE_DEFAULT_BETA */
}
else if((current_slice % uint(4)) == uint(1))
{
denominator = unpacked_s[1].y;
denominator = INVSQRT_OP(ADD_OP(denominator, SQCVT_SAT(float(ESPILON))));
- //Calculate x bar and store results
- numerator = unpacked_s[2].y;
- x_bar = MUL_OP(SUB_OP(unpacked_s[0], numerator), denominator);
+ // Calculate x bar
+ numerator = unpacked_s[2].y;
+ x_bar = MUL_OP(SUB_OP(unpacked_s[0], numerator), denominator);
+#ifndef USE_DEFAULT_GAMMA
gamma_param = unpacked_s[3].y;
+#endif /* USE_DEFAULT_GAMMA */
+#ifndef USE_DEFAULT_BETA
beta_param = unpacked_s[4].y;
- result = ACTIVATION_FUNC(ADD_OP(MUL_OP(gamma_param, x_bar), beta_param));
-
- STORE_PACK4_CURRENT_ITEM_HALF(dst_ptr, dst_iter, result);
+#endif /* USE_DEFAULT_BETA */
}
else if((current_slice % uint(4)) == uint(2))
{
denominator = unpacked_s[1].z;
denominator = INVSQRT_OP(ADD_OP(denominator, SQCVT_SAT(float(ESPILON))));
- //Calculate x bar and store results
- numerator = unpacked_s[2].z;
- x_bar = MUL_OP(SUB_OP(unpacked_s[0], numerator), denominator);
+ // Calculate x bar
+ numerator = unpacked_s[2].z;
+ x_bar = MUL_OP(SUB_OP(unpacked_s[0], numerator), denominator);
+#ifndef USE_DEFAULT_GAMMA
gamma_param = unpacked_s[3].z;
+#endif /* USE_DEFAULT_GAMMA */
+#ifndef USE_DEFAULT_BETA
beta_param = unpacked_s[4].z;
- result = ACTIVATION_FUNC(ADD_OP(MUL_OP(gamma_param, x_bar), beta_param));
-
- STORE_PACK4_CURRENT_ITEM_HALF(dst_ptr, dst_iter, result);
+#endif /* USE_DEFAULT_BETA */
}
else
{
denominator = unpacked_s[1].w;
denominator = INVSQRT_OP(ADD_OP(denominator, SQCVT_SAT(float(ESPILON))));
- //Calculate x bar and store results
- numerator = unpacked_s[2].w;
- x_bar = MUL_OP(SUB_OP(unpacked_s[0], numerator), denominator);
+ // Calculate x bar
+ numerator = unpacked_s[2].w;
+ x_bar = MUL_OP(SUB_OP(unpacked_s[0], numerator), denominator);
+#ifndef USE_DEFAULT_GAMMA
gamma_param = unpacked_s[3].w;
+#endif /* USE_DEFAULT_GAMMA */
+#ifndef USE_DEFAULT_BETA
beta_param = unpacked_s[4].w;
- result = ACTIVATION_FUNC(ADD_OP(MUL_OP(gamma_param, x_bar), beta_param));
-
- STORE_PACK4_CURRENT_ITEM_HALF(dst_ptr, dst_iter, result);
+#endif /* USE_DEFAULT_BETA */
}
+
+#ifndef USE_DEFAULT_GAMMA
+ x_bar = MUL_OP(gamma_param, x_bar);
+#endif /* USE_DEFAULT_GAMMA */
+#ifndef USE_DEFAULT_BETA
+ x_bar = ADD_OP(x_bar, beta_param);
+#endif /* USE_DEFAULT_BETA */
+
+ result = ACTIVATION_FUNC(x_bar);
+
+ STORE_PACK4_CURRENT_ITEM_HALF(dst_ptr, dst_iter, result);
}
#endif /*DATA_TYPE_FP16*/
diff --git a/src/core/GLES_COMPUTE/kernels/GCBatchNormalizationLayerKernel.cpp b/src/core/GLES_COMPUTE/kernels/GCBatchNormalizationLayerKernel.cpp
index cd93f6997e..9a592dfe00 100644
--- a/src/core/GLES_COMPUTE/kernels/GCBatchNormalizationLayerKernel.cpp
+++ b/src/core/GLES_COMPUTE/kernels/GCBatchNormalizationLayerKernel.cpp
@@ -36,32 +36,118 @@
using namespace arm_compute;
-GCBatchNormalizationLayerKernel::GCBatchNormalizationLayerKernel()
- : _input(nullptr), _output(nullptr), _mean(nullptr), _var(nullptr), _beta(nullptr), _gamma(nullptr), _epsilon(0.0f)
+namespace
{
-}
-
-void GCBatchNormalizationLayerKernel::configure(const IGCTensor *input, IGCTensor *output, const IGCTensor *mean, const IGCTensor *var, const IGCTensor *beta, const IGCTensor *gamma,
- float epsilon, ActivationLayerInfo act_info)
+Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output,
+ const ITensorInfo *mean, const ITensorInfo *var,
+ const ITensorInfo *beta, const ITensorInfo *gamma,
+ float epsilon, ActivationLayerInfo act_info)
{
+ ARM_COMPUTE_UNUSED(epsilon);
+ ARM_COMPUTE_UNUSED(var);
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_NULLPTR(output);
- // Output tensor auto initialization if not yet initialized
- auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, mean, var);
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, mean, var);
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(mean, var);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output, mean, var, beta, gamma);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output, mean, var, beta, gamma);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, output);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(mean, var, beta, gamma);
+ if(output->total_size() != 0)
+ {
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, output);
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
+ }
+
+ if(beta != nullptr)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, beta);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, beta);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, beta);
+ }
+ if(gamma != nullptr)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, gamma);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, gamma);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, gamma);
+ }
if(act_info.enabled())
{
- ARM_COMPUTE_ERROR_ON(input->info()->data_type() != DataType::F32 && input->info()->data_type() != DataType::F16);
+ ARM_COMPUTE_ERROR_ON(input->data_type() != DataType::F32 && input->data_type() != DataType::F16);
ARM_COMPUTE_ERROR_ON(act_info.activation() != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::RELU
&& act_info.activation() != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::BOUNDED_RELU
&& act_info.activation() != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU);
ARM_COMPUTE_ERROR_ON(act_info.b() > act_info.a());
}
+ return Status{};
+}
+
+std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output,
+ ITensorInfo *mean, ITensorInfo *var,
+ ITensorInfo *beta, ITensorInfo *gamma)
+{
+ // Output tensor auto initialization if not yet initialized
+ auto_init_if_empty(*output, input->tensor_shape(), 1, input->data_type(), input->fixed_point_position());
+
+ unsigned int num_elems_processed_per_iteration = 1;
+ if(input->data_type() == DataType::F16)
+ {
+ num_elems_processed_per_iteration = 4;
+ }
+
+ // Configure kernel window
+ Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
+
+ AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
+ AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
+ AccessWindowStatic mean_access(mean, 0, 0, mean->dimension(0) + 3, mean->dimension(1));
+ AccessWindowStatic var_access(var, 0, 0, var->dimension(0) + 3, var->dimension(1));
+
+ bool window_changed = false;
+ if(beta != nullptr)
+ {
+ AccessWindowStatic beta_access(beta, 0, 0, beta->dimension(0) + 3, beta->dimension(1));
+ if(gamma != nullptr)
+ {
+ AccessWindowStatic gamma_access(gamma, 0, 0, gamma->dimension(0) + 3, gamma->dimension(1));
+ window_changed = update_window_and_padding(win, input_access, output_access, mean_access, var_access, beta_access, gamma_access);
+ }
+ else
+ {
+ window_changed = update_window_and_padding(win, input_access, output_access, mean_access, var_access, beta_access);
+ }
+ }
+ else
+ {
+ if(gamma != nullptr)
+ {
+ AccessWindowStatic gamma_access(gamma, 0, 0, gamma->dimension(0) + 3, gamma->dimension(1));
+ window_changed = update_window_and_padding(win, input_access, output_access, mean_access, var_access, gamma_access);
+ }
+ else
+ {
+ window_changed = update_window_and_padding(win, input_access, output_access, mean_access, var_access);
+ }
+ }
+ output_access.set_valid_region(win, input->valid_region());
+
+ Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
+ return std::make_pair(err, win);
+}
+} // namespace
+
+GCBatchNormalizationLayerKernel::GCBatchNormalizationLayerKernel()
+ : _input(nullptr), _output(nullptr), _mean(nullptr), _var(nullptr), _beta(nullptr), _gamma(nullptr), _epsilon(0.0f)
+{
+}
+
+void GCBatchNormalizationLayerKernel::configure(const IGCTensor *input, IGCTensor *output, const IGCTensor *mean, const IGCTensor *var, const IGCTensor *beta, const IGCTensor *gamma,
+ float epsilon, ActivationLayerInfo act_info)
+{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, output, mean, var);
+
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), mean->info(), var->info(),
+ (beta != nullptr) ? beta->info() : nullptr, (gamma != nullptr) ? gamma->info() : nullptr,
+ epsilon, act_info));
_input = input;
_output = output;
@@ -71,12 +157,6 @@ void GCBatchNormalizationLayerKernel::configure(const IGCTensor *input, IGCTenso
_gamma = gamma;
_epsilon = epsilon;
- unsigned int num_elems_processed_per_iteration = 1;
- if(input->info()->data_type() == DataType::F16)
- {
- num_elems_processed_per_iteration = 4;
- }
-
// Set build options
std::set<std::string> build_opts;
std::string dt_name = (input->info()->data_type() == DataType::F32) ? "DATA_TYPE_FP32" : "DATA_TYPE_FP16";
@@ -85,6 +165,14 @@ void GCBatchNormalizationLayerKernel::configure(const IGCTensor *input, IGCTenso
build_opts.emplace(("#define LOCAL_SIZE_X " + support::cpp11::to_string(1)));
build_opts.emplace(("#define LOCAL_SIZE_Y " + support::cpp11::to_string(1)));
build_opts.emplace(("#define LOCAL_SIZE_Z " + support::cpp11::to_string(1)));
+ if(beta == nullptr)
+ {
+ build_opts.emplace("#define USE_DEFAULT_BETA");
+ }
+ if(gamma == nullptr)
+ {
+ build_opts.emplace("#define USE_DEFAULT_GAMMA");
+ }
if(act_info.enabled())
{
@@ -97,19 +185,25 @@ void GCBatchNormalizationLayerKernel::configure(const IGCTensor *input, IGCTenso
_kernel = static_cast<GCKernel>(GCKernelLibrary::get().create_kernel("batchnormalization_layer", build_opts));
// Configure kernel window
- Window win = calculate_max_window(*input->info(), Steps(num_elems_processed_per_iteration));
+ auto win_config = validate_and_configure_window(input->info(), output->info(), mean->info(), var->info(),
+ (beta != nullptr) ? beta->info() : nullptr, (gamma != nullptr) ? gamma->info() : nullptr);
+ ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- AccessWindowHorizontal input_access(input->info(), 0, num_elems_processed_per_iteration);
- AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration);
- AccessWindowStatic mean_access(mean->info(), 0, 0, mean->info()->dimension(0) + 3, mean->info()->dimension(1));
- AccessWindowStatic var_access(var->info(), 0, 0, var->info()->dimension(0) + 3, var->info()->dimension(1));
- AccessWindowStatic beta_access(beta->info(), 0, 0, beta->info()->dimension(0) + 3, beta->info()->dimension(1));
- AccessWindowStatic gamma_access(gamma->info(), 0, 0, gamma->info()->dimension(0) + 3, gamma->info()->dimension(1));
+ IGCKernel::configure(win_config.second);
+}
- update_window_and_padding(win, input_access, output_access, mean_access, var_access, beta_access, gamma_access);
- output_access.set_valid_region(win, input->info()->valid_region());
+Status GCBatchNormalizationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output,
+ const ITensorInfo *mean, const ITensorInfo *var,
+ const ITensorInfo *beta, const ITensorInfo *gamma,
+ float epsilon, ActivationLayerInfo act_info)
+{
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, mean, var, beta, gamma, epsilon, act_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get(),
+ mean->clone().get(), var->clone().get(),
+ beta->clone().get(), gamma->clone().get())
+ .first);
- IGCKernel::configure(win);
+ return Status{};
}
void GCBatchNormalizationLayerKernel::run(const Window &window)
@@ -127,11 +221,18 @@ void GCBatchNormalizationLayerKernel::run(const Window &window)
Window vector_slice = window.first_slice_window_1D();
vector_slice.set(Window::DimX, Window::Dimension(0, 0, 0));
- unsigned int idx = 2 * num_arguments_per_3D_tensor();
- add_1D_tensor_argument(idx, _mean, 3, vector_slice);
- add_1D_tensor_argument(idx, _var, 4, vector_slice);
- add_1D_tensor_argument(idx, _beta, 5, vector_slice);
- add_1D_tensor_argument(idx, _gamma, 6, vector_slice);
+ unsigned int idx = 2 * num_arguments_per_3D_tensor();
+ unsigned int binding_point = 3;
+ add_1D_tensor_argument(idx, _mean, binding_point, vector_slice);
+ add_1D_tensor_argument(idx, _var, ++binding_point, vector_slice);
+ if(_beta != nullptr)
+ {
+ add_1D_tensor_argument(idx, _beta, ++binding_point, vector_slice);
+ }
+ if(_gamma != nullptr)
+ {
+ add_1D_tensor_argument(idx, _gamma, ++binding_point, vector_slice);
+ }
slice.shift(Window::DimX, -(_output->info()->padding()).left);