aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp118
1 files changed, 1 insertions, 117 deletions
diff --git a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp
index 6be50fdb0d..6aed41f3aa 100644
--- a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp
@@ -43,7 +43,7 @@ validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const IT
const ITensorInfo *beta, const ITensorInfo *gamma, float epsilon, ActivationLayerInfo act_info)
{
ARM_COMPUTE_UNUSED(epsilon);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16,
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16,
DataType::F32);
if(act_info.enabled())
@@ -60,22 +60,18 @@ validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const IT
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
}
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, mean, var);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, mean, var);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, var);
if(beta != nullptr)
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, beta);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, beta);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, beta);
}
if(gamma != nullptr)
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, gamma);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, gamma);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, gamma);
}
ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL)) != mean->dimension(0));
@@ -104,112 +100,6 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITen
} //namespace
template <bool fused_activation>
-void NEBatchNormalizationLayerKernel::batch_normalization_qs8(const Window &window)
-{
- static_assert(!fused_activation, "Activation is not supported for QS8");
-
- Iterator input(_input, window);
- Iterator output(_output, window);
-
- // Hold information about the current feature map we are iterating.
- // Only compute denominator and NEON vectors once per feature map.
- int slice = -1;
-
- const int fixed_point_position = _input->info()->fixed_point_position();
- const auto input_mean = reinterpret_cast<const qint8_t *>(_mean->ptr_to_element(Coordinates(0, 0)));
- const auto input_var = reinterpret_cast<const qint8_t *>(_var->ptr_to_element(Coordinates(0, 0)));
- const auto input_gamma = (_gamma != nullptr) ? reinterpret_cast<const qint8_t *>(_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr;
- const auto input_beta = (_beta != nullptr) ? reinterpret_cast<const qint8_t *>(_beta->ptr_to_element(Coordinates(0, 0))) : nullptr;
-
- qint8x16_t mean_vec = vdupq_n_qs8(0);
- qint8x16_t var_vec = vdupq_n_qs8(0);
- qint8x16_t gamma_vec = vdupq_n_qs8(sqcvt_qs8_f32(1, fixed_point_position));
- qint8x16_t beta_vec = vdupq_n_qs8(sqcvt_qs8_f32(0, fixed_point_position));
- qint8x16_t denominator = vdupq_n_qs8(0);
- const qint8x16_t epsilon_vec = vdupq_n_qs8(sqcvt_qs8_f32(_epsilon, fixed_point_position));
- execute_window_loop(window, [&](const Coordinates & id)
- {
- if(slice != id.z())
- {
- // Conctruct vectors
- mean_vec = vdupq_n_qs8(*(input_mean + id.z()));
- var_vec = vdupq_n_qs8(*(input_var + id.z()));
- if(input_gamma != nullptr)
- {
- gamma_vec = vdupq_n_qs8(*(input_gamma + id.z()));
- }
- if(input_beta != nullptr)
- {
- beta_vec = vdupq_n_qs8(*(input_beta + id.z()));
- }
-
- // Calculate denominator
- denominator = vqinvsqrtq_qs8(vqaddq_qs8(var_vec, epsilon_vec), fixed_point_position);
- slice = id.z();
- }
-
- // Calculate x bar and store results
- const qint8x16_t numerator = vqsubq_qs8(vld1q_qs8(reinterpret_cast<const qint8_t *>(input.ptr())), mean_vec);
- const qint8x16_t x_bar = vqmulq_qs8(numerator, denominator, fixed_point_position);
- vst1q_qs8(reinterpret_cast<qint8_t *>(output.ptr()), vqmlaq_qs8(beta_vec, x_bar, gamma_vec, fixed_point_position));
- },
- input, output);
-}
-
-template <bool fused_activation>
-void NEBatchNormalizationLayerKernel::batch_normalization_qs16(const Window &window)
-{
- static_assert(!fused_activation, "Activation is not supported for QS16");
-
- Iterator input(_input, window);
- Iterator output(_output, window);
-
- // Hold information about the current feature map we are iterating.
- // Only compute denominator and NEON vectors once per feature map.
- int slice = -1;
-
- const int fixed_point_position = _input->info()->fixed_point_position();
- const auto input_mean = reinterpret_cast<const qint16_t *>(_mean->ptr_to_element(Coordinates(0, 0)));
- const auto input_var = reinterpret_cast<const qint16_t *>(_var->ptr_to_element(Coordinates(0, 0)));
- const auto input_gamma = (_gamma != nullptr) ? reinterpret_cast<const qint16_t *>(_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr;
- const auto input_beta = (_beta != nullptr) ? reinterpret_cast<const qint16_t *>(_beta->ptr_to_element(Coordinates(0, 0))) : nullptr;
-
- qint16x8_t mean_vec = vdupq_n_qs16(0);
- qint16x8_t var_vec = vdupq_n_qs16(0);
- qint16x8_t gamma_vec = vdupq_n_qs16(sqcvt_qs16_f32(1, fixed_point_position));
- qint16x8_t beta_vec = vdupq_n_qs16(sqcvt_qs16_f32(0, fixed_point_position));
- qint16x8_t denominator = vdupq_n_qs16(0);
- const qint16x8_t epsilon_vec = vdupq_n_qs16(sqcvt_qs16_f32(_epsilon, fixed_point_position));
- execute_window_loop(window, [&](const Coordinates & id)
- {
- if(slice != id.z())
- {
- // Conctruct vectors
- mean_vec = vdupq_n_qs16(*(input_mean + id.z()));
- var_vec = vdupq_n_qs16(*(input_var + id.z()));
- if(input_gamma != nullptr)
- {
- gamma_vec = vdupq_n_qs16(*(input_gamma + id.z()));
- }
- if(input_beta != nullptr)
- {
- beta_vec = vdupq_n_qs16(*(input_beta + id.z()));
- }
-
- // Calculate denominator
- denominator = vqinvsqrtq_qs16(vqaddq_qs16(var_vec, epsilon_vec), fixed_point_position);
- slice = id.z();
- }
-
- // Calculate x bar and store results
- const qint16x8_t numerator = vqsubq_qs16(vld1q_qs16(reinterpret_cast<const qint16_t *>(input.ptr())), mean_vec);
- const qint16x8_t x_bar = vqmulq_qs16(numerator, denominator, fixed_point_position);
- vst1q_qs16(reinterpret_cast<qint16_t *>(output.ptr()), vqmlaq_qs16(beta_vec, x_bar, gamma_vec, fixed_point_position));
- },
- input, output);
-}
-
-template <bool fused_activation>
void NEBatchNormalizationLayerKernel::batch_normalization_fp16_nchw(const Window &window)
{
static_assert(!fused_activation, "Activation is not supported for FP16");
@@ -406,12 +296,6 @@ void NEBatchNormalizationLayerKernel::configure_non_fused()
const bool is_nhwc = _input->info()->data_layout() == DataLayout::NHWC;
switch(_input->info()->data_type())
{
- case DataType::QS8:
- _func = &NEBatchNormalizationLayerKernel::batch_normalization_qs8<false>;
- break;
- case DataType::QS16:
- _func = &NEBatchNormalizationLayerKernel::batch_normalization_qs16<false>;
- break;
case DataType::F16:
_func = (is_nhwc) ? &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nhwc<false> : &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nchw<false>;
break;