diff options
author | Michele Di Giorgio <michele.digiorgio@arm.com> | 2018-03-02 09:43:54 +0000 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:49:37 +0000 |
commit | 4d33630096c769dd43716dd5607f151e3d5abef7 (patch) | |
tree | 762897c2acac9553c0dad688d0c21842c8edff16 /src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp | |
parent | 1cd41495153c4e89d6195b42f870967339c1a13b (diff) | |
download | ComputeLibrary-4d33630096c769dd43716dd5607f151e3d5abef7.tar.gz |
COMPMID-987: Make beta and gamma optional in BatchNormalization
Currently we have beta and gamma compulsory in Batch normalization. There are
network that might not need one or both of those. Thus these should be optional
with beta(offset) defaulting to zero and gamma(scale) to 1. Will also reduce
some memory requirements.
Change-Id: I15bf1ec14b814be2acebf1be1a4fba9c4fbd3190
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/123237
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp | 130 |
1 files changed, 82 insertions, 48 deletions
diff --git a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp index 1f730a2c3c..d1bdfac2da 100644 --- a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp +++ b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp @@ -62,9 +62,21 @@ validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const IT ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output); } - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, mean, var, beta, gamma); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, mean, var, beta, gamma); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, var, beta, gamma); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, mean, var); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, mean, var); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, var); + if(beta != nullptr) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, beta); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, beta); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, beta); + } + if(gamma != nullptr) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, gamma); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, gamma); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, gamma); + } ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(2) != mean->dimension(0)); return Status{}; @@ -72,6 +84,12 @@ validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const IT std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output) { + if(output != nullptr) + { + // Output tensor auto initialization if not yet initialized + auto_init_if_empty(*output, *input->clone()); + } + unsigned int num_elems_processed_per_iteration = 16 / input->element_size(); Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration)); @@ -99,13 +117,13 @@ void NEBatchNormalizationLayerKernel::batch_normalization_qs8(const Window &wind const int fixed_point_position = _input->info()->fixed_point_position(); const auto input_mean = reinterpret_cast<const qint8_t *>(_mean->ptr_to_element(Coordinates(0, 0))); const auto input_var = reinterpret_cast<const qint8_t *>(_var->ptr_to_element(Coordinates(0, 0))); - const auto input_gamma = reinterpret_cast<const qint8_t *>(_gamma->ptr_to_element(Coordinates(0, 0))); - const auto input_beta = reinterpret_cast<const qint8_t *>(_beta->ptr_to_element(Coordinates(0, 0))); + const auto input_gamma = (_gamma != nullptr) ? reinterpret_cast<const qint8_t *>(_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr; + const auto input_beta = (_beta != nullptr) ? reinterpret_cast<const qint8_t *>(_beta->ptr_to_element(Coordinates(0, 0))) : nullptr; qint8x16_t mean_vec = vdupq_n_qs8(0); qint8x16_t var_vec = vdupq_n_qs8(0); - qint8x16_t gamma_vec = vdupq_n_qs8(0); - qint8x16_t beta_vec = vdupq_n_qs8(0); + qint8x16_t gamma_vec = vdupq_n_qs8(sqcvt_qs8_f32(1, fixed_point_position)); + qint8x16_t beta_vec = vdupq_n_qs8(sqcvt_qs8_f32(0, fixed_point_position)); qint8x16_t denominator = vdupq_n_qs8(0); const qint8x16_t epsilon_vec = vdupq_n_qs8(sqcvt_qs8_f32(_epsilon, fixed_point_position)); execute_window_loop(window, [&](const Coordinates & id) @@ -113,10 +131,16 @@ void NEBatchNormalizationLayerKernel::batch_normalization_qs8(const Window &wind if(slice != id.z()) { // Conctruct vectors - mean_vec = vdupq_n_qs8(*(input_mean + id.z())); - var_vec = vdupq_n_qs8(*(input_var + id.z())); - gamma_vec = vdupq_n_qs8(*(input_gamma + id.z())); - beta_vec = vdupq_n_qs8(*(input_beta + id.z())); + mean_vec = vdupq_n_qs8(*(input_mean + id.z())); + var_vec = vdupq_n_qs8(*(input_var + id.z())); + if(input_gamma != nullptr) + { + gamma_vec = vdupq_n_qs8(*(input_gamma + id.z())); + } + if(input_beta != nullptr) + { + beta_vec = vdupq_n_qs8(*(input_beta + id.z())); + } // Calculate denominator denominator = vqinvsqrtq_qs8(vqaddq_qs8(var_vec, epsilon_vec), fixed_point_position); @@ -146,13 +170,13 @@ void NEBatchNormalizationLayerKernel::batch_normalization_qs16(const Window &win const int fixed_point_position = _input->info()->fixed_point_position(); const auto input_mean = reinterpret_cast<const qint16_t *>(_mean->ptr_to_element(Coordinates(0, 0))); const auto input_var = reinterpret_cast<const qint16_t *>(_var->ptr_to_element(Coordinates(0, 0))); - const auto input_gamma = reinterpret_cast<const qint16_t *>(_gamma->ptr_to_element(Coordinates(0, 0))); - const auto input_beta = reinterpret_cast<const qint16_t *>(_beta->ptr_to_element(Coordinates(0, 0))); + const auto input_gamma = (_gamma != nullptr) ? reinterpret_cast<const qint16_t *>(_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr; + const auto input_beta = (_beta != nullptr) ? reinterpret_cast<const qint16_t *>(_beta->ptr_to_element(Coordinates(0, 0))) : nullptr; qint16x8_t mean_vec = vdupq_n_qs16(0); qint16x8_t var_vec = vdupq_n_qs16(0); - qint16x8_t gamma_vec = vdupq_n_qs16(0); - qint16x8_t beta_vec = vdupq_n_qs16(0); + qint16x8_t gamma_vec = vdupq_n_qs16(sqcvt_qs16_f32(1, fixed_point_position)); + qint16x8_t beta_vec = vdupq_n_qs16(sqcvt_qs16_f32(0, fixed_point_position)); qint16x8_t denominator = vdupq_n_qs16(0); const qint16x8_t epsilon_vec = vdupq_n_qs16(sqcvt_qs16_f32(_epsilon, fixed_point_position)); execute_window_loop(window, [&](const Coordinates & id) @@ -160,10 +184,16 @@ void NEBatchNormalizationLayerKernel::batch_normalization_qs16(const Window &win if(slice != id.z()) { // Conctruct vectors - mean_vec = vdupq_n_qs16(*(input_mean + id.z())); - var_vec = vdupq_n_qs16(*(input_var + id.z())); - gamma_vec = vdupq_n_qs16(*(input_gamma + id.z())); - beta_vec = vdupq_n_qs16(*(input_beta + id.z())); + mean_vec = vdupq_n_qs16(*(input_mean + id.z())); + var_vec = vdupq_n_qs16(*(input_var + id.z())); + if(input_gamma != nullptr) + { + gamma_vec = vdupq_n_qs16(*(input_gamma + id.z())); + } + if(input_beta != nullptr) + { + beta_vec = vdupq_n_qs16(*(input_beta + id.z())); + } // Calculate denominator denominator = vqinvsqrtq_qs16(vqaddq_qs16(var_vec, epsilon_vec), fixed_point_position); @@ -194,12 +224,12 @@ void NEBatchNormalizationLayerKernel::batch_normalization_fp16(const Window &win const auto input_mean = reinterpret_cast<const float16_t *>(_mean->ptr_to_element(Coordinates(0, 0))); const auto input_var = reinterpret_cast<const float16_t *>(_var->ptr_to_element(Coordinates(0, 0))); - const auto input_gamma = reinterpret_cast<const float16_t *>(_gamma->ptr_to_element(Coordinates(0, 0))); - const auto input_beta = reinterpret_cast<const float16_t *>(_beta->ptr_to_element(Coordinates(0, 0))); + const auto input_gamma = (_gamma != nullptr) ? reinterpret_cast<const float16_t *>(_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr; + const auto input_beta = (_beta != nullptr) ? reinterpret_cast<const float16_t *>(_beta->ptr_to_element(Coordinates(0, 0))) : nullptr; float16x8_t mean_vec = vdupq_n_f16(0.0); float16x8_t var_vec = vdupq_n_f16(0.0); - float16x8_t gamma_vec = vdupq_n_f16(0.0); + float16x8_t gamma_vec = vdupq_n_f16(1.0); float16x8_t beta_vec = vdupq_n_f16(0.0); float16x8_t denominator = vdupq_n_f16(0.0); const float16x8_t epsilon_vec = vdupq_n_f16(_epsilon); @@ -208,10 +238,16 @@ void NEBatchNormalizationLayerKernel::batch_normalization_fp16(const Window &win if(slice != id.z()) { // Conctruct vectors - mean_vec = vdupq_n_f16(*(input_mean + id.z())); - var_vec = vdupq_n_f16(*(input_var + id.z())); - gamma_vec = vdupq_n_f16(*(input_gamma + id.z())); - beta_vec = vdupq_n_f16(*(input_beta + id.z())); + mean_vec = vdupq_n_f16(*(input_mean + id.z())); + var_vec = vdupq_n_f16(*(input_var + id.z())); + if(input_gamma != nullptr) + { + gamma_vec = vdupq_n_f16(*(input_gamma + id.z())); + } + if(input_beta != nullptr) + { + beta_vec = vdupq_n_f16(*(input_beta + id.z())); + } // Calculate denominator denominator = vinvsqrtq_f16(vaddq_f16(var_vec, epsilon_vec)); @@ -241,12 +277,12 @@ void NEBatchNormalizationLayerKernel::batch_normalization_fp32(const Window &win const auto input_mean = reinterpret_cast<const float *>(_mean->ptr_to_element(Coordinates(0, 0))); const auto input_var = reinterpret_cast<const float *>(_var->ptr_to_element(Coordinates(0, 0))); - const auto input_gamma = reinterpret_cast<const float *>(_gamma->ptr_to_element(Coordinates(0, 0))); - const auto input_beta = reinterpret_cast<const float *>(_beta->ptr_to_element(Coordinates(0, 0))); + const auto input_gamma = (_gamma != nullptr) ? reinterpret_cast<const float *>(_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr; + const auto input_beta = (_beta != nullptr) ? reinterpret_cast<const float *>(_beta->ptr_to_element(Coordinates(0, 0))) : nullptr; float32x4_t mean_vec = vdupq_n_f32(0.0); float32x4_t var_vec = vdupq_n_f32(0.0); - float32x4_t gamma_vec = vdupq_n_f32(0.0); + float32x4_t gamma_vec = vdupq_n_f32(1.0); float32x4_t beta_vec = vdupq_n_f32(0.0); float32x4_t denominator = vdupq_n_f32(0.0); const float32x4_t epsilon_vec = vdupq_n_f32(_epsilon); @@ -255,10 +291,16 @@ void NEBatchNormalizationLayerKernel::batch_normalization_fp32(const Window &win if(slice != id.z()) { // Conctruct vectors - mean_vec = vdupq_n_f32(*(input_mean + id.z())); - var_vec = vdupq_n_f32(*(input_var + id.z())); - gamma_vec = vdupq_n_f32(*(input_gamma + id.z())); - beta_vec = vdupq_n_f32(*(input_beta + id.z())); + mean_vec = vdupq_n_f32(*(input_mean + id.z())); + var_vec = vdupq_n_f32(*(input_var + id.z())); + if(input_gamma != nullptr) + { + gamma_vec = vdupq_n_f32(*(input_gamma + id.z())); + } + if(input_beta != nullptr) + { + beta_vec = vdupq_n_f32(*(input_beta + id.z())); + } // Calculate denominator denominator = vinvsqrtq_f32(vaddq_f32(var_vec, epsilon_vec)); @@ -335,21 +377,12 @@ void NEBatchNormalizationLayerKernel::configure(ITensor *input, ITensor *output, const ITensor *beta, const ITensor *gamma, float epsilon, ActivationLayerInfo act_info) { - ARM_COMPUTE_ERROR_ON_NULLPTR(input, mean, var, beta, gamma); + ARM_COMPUTE_ERROR_ON_NULLPTR(input, mean, var); - ITensorInfo *output_info = nullptr; - - if(nullptr != output) - { - // Output tensor auto initialization if not yet initialized - auto_init_if_empty(*output->info(), *input->info()); - - output_info = output->info(); - } - - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output_info, + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (output != nullptr) ? output->info() : nullptr, mean->info(), var->info(), - beta->info(), gamma->info(), + (beta != nullptr) ? beta->info() : nullptr, + (gamma != nullptr) ? gamma->info() : nullptr, epsilon, act_info)); _input = input; @@ -361,7 +394,8 @@ void NEBatchNormalizationLayerKernel::configure(ITensor *input, ITensor *output, _epsilon = epsilon; _act_info = act_info; - if(output != nullptr) + const bool run_in_place = (output == nullptr) || (output == input); + if(!run_in_place) { _output = output; } @@ -377,7 +411,7 @@ void NEBatchNormalizationLayerKernel::configure(ITensor *input, ITensor *output, } // Configure kernel window - auto win_config = validate_and_configure_window(input->info(), output_info); + auto win_config = validate_and_configure_window(input->info(), (run_in_place) ? nullptr : output->info()); ARM_COMPUTE_ERROR_THROW_ON(win_config.first); INEKernel::configure(win_config.second); } |