From 8fda1cb6f4142133fff045a6f9c18778757c316c Mon Sep 17 00:00:00 2001 From: Pablo Tello Date: Wed, 5 Jul 2017 15:20:38 +0100 Subject: COMPMID-421: Added FP16 support in BatchNormalizationLayer. Change-Id: I7142e0e8466ef79e016ae56d285e8e9291573e52 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/79814 Reviewed-by: Moritz Pflanzer Reviewed-by: Anthony Barbier Tested-by: Kaizen --- arm_compute/core/NEON/NEMath.h | 10 +++ arm_compute/core/NEON/NEMath.inl | 3 +- .../kernels/NEBatchNormalizationLayerKernel.cpp | 53 +++++++++++- tests/NEON/Helper.h | 16 ++++ tests/validation/Helpers.h | 23 +++++ tests/validation/NEON/BatchNormalizationLayer.cpp | 98 +++++++++++++++------- tests/validation/Reference.cpp | 67 ++++++++------- 7 files changed, 208 insertions(+), 62 deletions(-) diff --git a/arm_compute/core/NEON/NEMath.h b/arm_compute/core/NEON/NEMath.h index b467a600d6..39f0c3bf77 100644 --- a/arm_compute/core/NEON/NEMath.h +++ b/arm_compute/core/NEON/NEMath.h @@ -36,6 +36,16 @@ namespace arm_compute */ float32x4_t vinvsqrtq_f32(float32x4_t x); +#ifdef ARM_COMPUTE_ENABLE_FP16 +/** Calculate inverse square root. + * + * @param[in] x Input value. + * + * @return The calculated inverse square root. + */ +float16x8_t vinvsqrtq_f16(float16x8_t x); +#endif /* ARM_COMPUTE_ENABLE_FP16 */ + /** Calculate reciprocal. * * @param[in] x Input value. diff --git a/arm_compute/core/NEON/NEMath.inl b/arm_compute/core/NEON/NEMath.inl index 1d90029147..08f6749ac9 100644 --- a/arm_compute/core/NEON/NEMath.inl +++ b/arm_compute/core/NEON/NEMath.inl @@ -141,7 +141,6 @@ inline float32x4_t vpowq_f32(float32x4_t val, float32x4_t n) { return vexpq_f32(vmulq_f32(n, vlogq_f32(val))); } - #ifdef ARM_COMPUTE_ENABLE_FP16 /* Exponent polynomial coefficients */ const std::array exp_tab_f16 = @@ -172,12 +171,12 @@ const std::array log_tab_f16 = vdupq_n_f16(0.0141278216615f), } }; + inline float16x8_t vinvsqrtq_f16(float16x8_t x) { float16x8_t sqrt_reciprocal = vrsqrteq_f16(x); sqrt_reciprocal = vmulq_f16(vrsqrtsq_f16(vmulq_f16(x, sqrt_reciprocal), sqrt_reciprocal), sqrt_reciprocal); sqrt_reciprocal = vmulq_f16(vrsqrtsq_f16(vmulq_f16(x, sqrt_reciprocal), sqrt_reciprocal), sqrt_reciprocal); - return sqrt_reciprocal; } diff --git a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp index d1adfa7aec..290a3c59ba 100644 --- a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp +++ b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp @@ -169,9 +169,54 @@ void batch_normalization_fp32(const ITensor *in, ITensor *out, const ITensor *me input, output); } +#ifdef ARM_COMPUTE_ENABLE_FP16 +void batch_normalization_fp16(const ITensor *in, ITensor *out, const ITensor *mean, const ITensor *var, const ITensor *beta, const ITensor *gamma, float epsilon, const Window &window) +{ + Iterator input(in, window); + Iterator output(out, window); + + // Hold information about the current feature map we are iterating. + // Only compute denominator and NEON vectors once per feature map. + int slice = -1; + + const auto input_mean = reinterpret_cast(mean->ptr_to_element(Coordinates(0, 0))); + const auto input_var = reinterpret_cast(var->ptr_to_element(Coordinates(0, 0))); + const auto input_gamma = reinterpret_cast(gamma->ptr_to_element(Coordinates(0, 0))); + const auto input_beta = reinterpret_cast(beta->ptr_to_element(Coordinates(0, 0))); + + float16x8_t mean_vec = vdupq_n_f16(0.0); + float16x8_t var_vec = vdupq_n_f16(0.0); + float16x8_t gamma_vec = vdupq_n_f16(0.0); + float16x8_t beta_vec = vdupq_n_f16(0.0); + float16x8_t denominator = vdupq_n_f16(0.0); + const float16x8_t epsilon_vec = vdupq_n_f16(epsilon); + execute_window_loop(window, [&](const Coordinates & id) + { + if(slice != id.z()) + { + // Conctruct vectors + mean_vec = vdupq_n_f16(*(input_mean + id.z())); + var_vec = vdupq_n_f16(*(input_var + id.z())); + gamma_vec = vdupq_n_f16(*(input_gamma + id.z())); + beta_vec = vdupq_n_f16(*(input_beta + id.z())); + + // Calculate denominator + denominator = vinvsqrtq_f16(vaddq_f16(var_vec, epsilon_vec)); + slice = id.z(); + } + + // Calculate x bar and store results + const float16x8_t numerator = vsubq_f16(vld1q_f16(reinterpret_cast(input.ptr())), mean_vec); + const float16x8_t x_bar = vmulq_f16(numerator, denominator); + vst1q_f16(reinterpret_cast(output.ptr()), vaddq_f16(beta_vec, vmulq_f16(x_bar, gamma_vec))); + }, + input, output); +} +#endif /* ARM_COMPUTE_ENABLE_FP16 */ + void NEBatchNormalizationLayerKernel::configure(const ITensor *input, ITensor *output, const ITensor *mean, const ITensor *var, const ITensor *beta, const ITensor *gamma, float epsilon) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F32); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32); ARM_COMPUTE_ERROR_ON_NULLPTR(output); // Output tensor auto initialization if not yet initialized @@ -207,6 +252,12 @@ void NEBatchNormalizationLayerKernel::configure(const ITensor *input, ITensor *o _func = &batch_normalization_fp32; num_elems_processed_per_iteration = 4; break; + case DataType::F16: +#ifdef ARM_COMPUTE_ENABLE_FP16 + _func = &batch_normalization_fp16; + num_elems_processed_per_iteration = 8; + break; +#endif /* ARM_COMPUTE_ENABLE_FP16 */ default: ARM_COMPUTE_ERROR("Element size not supported"); break; diff --git a/tests/NEON/Helper.h b/tests/NEON/Helper.h index 5b0f750fca..9e60a48ecd 100644 --- a/tests/NEON/Helper.h +++ b/tests/NEON/Helper.h @@ -25,8 +25,10 @@ #define __ARM_COMPUTE_TEST_NEON_HELPER_H__ #include "arm_compute/runtime/Array.h" +#include "tests/Globals.h" #include +#include #include namespace arm_compute @@ -44,6 +46,20 @@ Array create_array(const std::vector &v) return array; } + +template +void fill_tensors(D &&dist, std::initializer_list seeds, T &&tensor, Ts &&... other_tensors) +{ + const std::array < T, 1 + sizeof...(Ts) > tensors{ { std::forward(tensor), std::forward(other_tensors)... } }; + std::vector vs(seeds); + ARM_COMPUTE_ERROR_ON(vs.size() != tensors.size()); + int k = 0; + for(auto tp : tensors) + { + library->fill(Accessor(*tp), std::forward(dist), vs[k++]); + } +} + } // namespace test } // namespace arm_compute #endif /* __ARM_COMPUTE_TEST_NEON_HELPER_H__ */ diff --git a/tests/validation/Helpers.h b/tests/validation/Helpers.h index 2793c22147..8d70de6958 100644 --- a/tests/validation/Helpers.h +++ b/tests/validation/Helpers.h @@ -25,11 +25,13 @@ #define __ARM_COMPUTE_TEST_VALIDATION_HELPERS_H__ #include "arm_compute/core/Types.h" +#include "tests/Globals.h" #include "tests/ILutAccessor.h" #include "tests/Types.h" #include "tests/validation/ValidationUserConfiguration.h" #include "tests/validation/half.h" +#include #include #include #include @@ -41,6 +43,27 @@ namespace test { namespace validation { +/** Helper function to fill one or more tensors with the uniform distribution with int values. + * + * @param[in] dist Distribution to be used to get the values for the tensor. + * @param[in] seeds List of seeds to be used to fill each tensor. + * @param[in,out] tensor Tensor to be initialized with the values of the distribution. + * @param[in,out] other_tensors (Optional) One or more tensors to be filled. + * + */ +template +void fill_tensors(D &&dist, std::initializer_list seeds, T &&tensor, Ts &&... other_tensors) +{ + const std::array < T, 1 + sizeof...(Ts) > tensors{ { std::forward(tensor), std::forward(other_tensors)... } }; + std::vector vs(seeds); + ARM_COMPUTE_ERROR_ON(vs.size() != tensors.size()); + int k = 0; + for(auto tp : tensors) + { + library->fill(*tp, std::forward(dist), vs[k++]); + } +} + /** Helper function to get the testing range for each activation layer. * * @param[in] activation Activation function to test. diff --git a/tests/validation/NEON/BatchNormalizationLayer.cpp b/tests/validation/NEON/BatchNormalizationLayer.cpp index 279257d071..9898beb7db 100644 --- a/tests/validation/NEON/BatchNormalizationLayer.cpp +++ b/tests/validation/NEON/BatchNormalizationLayer.cpp @@ -25,6 +25,7 @@ #include "TypePrinter.h" #include "dataset/BatchNormalizationLayerDataset.h" #include "tests/Globals.h" +#include "tests/NEON/Helper.h" #include "tests/Utils.h" #include "tests/validation/Helpers.h" #include "validation/Datasets.h" @@ -41,9 +42,12 @@ using namespace arm_compute::test::validation; namespace { -const float tolerance_f = 1e-05; /**< Tolerance value for comparing reference's output against floating point implementation's output */ -const float tolerance_qs8 = 6; /**< Tolerance value for comparing reference's output against quantized implementation's output */ -const float tolerance_qs16 = 6; /**< Tolerance value for comparing reference's output against quantized implementation's output */ +const float tolerance_qs8 = 6; /**< Tolerance value for comparing reference's output against quantized implementation's output */ +const float tolerance_qs16 = 6; /**< Tolerance value for comparing reference's output against quantized implementation's output */ +const float tolerance_f32 = 1e-05f; /**< Tolerance value for comparing reference's output against floating point implementation's output */ +#ifdef ARM_COMPUTE_ENABLE_FP16 +const float tolerance_f16 = 0.01f; /**< Tolerance value for comparing reference's output against half precision floating point implementation's output */ +#endif /* ARM_COMPUTE_ENABLE_FP16 */ /** Compute Neon batch normalization function. * @@ -83,38 +87,51 @@ Tensor compute_reference_batch_normalization_layer(const TensorShape &shape0, co BOOST_TEST(!gamma.info()->is_resizable()); // Fill tensors - if(dt == DataType::F32) + switch(dt) { - float min_bound = 0.f; - float max_bound = 0.f; - std::tie(min_bound, max_bound) = get_batchnormalization_layer_test_bounds(); - std::uniform_real_distribution<> distribution(min_bound, max_bound); - std::uniform_real_distribution<> distribution_var(0, max_bound); - library->fill(Accessor(src), distribution, 0); - library->fill(Accessor(mean), distribution, 1); - library->fill(Accessor(var), distribution_var, 0); - library->fill(Accessor(beta), distribution, 3); - library->fill(Accessor(gamma), distribution, 4); - } - else - { - int min_bound = 0; - int max_bound = 0; - if(dt == DataType::QS8) + case DataType::QS8: + { + const std::pair bounds = get_batchnormalization_layer_test_bounds(fixed_point_position); + std::uniform_int_distribution<> distribution(bounds.first, bounds.second); + std::uniform_int_distribution<> distribution_var(0, bounds.second); + test::fill_tensors(distribution, { 0, 1, 3, 4 }, &src, &mean, &beta, &gamma); + test::fill_tensors(distribution_var, { 0 }, &var); + break; + } + case DataType::QS16: + { + const std::pair bounds = get_batchnormalization_layer_test_bounds(fixed_point_position); + std::uniform_int_distribution<> distribution(bounds.first, bounds.second); + std::uniform_int_distribution<> distribution_var(0, bounds.second); + test::fill_tensors(distribution, { 0, 1, 3, 4 }, &src, &mean, &beta, &gamma); + test::fill_tensors(distribution_var, { 0 }, &var); + break; + } +#ifdef ARM_COMPUTE_ENABLE_FP16 + case DataType::F16: { - std::tie(min_bound, max_bound) = get_batchnormalization_layer_test_bounds(fixed_point_position); + const std::pair bounds = get_batchnormalization_layer_test_bounds(); + std::uniform_real_distribution<> distribution(bounds.first, bounds.second); + std::uniform_real_distribution<> distribution_var(0, bounds.second); + test::fill_tensors(distribution, { 0, 1, 3, 4 }, &src, &mean, &beta, &gamma); + test::fill_tensors(distribution_var, { 0 }, &var); + break; } - else +#endif /* ARM_COMPUTE_ENABLE_FP16 */ + case DataType::F32: { - std::tie(min_bound, max_bound) = get_batchnormalization_layer_test_bounds(fixed_point_position); + const std::pair bounds = get_batchnormalization_layer_test_bounds(); + std::uniform_real_distribution<> distribution(bounds.first, bounds.second); + std::uniform_real_distribution<> distribution_var(0, bounds.second); + test::fill_tensors(distribution, { 0, 1, 3, 4 }, &src, &mean, &beta, &gamma); + test::fill_tensors(distribution_var, { 0 }, &var); + break; + } + default: + { + ARM_COMPUTE_ERROR("Not supported"); + break; } - std::uniform_int_distribution<> distribution(min_bound, max_bound); - std::uniform_int_distribution<> distribution_var(0, max_bound); - library->fill(Accessor(src), distribution, 0); - library->fill(Accessor(mean), distribution, 1); - library->fill(Accessor(var), distribution_var, 0); - library->fill(Accessor(beta), distribution, 3); - library->fill(Accessor(gamma), distribution, 4); } // Compute function @@ -177,9 +194,28 @@ BOOST_DATA_TEST_CASE(Random, RawTensor ref_dst = Reference::compute_reference_batch_normalization_layer(obj.shape0, obj.shape1, dt, obj.epsilon); // Validate output - validate(Accessor(dst), ref_dst, tolerance_f, 0); + validate(Accessor(dst), ref_dst, tolerance_f32, 0); +} +BOOST_AUTO_TEST_SUITE_END() + +#ifdef ARM_COMPUTE_ENABLE_FP16 +BOOST_AUTO_TEST_SUITE(Float16) +BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit")) +BOOST_DATA_TEST_CASE(Random, + RandomBatchNormalizationLayerDataset() * boost::unit_test::data::make(DataType::F16), + obj, dt) +{ + // Compute function + Tensor dst = compute_reference_batch_normalization_layer(obj.shape0, obj.shape1, dt, obj.epsilon); + + // Compute reference + RawTensor ref_dst = Reference::compute_reference_batch_normalization_layer(obj.shape0, obj.shape1, dt, obj.epsilon); + + // Validate output + validate(Accessor(dst), ref_dst, tolerance_f16, 0); } BOOST_AUTO_TEST_SUITE_END() +#endif /* ARM_COMPUTE_ENABLE_FP16 */ BOOST_AUTO_TEST_SUITE(Quantized) BOOST_AUTO_TEST_SUITE(QS8) diff --git a/tests/validation/Reference.cpp b/tests/validation/Reference.cpp index b7553f3b7b..e9ddea78cb 100644 --- a/tests/validation/Reference.cpp +++ b/tests/validation/Reference.cpp @@ -513,39 +513,50 @@ RawTensor Reference::compute_reference_batch_normalization_layer(const TensorSha RawTensor ref_beta(shape1, dt, 1, fixed_point_position); RawTensor ref_gamma(shape1, dt, 1, fixed_point_position); - // Fill tensors with values from -1 to 1. - if(dt == DataType::F32) - { - float min_bound = 0.f; - float max_bound = 0.f; - std::tie(min_bound, max_bound) = get_batchnormalization_layer_test_bounds(); - std::uniform_real_distribution<> distribution(min_bound, max_bound); - std::uniform_real_distribution<> distribution_var(0, max_bound); - library->fill(ref_src, distribution, 0); - library->fill(ref_mean, distribution, 1); - library->fill(ref_var, distribution_var, 0); - library->fill(ref_beta, distribution, 3); - library->fill(ref_gamma, distribution, 4); - } - else + // Fill tensors + switch(dt) { - int min_bound = 0; - int max_bound = 0; - if(dt == DataType::QS8) + case DataType::QS8: { - std::tie(min_bound, max_bound) = get_batchnormalization_layer_test_bounds(fixed_point_position); + const std::pair bounds = get_batchnormalization_layer_test_bounds(fixed_point_position); + std::uniform_int_distribution<> distribution(bounds.first, bounds.second); + std::uniform_int_distribution<> distribution_var(0, bounds.second); + fill_tensors(distribution, { 0, 1, 3, 4 }, &ref_src, &ref_mean, &ref_beta, &ref_gamma); + fill_tensors(distribution_var, { 0 }, &ref_var); + break; } - else + case DataType::QS16: { - std::tie(min_bound, max_bound) = get_batchnormalization_layer_test_bounds(fixed_point_position); + const std::pair bounds = get_batchnormalization_layer_test_bounds(fixed_point_position); + std::uniform_int_distribution<> distribution(bounds.first, bounds.second); + std::uniform_int_distribution<> distribution_var(0, bounds.second); + fill_tensors(distribution, { 0, 1, 3, 4 }, &ref_src, &ref_mean, &ref_beta, &ref_gamma); + fill_tensors(distribution_var, { 0 }, &ref_var); + break; + } + case DataType::F16: + { + const std::pair bounds = get_batchnormalization_layer_test_bounds(); + std::uniform_real_distribution<> distribution(bounds.first, bounds.second); + std::uniform_real_distribution<> distribution_var(0, bounds.second); + fill_tensors(distribution, { 0, 1, 3, 4 }, &ref_src, &ref_mean, &ref_beta, &ref_gamma); + fill_tensors(distribution_var, { 0 }, &ref_var); + break; + } + case DataType::F32: + { + const std::pair bounds = get_batchnormalization_layer_test_bounds(); + std::uniform_real_distribution<> distribution(bounds.first, bounds.second); + std::uniform_real_distribution<> distribution_var(0, bounds.second); + fill_tensors(distribution, { 0, 1, 3, 4 }, &ref_src, &ref_mean, &ref_beta, &ref_gamma); + fill_tensors(distribution_var, { 0 }, &ref_var); + break; + } + default: + { + ARM_COMPUTE_ERROR("Not supported"); + break; } - std::uniform_int_distribution<> distribution(min_bound, max_bound); - std::uniform_int_distribution<> distribution_var(0, max_bound); - library->fill(ref_src, distribution, 0); - library->fill(ref_mean, distribution, 1); - library->fill(ref_var, distribution_var, 0); - library->fill(ref_beta, distribution, 3); - library->fill(ref_gamma, distribution, 4); } // Compute reference -- cgit v1.2.1