From 980a9168b81d778f4902973b4920b54c103907e0 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Wed, 3 Jun 2020 20:16:46 +0100 Subject: COMPMID-3177: Remove padding from NEBatchNormalizationLayer Signed-off-by: Georgios Pinitas Change-Id: I9be23e6ef1f552eb159e39fda16c82fa20124094 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3307 Tested-by: Arm Jenkins Reviewed-by: Gian Marco Iodice Comments-Addressed: Arm Jenkins --- .../NEON/kernels/NEBatchNormalizationLayerKernel.h | 29 +- .../kernels/detail/NEActivationFunctionDetail.h | 108 ++++-- .../kernels/NEBatchNormalizationLayerKernel.cpp | 372 +++++++++------------ tests/validation/NEON/BatchNormalizationLayer.cpp | 43 +-- 4 files changed, 266 insertions(+), 286 deletions(-) diff --git a/arm_compute/core/NEON/kernels/NEBatchNormalizationLayerKernel.h b/arm_compute/core/NEON/kernels/NEBatchNormalizationLayerKernel.h index d59ed7baf0..7371e3c177 100644 --- a/arm_compute/core/NEON/kernels/NEBatchNormalizationLayerKernel.h +++ b/arm_compute/core/NEON/kernels/NEBatchNormalizationLayerKernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -28,6 +28,7 @@ namespace arm_compute { +// Forward declarations class ITensor; /** Interface for the batch normalization layer kernel. @@ -97,40 +98,26 @@ private: /** Configure execution function in case of fused activation **/ void configure_fused(); - /** Template function to run batch normalization on fp16 - * - * @tparam fused_activation Boolean that flags if its a fused activation or not - * - * @param[in] window Region on which to execute the kernel. (Must be a valid region of the window returned by window()). - */ - template - void batch_normalization_fp16_nchw(const Window &window); - /** Template function to run batch normalization on fp16 on tensors with NHWC format - * - * @tparam fused_activation Boolean that flags if its a fused activation or not - * - * @param[in] window Region on which to execute the kernel. (Must be a valid region of the window returned by window()). - */ - template - void batch_normalization_fp16_nhwc(const Window &window); /** Template function to run batch normalization on fp32 * + * @tparam T Specialization data type * @tparam fused_activation Boolean that flags if its a fused activation or not * @tparam F Activation function functor to run * * @param[in] window Region on which to execute the kernel. (Must be a valid region of the window returned by window()). */ - template - void batch_normalization_fp32_nchw(const Window &window); + template + void batch_normalization_nchw(const Window &window); /** Template function to run batch normalization on fp32 on tensors with NHWC format * + * @tparam T Specialization data type * @tparam fused_activation Boolean that flags if its a fused activation or not * @tparam F Activation function functor to run * * @param[in] window Region on which to execute the kernel. (Must be a valid region of the window returned by window()). */ - template - void batch_normalization_fp32_nhwc(const Window &window); + template + void batch_normalization_nhwc(const Window &window); /** Common signature for all the batch normalization functions * * @param[in] window Region on which to execute the kernel. diff --git a/arm_compute/core/NEON/kernels/detail/NEActivationFunctionDetail.h b/arm_compute/core/NEON/kernels/detail/NEActivationFunctionDetail.h index 4861559695..7945418ac5 100644 --- a/arm_compute/core/NEON/kernels/detail/NEActivationFunctionDetail.h +++ b/arm_compute/core/NEON/kernels/detail/NEActivationFunctionDetail.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 ARM Limited. + * Copyright (c) 2018-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -45,6 +45,7 @@ struct dummy { ARM_COMPUTE_UNUSED(act_info); } + /** Run activation function. * * @param[in] vval Vector of values. @@ -53,6 +54,15 @@ struct dummy { ARM_COMPUTE_UNUSED(vval); } + + /** Run activation function. + * + * @param[in] val Scalar value. + */ + void operator()(T &val) + { + ARM_COMPUTE_UNUSED(val); + } }; /** Linear activation object */ template @@ -68,8 +78,10 @@ struct linear * @param[in] act_info Activation layer information. */ explicit linear(ActivationLayerInfo act_info) - : valpha(wrapper::vdup_n(static_cast(act_info.a()), ExactTagType{})), - vbeta(wrapper::vdup_n(static_cast(act_info.b()), ExactTagType{})) + : alpha(act_info.a()), + beta(act_info.b()), + valpha(wrapper::vdup_n(static_cast(alpha), ExactTagType{})), + vbeta(wrapper::vdup_n(static_cast(beta), ExactTagType{})) { } @@ -79,13 +91,22 @@ struct linear */ void operator()(ExactType &vval) { - vval = wrapper::vmla(vval, valpha, vbeta); + vval = wrapper::vmla(vbeta, vval, valpha); } - /** Vector of alphas. */ - const ExactType valpha; - /** Vector of betas. */ - const ExactType vbeta; + /** Run activation function. + * + * @param[in] val Scalar value. + */ + void operator()(T &val) + { + val = alpha * val + beta; + } + + const T alpha; /**< Scalar alpha */ + const T beta; /**< Scalar alpha */ + const ExactType valpha; /**< Vector of alphas. */ + const ExactType vbeta; /**< Vector of betas. */ }; /** Square activation object */ template @@ -113,6 +134,15 @@ struct square { vval = wrapper::vmul(vval, vval); } + + /** Run activation function. + * + * @param[in] val Scalar value. + */ + void operator()(T &val) + { + val = val * val; + } }; /** Logistic activation object */ template @@ -128,7 +158,7 @@ struct logistic * @param[in] act_info Activation layer information. */ explicit logistic(ActivationLayerInfo act_info) - : vone(wrapper::vdup_n(static_cast(1.f), ExactTagType{})) + : vone(wrapper::vdup_n(static_cast(1), ExactTagType{})) { ARM_COMPUTE_UNUSED(act_info); } @@ -142,6 +172,15 @@ struct logistic vval = wrapper::vinv(wrapper::vadd(vone, wrapper::vexpq(wrapper::vneg(vval)))); } + /** Run activation function. + * + * @param[in] val Scalar value. + */ + void operator()(T &val) + { + val = 1 / (1 + std::exp(-val)); + } + /** Vector of ones. */ const ExactType vone; }; @@ -159,7 +198,7 @@ struct relu * @param[in] act_info Activation layer information. */ explicit relu(ActivationLayerInfo act_info) - : vzero(wrapper::vdup_n(static_cast(0.f), ExactTagType{})) + : vzero(wrapper::vdup_n(static_cast(0), ExactTagType{})) { ARM_COMPUTE_UNUSED(act_info); } @@ -173,6 +212,15 @@ struct relu vval = wrapper::vmax(vzero, vval); } + /** Run activation function. + * + * @param[in] val Scalar value. + */ + void operator()(T &val) + { + val = std::max(static_cast(0), val); + } + /** Vector of zeroes. */ const ExactType vzero; }; @@ -190,7 +238,8 @@ struct brelu * @param[in] act_info Activation layer information. */ explicit brelu(ActivationLayerInfo act_info) - : vzero(wrapper::vdup_n(static_cast(0.f), ExactTagType{})), + : alpha(act_info.a()), + vzero(wrapper::vdup_n(static_cast(0), ExactTagType{})), valpha(wrapper::vdup_n(static_cast(act_info.a()), ExactTagType{})) { } @@ -204,10 +253,18 @@ struct brelu vval = wrapper::vmin(valpha, wrapper::vmax(vzero, vval)); } - /** Vector of zeroes. */ - const ExactType vzero; - /** Vector of alphas. */ - const ExactType valpha; + /** Run activation function. + * + * @param[in] val Scalar value. + */ + void operator()(T &val) + { + val = std::min(alpha, std::max(static_cast(0), val)); + } + + const T alpha; /** Scalar alpha */ + const ExactType vzero; /** Vector of zeroes. */ + const ExactType valpha; /** Vector of alphas. */ }; /** Lower-Upper Bounded RELU activation object */ template @@ -223,7 +280,9 @@ struct lubrelu * @param[in] act_info Activation layer information. */ explicit lubrelu(ActivationLayerInfo act_info) - : valpha(wrapper::vdup_n(static_cast(act_info.a()), ExactTagType{})), + : alpha(act_info.a()), + beta(act_info.b()), + valpha(wrapper::vdup_n(static_cast(act_info.a()), ExactTagType{})), vbeta(wrapper::vdup_n(static_cast(act_info.b()), ExactTagType{})) { } @@ -237,10 +296,19 @@ struct lubrelu vval = wrapper::vmin(valpha, wrapper::vmax(vbeta, vval)); } - /** Vector of alphas. */ - const ExactType valpha; - /** Vector of betas. */ - const ExactType vbeta; + /** Run activation function. + * + * @param[in] val Scalar value. + */ + void operator()(T &val) + { + val = std::min(alpha, std::max(beta, val)); + } + + const T alpha; /**< Scalar alpha */ + const T beta; /**< Scalar alpha */ + const ExactType valpha; /** Vector of alphas. */ + const ExactType vbeta; /** Vector of betas. */ }; } // namespace detail } // namespace arm_compute diff --git a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp index 6bd30ee845..3d84ce8449 100644 --- a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp +++ b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -33,10 +33,12 @@ #include "arm_compute/core/Validate.h" #include "arm_compute/core/Window.h" -#include +#include "arm_compute/core/NEON/wrapper/wrapper.h" -using namespace arm_compute; +#include +namespace arm_compute +{ namespace { Status @@ -82,56 +84,41 @@ validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const IT std::pair validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, ITensorInfo *mean, ITensorInfo *var, ITensorInfo *gamma, ITensorInfo *beta) { - if(output != nullptr) - { - // Output tensor auto initialization if not yet initialized - auto_init_if_empty(*output, *input->clone()); - } + ARM_COMPUTE_UNUSED(mean, var, gamma, beta); - unsigned int num_elems_processed_per_iteration = 16 / input->element_size(); - - Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration)); - AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration); - bool window_changed = update_window_and_padding(win, input_access); + // Configure kernel window + Window win = calculate_max_window(*input, Steps()); if(output != nullptr) { - AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration); - window_changed |= update_window_and_padding(win, output_access); - output_access.set_valid_region(win, input->valid_region()); - } - - // Mean, var, gamma and beta get parallelized for the NHWC case as they follow the channel dimension, which is along the first axis - if(input->data_layout() == DataLayout::NHWC) - { - AccessWindowHorizontal mean_access(mean, 0, num_elems_processed_per_iteration); - AccessWindowHorizontal var_access(var, 0, num_elems_processed_per_iteration); - window_changed |= update_window_and_padding(win, mean_access, var_access); + // Output auto initialization if not yet initialized + auto_init_if_empty(*output, *input->clone()); - if(gamma != nullptr) - { - AccessWindowHorizontal gamma_access(gamma, 0, num_elems_processed_per_iteration); - window_changed |= update_window_and_padding(win, gamma_access); - } - if(beta != nullptr) - { - AccessWindowHorizontal beta_access(beta, 0, num_elems_processed_per_iteration); - window_changed |= update_window_and_padding(win, beta_access); - } + // NEBatchNormalizationLayerKernel doesn't need padding so update_window_and_padding() can be skipped + Coordinates coord; + coord.set_num_dimensions(output->num_dimensions()); + output->set_valid_region(ValidRegion(coord, output->tensor_shape())); } - Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{}; - return std::make_pair(err, win); + return std::make_pair(Status{}, win); } } //namespace -template -void NEBatchNormalizationLayerKernel::batch_normalization_fp16_nchw(const Window &window) +template +void NEBatchNormalizationLayerKernel::batch_normalization_nchw(const Window &window) { - ARM_COMPUTE_UNUSED(window); -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - Iterator input(_input, window); - Iterator output(_output, window); + /** NEON vector tag type. */ + using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t; + + const int window_step_x = 16 / sizeof(T); + const auto window_start_x = static_cast(window.x().start()); + const auto window_end_x = static_cast(window.x().end()); + + Window win_to_use = window; + win_to_use.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator input(_input, win_to_use); + Iterator output(_output, win_to_use); F activation_functor(_act_info); @@ -139,196 +126,168 @@ void NEBatchNormalizationLayerKernel::batch_normalization_fp16_nchw(const Window // Only compute denominator and NEON vectors once per feature map. int slice = -1; - const auto input_mean = reinterpret_cast(_mean->ptr_to_element(Coordinates(0, 0))); - const auto input_var = reinterpret_cast(_var->ptr_to_element(Coordinates(0, 0))); - const auto input_gamma = (_gamma != nullptr) ? reinterpret_cast(_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr; - const auto input_beta = (_beta != nullptr) ? reinterpret_cast(_beta->ptr_to_element(Coordinates(0, 0))) : nullptr; - - float16x8_t mean_vec = vdupq_n_f16(0.0); - float16x8_t var_vec = vdupq_n_f16(0.0); - float16x8_t gamma_vec = vdupq_n_f16(1.0); - float16x8_t beta_vec = vdupq_n_f16(0.0); - float16x8_t denominator = vdupq_n_f16(0.0); - const float16x8_t epsilon_vec = vdupq_n_f16(_epsilon); - execute_window_loop(window, [&](const Coordinates & id) + const auto input_mean = reinterpret_cast(_mean->ptr_to_element(Coordinates(0, 0))); + const auto input_var = reinterpret_cast(_var->ptr_to_element(Coordinates(0, 0))); + const auto input_gamma = (_gamma != nullptr) ? reinterpret_cast(_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr; + const auto input_beta = (_beta != nullptr) ? reinterpret_cast(_beta->ptr_to_element(Coordinates(0, 0))) : nullptr; + + T mean = static_cast(0); + T var = static_cast(0); + T gamma = static_cast(1); + T beta = static_cast(0); + T denominator = static_cast(0); + + auto mean_vec = wrapper::vdup_n(mean, ExactTagType{}); + auto var_vec = wrapper::vdup_n(var, ExactTagType{}); + auto gamma_vec = wrapper::vdup_n(gamma, ExactTagType{}); + auto beta_vec = wrapper::vdup_n(beta, ExactTagType{}); + auto denominator_vec = wrapper::vdup_n(denominator, ExactTagType{}); + const auto epsilon_vec = wrapper::vdup_n(static_cast(_epsilon), ExactTagType{}); + execute_window_loop(win_to_use, [&](const Coordinates & id) { + const auto input_ptr = reinterpret_cast(input.ptr()); + const auto output_ptr = reinterpret_cast(output.ptr()); + if(slice != id.z()) { - // Conctruct vectors - mean_vec = vdupq_n_f16(*(input_mean + id.z())); - var_vec = vdupq_n_f16(*(input_var + id.z())); + mean = input_mean[id.z()]; + var = input_var[id.z()]; + mean_vec = wrapper::vdup_n(mean, ExactTagType{}); + var_vec = wrapper::vdup_n(var, ExactTagType{}); if(input_gamma != nullptr) { - gamma_vec = vdupq_n_f16(*(input_gamma + id.z())); + gamma = input_gamma[id.z()]; + gamma_vec = wrapper::vdup_n(gamma, ExactTagType{}); } if(input_beta != nullptr) { - beta_vec = vdupq_n_f16(*(input_beta + id.z())); + beta = input_beta[id.z()]; + beta_vec = wrapper::vdup_n(beta, ExactTagType{}); } // Calculate denominator - denominator = vinvsqrtq_f16(vaddq_f16(var_vec, epsilon_vec)); - slice = id.z(); + denominator_vec = wrapper::vinvsqrt(wrapper::vadd(var_vec, epsilon_vec)); + denominator = wrapper::vgetlane(denominator_vec, 0); + slice = id.z(); } - // Calculate x bar and store results - const float16x8_t numerator = vsubq_f16(vld1q_f16(reinterpret_cast(input.ptr())), mean_vec); - const float16x8_t x_bar = vmulq_f16(numerator, denominator); - float16x8_t res = vaddq_f16(beta_vec, vmulq_f16(x_bar, gamma_vec)); - - // Perform fused activation - if(fused_activation) + // Perform core calculations using vector operations + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) { - activation_functor(res); - } + // Calculate x bar + const auto numerator = wrapper::vsub(wrapper::vloadq(input_ptr + x), mean_vec); + const auto x_bar = wrapper::vmul(numerator, denominator_vec); + auto res = wrapper::vmla(beta_vec, x_bar, gamma_vec); - vst1q_f16(reinterpret_cast(output.ptr()), res); - }, - input, output); -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ -} + // Perform fused activation + if(fused_activation) + { + activation_functor(res); + } -template -void NEBatchNormalizationLayerKernel::batch_normalization_fp16_nhwc(const Window &window) -{ - ARM_COMPUTE_UNUSED(window); -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - Iterator input(_input, window); - Iterator output(_output, window); + // Store results + wrapper::vstore(output_ptr + x, res); + } - F activation_functor(_act_info); + // Compute left-over elements + for(; x < window_end_x; ++x) + { + const T numerator = input_ptr[x] - mean; + const T x_bar = numerator * denominator; + T res = beta + x_bar * gamma; - const auto input_mean = reinterpret_cast(_mean->ptr_to_element(Coordinates(0, 0))); - const auto input_var = reinterpret_cast(_var->ptr_to_element(Coordinates(0, 0))); - const auto input_gamma = (_gamma != nullptr) ? reinterpret_cast(_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr; - const auto input_beta = (_beta != nullptr) ? reinterpret_cast(_beta->ptr_to_element(Coordinates(0, 0))) : nullptr; + // Perform fused activation + if(fused_activation) + { + activation_functor(res); + } - const float16x8_t epsilon_vec = vdupq_n_f16(_epsilon); - execute_window_loop(window, [&](const Coordinates & id) - { - // Conctruct vectors - const float16x8_t mean_vec = vld1q_f16(input_mean + id.x()); - const float16x8_t var_vec = vld1q_f16(input_var + id.x()); - const float16x8_t gamma_vec = (input_gamma != nullptr) ? vld1q_f16(input_gamma + id.x()) : vdupq_n_f16(1.0); - const float16x8_t beta_vec = (input_beta != nullptr) ? vld1q_f16(input_beta + id.x()) : vdupq_n_f16(0.0); - // Calculate denominator - const float16x8_t denominator = vinvsqrtq_f16(vaddq_f16(var_vec, epsilon_vec)); - - // Calculate x bar and store results - const float16x8_t numerator = vsubq_f16(vld1q_f16(reinterpret_cast(input.ptr())), mean_vec); - const float16x8_t x_bar = vmulq_f16(numerator, denominator); - float16x8_t res = vaddq_f16(beta_vec, vmulq_f16(x_bar, gamma_vec)); - - // Perform fused activation - if(fused_activation) - { - activation_functor(res); + // Store results + *(output_ptr + x) = res; } - - vst1q_f16(reinterpret_cast(output.ptr()), res); }, input, output); -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ } -template -void NEBatchNormalizationLayerKernel::batch_normalization_fp32_nchw(const Window &window) +template +void NEBatchNormalizationLayerKernel::batch_normalization_nhwc(const Window &window) { - Iterator input(_input, window); - Iterator output(_output, window); + /** NEON vector tag type. */ + using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t; + + const int window_step_x = 16 / sizeof(T); + const auto window_start_x = static_cast(window.x().start()); + const auto window_end_x = static_cast(window.x().end()); + + Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); + win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator input(_input, win_collapsed); + Iterator output(_output, win_collapsed); F activation_functor(_act_info); - // Hold information about the current feature map we are iterating. - // Only compute denominator and NEON vectors once per feature map. - int slice = -1; + const auto input_mean = reinterpret_cast(_mean->ptr_to_element(Coordinates(0, 0))); + const auto input_var = reinterpret_cast(_var->ptr_to_element(Coordinates(0, 0))); + const auto input_gamma = (_gamma != nullptr) ? reinterpret_cast(_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr; + const auto input_beta = (_beta != nullptr) ? reinterpret_cast(_beta->ptr_to_element(Coordinates(0, 0))) : nullptr; - const auto input_mean = reinterpret_cast(_mean->ptr_to_element(Coordinates(0, 0))); - const auto input_var = reinterpret_cast(_var->ptr_to_element(Coordinates(0, 0))); - const auto input_gamma = (_gamma != nullptr) ? reinterpret_cast(_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr; - const auto input_beta = (_beta != nullptr) ? reinterpret_cast(_beta->ptr_to_element(Coordinates(0, 0))) : nullptr; - - float32x4_t mean_vec = vdupq_n_f32(0.0); - float32x4_t var_vec = vdupq_n_f32(0.0); - float32x4_t gamma_vec = vdupq_n_f32(1.0); - float32x4_t beta_vec = vdupq_n_f32(0.0); - float32x4_t denominator = vdupq_n_f32(0.0); - const float32x4_t epsilon_vec = vdupq_n_f32(_epsilon); - execute_window_loop(window, [&](const Coordinates & id) + const auto epsilon_vec = wrapper::vdup_n(static_cast(_epsilon), ExactTagType{}); + execute_window_loop(win_collapsed, [&](const Coordinates &) { - if(slice != id.z()) + const auto input_ptr = reinterpret_cast(input.ptr()); + const auto output_ptr = reinterpret_cast(output.ptr()); + + // Perform core calculations using vector operations + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) { // Conctruct vectors - mean_vec = vdupq_n_f32(*(input_mean + id.z())); - var_vec = vdupq_n_f32(*(input_var + id.z())); - if(input_gamma != nullptr) - { - gamma_vec = vdupq_n_f32(*(input_gamma + id.z())); - } - if(input_beta != nullptr) - { - beta_vec = vdupq_n_f32(*(input_beta + id.z())); - } + const auto mean_vec = wrapper::vloadq(input_mean + x); + const auto var_vec = wrapper::vloadq(input_var + x); + const auto gamma_vec = (input_gamma != nullptr) ? wrapper::vloadq(input_gamma + x) : wrapper::vdup_n(static_cast(1.f), ExactTagType{}); + const auto beta_vec = (input_beta != nullptr) ? wrapper::vloadq(input_beta + x) : wrapper::vdup_n(static_cast(0.f), ExactTagType{}); // Calculate denominator - denominator = vinvsqrtq_f32(vaddq_f32(var_vec, epsilon_vec)); - slice = id.z(); - } + const auto denominator = wrapper::vinvsqrt(wrapper::vadd(var_vec, epsilon_vec)); - // Calculate x bar - const float32x4_t numerator = vsubq_f32(vld1q_f32(reinterpret_cast(input.ptr())), mean_vec); - const float32x4_t x_bar = vmulq_f32(numerator, denominator); - float32x4_t res = vmlaq_f32(beta_vec, x_bar, gamma_vec); + // Calculate x bar + const auto numerator = wrapper::vsub(wrapper::vloadq(input_ptr + x), mean_vec); + const auto x_bar = wrapper::vmul(numerator, denominator); + auto res = wrapper::vmla(beta_vec, x_bar, gamma_vec); - // Perform fused activation - if(fused_activation) - { - activation_functor(res); - } + // Perform fused activation + if(fused_activation) + { + activation_functor(res); + } - // Store results - vst1q_f32(reinterpret_cast(output.ptr()), res); - }, - input, output); -} + // Store results + wrapper::vstore(output_ptr + x, res); + } -template -void NEBatchNormalizationLayerKernel::batch_normalization_fp32_nhwc(const Window &window) -{ - Iterator input(_input, window); - Iterator output(_output, window); + // Compute left-over elements + for(; x < window_end_x; ++x) + { + // Conctruct vectors + const T gamma = (input_gamma != nullptr) ? input_gamma[x] : 1.f; + const T beta = (input_beta != nullptr) ? input_beta[x] : 0.f; - F activation_functor(_act_info); + const T denominator = sqrt(input_var[x] + _epsilon); + const T numerator = input_ptr[x] - input_mean[x]; + const T x_bar = numerator / denominator; + T res = beta + x_bar * gamma; - const auto input_mean = reinterpret_cast(_mean->ptr_to_element(Coordinates(0, 0))); - const auto input_var = reinterpret_cast(_var->ptr_to_element(Coordinates(0, 0))); - const auto input_gamma = (_gamma != nullptr) ? reinterpret_cast(_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr; - const auto input_beta = (_beta != nullptr) ? reinterpret_cast(_beta->ptr_to_element(Coordinates(0, 0))) : nullptr; + // Perform fused activation + if(fused_activation) + { + activation_functor(res); + } - const float32x4_t epsilon_vec = vdupq_n_f32(_epsilon); - execute_window_loop(window, [&](const Coordinates & id) - { - // Conctruct vectors - const float32x4_t mean_vec = vld1q_f32(input_mean + id.x()); - const float32x4_t var_vec = vld1q_f32(input_var + id.x()); - const float32x4_t gamma_vec = (input_gamma != nullptr) ? vld1q_f32(input_gamma + id.x()) : vdupq_n_f32(1.0); - const float32x4_t beta_vec = (input_beta != nullptr) ? vld1q_f32(input_beta + id.x()) : vdupq_n_f32(0.0); - // Calculate denominator - const float32x4_t denominator = vinvsqrtq_f32(vaddq_f32(var_vec, epsilon_vec)); - - // Calculate x bar - const float32x4_t numerator = vsubq_f32(vld1q_f32(reinterpret_cast(input.ptr())), mean_vec); - const float32x4_t x_bar = vmulq_f32(numerator, denominator); - float32x4_t res = vmlaq_f32(beta_vec, x_bar, gamma_vec); - - // Perform fused activation - if(fused_activation) - { - activation_functor(res); + // Store results + *reinterpret_cast(output_ptr + x) = res; } - - // Store results - vst1q_f32(reinterpret_cast(output.ptr()), res); }, input, output); } @@ -340,13 +299,13 @@ void NEBatchNormalizationLayerKernel::configure_non_fused() { #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: - _func = (is_nhwc) ? &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nhwc> : - &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nchw>; + _func = (is_nhwc) ? &NEBatchNormalizationLayerKernel::batch_normalization_nhwc> : + &NEBatchNormalizationLayerKernel::batch_normalization_nchw>; break; #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F32: - _func = (is_nhwc) ? &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nhwc> : - &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nchw>; + _func = (is_nhwc) ? &NEBatchNormalizationLayerKernel::batch_normalization_nhwc> : + &NEBatchNormalizationLayerKernel::batch_normalization_nchw>; break; default: ARM_COMPUTE_ERROR("Element size not supported"); @@ -359,31 +318,31 @@ void NEBatchNormalizationLayerKernel::configure_fused() // NCHW Fused Batched Normalization with activation functions : FP32 static std::map bn_fused_map_f32_nchw = { - { ActivationLayerInfo::ActivationFunction::RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nchw> }, - { ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nchw> }, - { ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nchw> } + { ActivationLayerInfo::ActivationFunction::RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nchw> }, + { ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nchw> }, + { ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nchw> } }; // NHWC Fused Batched Normalization with activation functions : FP32 static std::map bn_fused_map_f32_nhwc = { - { ActivationLayerInfo::ActivationFunction::RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nhwc> }, - { ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nhwc> }, - { ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nhwc> } + { ActivationLayerInfo::ActivationFunction::RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nhwc> }, + { ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nhwc> }, + { ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nhwc> } }; #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC // NCHW Fused Batched Normalization with activation functions : FP16 static std::map bn_fused_map_f16_nchw = { - { ActivationLayerInfo::ActivationFunction::RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nchw> }, - { ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nchw> }, - { ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nchw> } + { ActivationLayerInfo::ActivationFunction::RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nchw> }, + { ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nchw> }, + { ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nchw> } }; // NHWC Fused Batched Normalization with activation functions : FP16 static std::map bn_fused_map_f16_nhwc = { - { ActivationLayerInfo::ActivationFunction::RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nhwc> }, - { ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nhwc> }, - { ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nhwc> } + { ActivationLayerInfo::ActivationFunction::RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nhwc> }, + { ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nhwc> }, + { ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nhwc> } }; #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -475,3 +434,4 @@ void NEBatchNormalizationLayerKernel::run(const Window &window, const ThreadInfo (this->*_func)(window); } +} // namespace arm_compute diff --git a/tests/validation/NEON/BatchNormalizationLayer.cpp b/tests/validation/NEON/BatchNormalizationLayer.cpp index 58b7474b41..6075e6be8d 100644 --- a/tests/validation/NEON/BatchNormalizationLayer.cpp +++ b/tests/validation/NEON/BatchNormalizationLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -71,69 +71,34 @@ TEST_SUITE(BatchNormalizationLayer) template using NEBatchNormalizationLayerFixture = BatchNormalizationLayerValidationFixture; -DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallRandomBatchNormalizationLayerDataset(), - combine(framework::dataset::make("UseBeta", { false, true }), framework::dataset::make("UseGamma", { false, true }))), - framework::dataset::make("DataType", { DataType::F32 })), - framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), - shape0, shape1, epsilon, use_beta, use_gamma, dt, data_layout) -{ - TensorShape src_dst_shapes = shape0; - if(data_layout == DataLayout::NHWC) - { - permute(src_dst_shapes, PermutationVector(2U, 0U, 1U)); - } - - // Create tensors - Tensor src = create_tensor(src_dst_shapes, dt, 1, QuantizationInfo(), data_layout); - Tensor dst = create_tensor(src_dst_shapes, dt, 1, QuantizationInfo(), data_layout); - Tensor mean = create_tensor(shape1, dt, 1); - Tensor var = create_tensor(shape1, dt, 1); - Tensor beta = create_tensor(shape1, dt, 1); - Tensor gamma = create_tensor(shape1, dt, 1); - - // Create and Configure function - NEBatchNormalizationLayer norm; - Tensor *beta_ptr = use_beta ? &beta : nullptr; - Tensor *gamma_ptr = use_gamma ? &gamma : nullptr; - norm.configure(&src, &dst, &mean, &var, beta_ptr, gamma_ptr, epsilon); - - // Validate valid region - const ValidRegion valid_region = shape_to_valid_region(src_dst_shapes); - validate(dst.info()->valid_region(), valid_region); -} - // *INDENT-OFF* // clang-format off DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip( - framework::dataset::make("InputInfo", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), // Window shrink + framework::dataset::make("InputInfo", { TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching data types TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching data types TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Invalid mean/var/beta/gamma shape TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Fused activation's a < b }), - framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), + framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F16), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), })), framework::dataset::make("MVBGInfo",{ TensorInfo(TensorShape(2U), 1, DataType::F32), - TensorInfo(TensorShape(2U), 1, DataType::F32), TensorInfo(TensorShape(2U), 1, DataType::F16), TensorInfo(TensorShape(2U), 1, DataType::F32), TensorInfo(TensorShape(5U), 1, DataType::F32), TensorInfo(TensorShape(2U), 1, DataType::F32), })), framework::dataset::make("ActivationLayerInfo",{ ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU), - ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.f), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.f), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 6.f), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 2.f, 6.f), })), - framework::dataset::make("Expected", { true, false, false, false, false, false})), + framework::dataset::make("Expected", { true, false, false, false, false})), input_info, output_info, mvbg_info, act_info, expected) { const auto &mean_info = mvbg_info; -- cgit v1.2.1