diff options
Diffstat (limited to 'src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp | 130 |
1 files changed, 82 insertions, 48 deletions
diff --git a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp index 1f730a2c3c..d1bdfac2da 100644 --- a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp +++ b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp @@ -62,9 +62,21 @@ validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const IT ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output); } - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, mean, var, beta, gamma); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, mean, var, beta, gamma); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, var, beta, gamma); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, mean, var); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, mean, var); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, var); + if(beta != nullptr) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, beta); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, beta); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, beta); + } + if(gamma != nullptr) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, gamma); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, gamma); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, gamma); + } ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(2) != mean->dimension(0)); return Status{}; @@ -72,6 +84,12 @@ validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const IT std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output) { + if(output != nullptr) + { + // Output tensor auto initialization if not yet initialized + auto_init_if_empty(*output, *input->clone()); + } + unsigned int num_elems_processed_per_iteration = 16 / input->element_size(); Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration)); @@ -99,13 +117,13 @@ void NEBatchNormalizationLayerKernel::batch_normalization_qs8(const Window &wind const int fixed_point_position = _input->info()->fixed_point_position(); const auto input_mean = reinterpret_cast<const qint8_t *>(_mean->ptr_to_element(Coordinates(0, 0))); const auto input_var = reinterpret_cast<const qint8_t *>(_var->ptr_to_element(Coordinates(0, 0))); - const auto input_gamma = reinterpret_cast<const qint8_t *>(_gamma->ptr_to_element(Coordinates(0, 0))); - const auto input_beta = reinterpret_cast<const qint8_t *>(_beta->ptr_to_element(Coordinates(0, 0))); + const auto input_gamma = (_gamma != nullptr) ? reinterpret_cast<const qint8_t *>(_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr; + const auto input_beta = (_beta != nullptr) ? reinterpret_cast<const qint8_t *>(_beta->ptr_to_element(Coordinates(0, 0))) : nullptr; qint8x16_t mean_vec = vdupq_n_qs8(0); qint8x16_t var_vec = vdupq_n_qs8(0); - qint8x16_t gamma_vec = vdupq_n_qs8(0); - qint8x16_t beta_vec = vdupq_n_qs8(0); + qint8x16_t gamma_vec = vdupq_n_qs8(sqcvt_qs8_f32(1, fixed_point_position)); + qint8x16_t beta_vec = vdupq_n_qs8(sqcvt_qs8_f32(0, fixed_point_position)); qint8x16_t denominator = vdupq_n_qs8(0); const qint8x16_t epsilon_vec = vdupq_n_qs8(sqcvt_qs8_f32(_epsilon, fixed_point_position)); execute_window_loop(window, [&](const Coordinates & id) @@ -113,10 +131,16 @@ void NEBatchNormalizationLayerKernel::batch_normalization_qs8(const Window &wind if(slice != id.z()) { // Conctruct vectors - mean_vec = vdupq_n_qs8(*(input_mean + id.z())); - var_vec = vdupq_n_qs8(*(input_var + id.z())); - gamma_vec = vdupq_n_qs8(*(input_gamma + id.z())); - beta_vec = vdupq_n_qs8(*(input_beta + id.z())); + mean_vec = vdupq_n_qs8(*(input_mean + id.z())); + var_vec = vdupq_n_qs8(*(input_var + id.z())); + if(input_gamma != nullptr) + { + gamma_vec = vdupq_n_qs8(*(input_gamma + id.z())); + } + if(input_beta != nullptr) + { + beta_vec = vdupq_n_qs8(*(input_beta + id.z())); + } // Calculate denominator denominator = vqinvsqrtq_qs8(vqaddq_qs8(var_vec, epsilon_vec), fixed_point_position); @@ -146,13 +170,13 @@ void NEBatchNormalizationLayerKernel::batch_normalization_qs16(const Window &win const int fixed_point_position = _input->info()->fixed_point_position(); const auto input_mean = reinterpret_cast<const qint16_t *>(_mean->ptr_to_element(Coordinates(0, 0))); const auto input_var = reinterpret_cast<const qint16_t *>(_var->ptr_to_element(Coordinates(0, 0))); - const auto input_gamma = reinterpret_cast<const qint16_t *>(_gamma->ptr_to_element(Coordinates(0, 0))); - const auto input_beta = reinterpret_cast<const qint16_t *>(_beta->ptr_to_element(Coordinates(0, 0))); + const auto input_gamma = (_gamma != nullptr) ? reinterpret_cast<const qint16_t *>(_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr; + const auto input_beta = (_beta != nullptr) ? reinterpret_cast<const qint16_t *>(_beta->ptr_to_element(Coordinates(0, 0))) : nullptr; qint16x8_t mean_vec = vdupq_n_qs16(0); qint16x8_t var_vec = vdupq_n_qs16(0); - qint16x8_t gamma_vec = vdupq_n_qs16(0); - qint16x8_t beta_vec = vdupq_n_qs16(0); + qint16x8_t gamma_vec = vdupq_n_qs16(sqcvt_qs16_f32(1, fixed_point_position)); + qint16x8_t beta_vec = vdupq_n_qs16(sqcvt_qs16_f32(0, fixed_point_position)); qint16x8_t denominator = vdupq_n_qs16(0); const qint16x8_t epsilon_vec = vdupq_n_qs16(sqcvt_qs16_f32(_epsilon, fixed_point_position)); execute_window_loop(window, [&](const Coordinates & id) @@ -160,10 +184,16 @@ void NEBatchNormalizationLayerKernel::batch_normalization_qs16(const Window &win if(slice != id.z()) { // Conctruct vectors - mean_vec = vdupq_n_qs16(*(input_mean + id.z())); - var_vec = vdupq_n_qs16(*(input_var + id.z())); - gamma_vec = vdupq_n_qs16(*(input_gamma + id.z())); - beta_vec = vdupq_n_qs16(*(input_beta + id.z())); + mean_vec = vdupq_n_qs16(*(input_mean + id.z())); + var_vec = vdupq_n_qs16(*(input_var + id.z())); + if(input_gamma != nullptr) + { + gamma_vec = vdupq_n_qs16(*(input_gamma + id.z())); + } + if(input_beta != nullptr) + { + beta_vec = vdupq_n_qs16(*(input_beta + id.z())); + } // Calculate denominator denominator = vqinvsqrtq_qs16(vqaddq_qs16(var_vec, epsilon_vec), fixed_point_position); @@ -194,12 +224,12 @@ void NEBatchNormalizationLayerKernel::batch_normalization_fp16(const Window &win const auto input_mean = reinterpret_cast<const float16_t *>(_mean->ptr_to_element(Coordinates(0, 0))); const auto input_var = reinterpret_cast<const float16_t *>(_var->ptr_to_element(Coordinates(0, 0))); - const auto input_gamma = reinterpret_cast<const float16_t *>(_gamma->ptr_to_element(Coordinates(0, 0))); - const auto input_beta = reinterpret_cast<const float16_t *>(_beta->ptr_to_element(Coordinates(0, 0))); + const auto input_gamma = (_gamma != nullptr) ? reinterpret_cast<const float16_t *>(_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr; + const auto input_beta = (_beta != nullptr) ? reinterpret_cast<const float16_t *>(_beta->ptr_to_element(Coordinates(0, 0))) : nullptr; float16x8_t mean_vec = vdupq_n_f16(0.0); float16x8_t var_vec = vdupq_n_f16(0.0); - float16x8_t gamma_vec = vdupq_n_f16(0.0); + float16x8_t gamma_vec = vdupq_n_f16(1.0); float16x8_t beta_vec = vdupq_n_f16(0.0); float16x8_t denominator = vdupq_n_f16(0.0); const float16x8_t epsilon_vec = vdupq_n_f16(_epsilon); @@ -208,10 +238,16 @@ void NEBatchNormalizationLayerKernel::batch_normalization_fp16(const Window &win if(slice != id.z()) { // Conctruct vectors - mean_vec = vdupq_n_f16(*(input_mean + id.z())); - var_vec = vdupq_n_f16(*(input_var + id.z())); - gamma_vec = vdupq_n_f16(*(input_gamma + id.z())); - beta_vec = vdupq_n_f16(*(input_beta + id.z())); + mean_vec = vdupq_n_f16(*(input_mean + id.z())); + var_vec = vdupq_n_f16(*(input_var + id.z())); + if(input_gamma != nullptr) + { + gamma_vec = vdupq_n_f16(*(input_gamma + id.z())); + } + if(input_beta != nullptr) + { + beta_vec = vdupq_n_f16(*(input_beta + id.z())); + } // Calculate denominator denominator = vinvsqrtq_f16(vaddq_f16(var_vec, epsilon_vec)); @@ -241,12 +277,12 @@ void NEBatchNormalizationLayerKernel::batch_normalization_fp32(const Window &win const auto input_mean = reinterpret_cast<const float *>(_mean->ptr_to_element(Coordinates(0, 0))); const auto input_var = reinterpret_cast<const float *>(_var->ptr_to_element(Coordinates(0, 0))); - const auto input_gamma = reinterpret_cast<const float *>(_gamma->ptr_to_element(Coordinates(0, 0))); - const auto input_beta = reinterpret_cast<const float *>(_beta->ptr_to_element(Coordinates(0, 0))); + const auto input_gamma = (_gamma != nullptr) ? reinterpret_cast<const float *>(_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr; + const auto input_beta = (_beta != nullptr) ? reinterpret_cast<const float *>(_beta->ptr_to_element(Coordinates(0, 0))) : nullptr; float32x4_t mean_vec = vdupq_n_f32(0.0); float32x4_t var_vec = vdupq_n_f32(0.0); - float32x4_t gamma_vec = vdupq_n_f32(0.0); + float32x4_t gamma_vec = vdupq_n_f32(1.0); float32x4_t beta_vec = vdupq_n_f32(0.0); float32x4_t denominator = vdupq_n_f32(0.0); const float32x4_t epsilon_vec = vdupq_n_f32(_epsilon); @@ -255,10 +291,16 @@ void NEBatchNormalizationLayerKernel::batch_normalization_fp32(const Window &win if(slice != id.z()) { // Conctruct vectors - mean_vec = vdupq_n_f32(*(input_mean + id.z())); - var_vec = vdupq_n_f32(*(input_var + id.z())); - gamma_vec = vdupq_n_f32(*(input_gamma + id.z())); - beta_vec = vdupq_n_f32(*(input_beta + id.z())); + mean_vec = vdupq_n_f32(*(input_mean + id.z())); + var_vec = vdupq_n_f32(*(input_var + id.z())); + if(input_gamma != nullptr) + { + gamma_vec = vdupq_n_f32(*(input_gamma + id.z())); + } + if(input_beta != nullptr) + { + beta_vec = vdupq_n_f32(*(input_beta + id.z())); + } // Calculate denominator denominator = vinvsqrtq_f32(vaddq_f32(var_vec, epsilon_vec)); @@ -335,21 +377,12 @@ void NEBatchNormalizationLayerKernel::configure(ITensor *input, ITensor *output, const ITensor *beta, const ITensor *gamma, float epsilon, ActivationLayerInfo act_info) { - ARM_COMPUTE_ERROR_ON_NULLPTR(input, mean, var, beta, gamma); + ARM_COMPUTE_ERROR_ON_NULLPTR(input, mean, var); - ITensorInfo *output_info = nullptr; - - if(nullptr != output) - { - // Output tensor auto initialization if not yet initialized - auto_init_if_empty(*output->info(), *input->info()); - - output_info = output->info(); - } - - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output_info, + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (output != nullptr) ? output->info() : nullptr, mean->info(), var->info(), - beta->info(), gamma->info(), + (beta != nullptr) ? beta->info() : nullptr, + (gamma != nullptr) ? gamma->info() : nullptr, epsilon, act_info)); _input = input; @@ -361,7 +394,8 @@ void NEBatchNormalizationLayerKernel::configure(ITensor *input, ITensor *output, _epsilon = epsilon; _act_info = act_info; - if(output != nullptr) + const bool run_in_place = (output == nullptr) || (output == input); + if(!run_in_place) { _output = output; } @@ -377,7 +411,7 @@ void NEBatchNormalizationLayerKernel::configure(ITensor *input, ITensor *output, } // Configure kernel window - auto win_config = validate_and_configure_window(input->info(), output_info); + auto win_config = validate_and_configure_window(input->info(), (run_in_place) ? nullptr : output->info()); ARM_COMPUTE_ERROR_THROW_ON(win_config.first); INEKernel::configure(win_config.second); } |