diff options
author | Pablo Tello <pablo.tello@arm.com> | 2017-07-05 15:20:38 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-09-17 14:16:42 +0100 |
commit | 8fda1cb6f4142133fff045a6f9c18778757c316c (patch) | |
tree | 3f0ad562b24cc3c76e8a745cb59cd584b664ec57 /src/core/NEON | |
parent | 8df3fafde3dcf131def3471db8e8b1a1c34b354b (diff) | |
download | ComputeLibrary-8fda1cb6f4142133fff045a6f9c18778757c316c.tar.gz |
COMPMID-421: Added FP16 support in BatchNormalizationLayer.
Change-Id: I7142e0e8466ef79e016ae56d285e8e9291573e52
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/79814
Reviewed-by: Moritz Pflanzer <moritz.pflanzer@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Diffstat (limited to 'src/core/NEON')
-rw-r--r-- | src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp | 53 |
1 files changed, 52 insertions, 1 deletions
diff --git a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp index d1adfa7aec..290a3c59ba 100644 --- a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp +++ b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp @@ -169,9 +169,54 @@ void batch_normalization_fp32(const ITensor *in, ITensor *out, const ITensor *me input, output); } +#ifdef ARM_COMPUTE_ENABLE_FP16 +void batch_normalization_fp16(const ITensor *in, ITensor *out, const ITensor *mean, const ITensor *var, const ITensor *beta, const ITensor *gamma, float epsilon, const Window &window) +{ + Iterator input(in, window); + Iterator output(out, window); + + // Hold information about the current feature map we are iterating. + // Only compute denominator and NEON vectors once per feature map. + int slice = -1; + + const auto input_mean = reinterpret_cast<const float16_t *>(mean->ptr_to_element(Coordinates(0, 0))); + const auto input_var = reinterpret_cast<const float16_t *>(var->ptr_to_element(Coordinates(0, 0))); + const auto input_gamma = reinterpret_cast<const float16_t *>(gamma->ptr_to_element(Coordinates(0, 0))); + const auto input_beta = reinterpret_cast<const float16_t *>(beta->ptr_to_element(Coordinates(0, 0))); + + float16x8_t mean_vec = vdupq_n_f16(0.0); + float16x8_t var_vec = vdupq_n_f16(0.0); + float16x8_t gamma_vec = vdupq_n_f16(0.0); + float16x8_t beta_vec = vdupq_n_f16(0.0); + float16x8_t denominator = vdupq_n_f16(0.0); + const float16x8_t epsilon_vec = vdupq_n_f16(epsilon); + execute_window_loop(window, [&](const Coordinates & id) + { + if(slice != id.z()) + { + // Conctruct vectors + mean_vec = vdupq_n_f16(*(input_mean + id.z())); + var_vec = vdupq_n_f16(*(input_var + id.z())); + gamma_vec = vdupq_n_f16(*(input_gamma + id.z())); + beta_vec = vdupq_n_f16(*(input_beta + id.z())); + + // Calculate denominator + denominator = vinvsqrtq_f16(vaddq_f16(var_vec, epsilon_vec)); + slice = id.z(); + } + + // Calculate x bar and store results + const float16x8_t numerator = vsubq_f16(vld1q_f16(reinterpret_cast<const float16_t *>(input.ptr())), mean_vec); + const float16x8_t x_bar = vmulq_f16(numerator, denominator); + vst1q_f16(reinterpret_cast<float16_t *>(output.ptr()), vaddq_f16(beta_vec, vmulq_f16(x_bar, gamma_vec))); + }, + input, output); +} +#endif /* ARM_COMPUTE_ENABLE_FP16 */ + void NEBatchNormalizationLayerKernel::configure(const ITensor *input, ITensor *output, const ITensor *mean, const ITensor *var, const ITensor *beta, const ITensor *gamma, float epsilon) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F32); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32); ARM_COMPUTE_ERROR_ON_NULLPTR(output); // Output tensor auto initialization if not yet initialized @@ -207,6 +252,12 @@ void NEBatchNormalizationLayerKernel::configure(const ITensor *input, ITensor *o _func = &batch_normalization_fp32; num_elems_processed_per_iteration = 4; break; + case DataType::F16: +#ifdef ARM_COMPUTE_ENABLE_FP16 + _func = &batch_normalization_fp16; + num_elems_processed_per_iteration = 8; + break; +#endif /* ARM_COMPUTE_ENABLE_FP16 */ default: ARM_COMPUTE_ERROR("Element size not supported"); break; |