From 4b3fba1850fdf84ba3f9a0c98acf3de672330b34 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Tue, 4 Jun 2019 17:31:46 +0100 Subject: COMPMID-2372: Add support for QASYMM8 for Tanh -Perform calculations in the floating point domain -Extends checks for Logistic as scale should be 1/256 and offset 0 Change-Id: I90ef4a042f053976936f5d28f8e09b54eec196a2 Signed-off-by: Georgios Pinitas Reviewed-on: https://review.mlplatform.org/c/1287 Tested-by: Arm Jenkins Reviewed-by: Michalis Spyrou Comments-Addressed: Arm Jenkins --- arm_compute/core/utils/helpers/float_ops.h | 98 +++++++++++++++++ src/core/CL/CLKernelLibrary.cpp | 2 +- src/core/CL/cl_kernels/activation_layer_qa8.cl | 119 +++++++++++---------- src/core/CL/kernels/CLActivationLayerKernel.cpp | 61 +++++++---- src/core/NEON/kernels/NEActivationLayerKernel.cpp | 57 ++++++++-- tests/datasets/ActivationFunctionsDataset.h | 9 +- tests/validation/CL/ActivationLayer.cpp | 2 +- tests/validation/NEON/ActivationLayer.cpp | 3 +- tests/validation/fixtures/ActivationLayerFixture.h | 53 ++++++--- tests/validation/reference/ActivationLayer.cpp | 16 +-- tests/validation/reference/ActivationLayer.h | 4 +- 11 files changed, 307 insertions(+), 117 deletions(-) create mode 100644 arm_compute/core/utils/helpers/float_ops.h diff --git a/arm_compute/core/utils/helpers/float_ops.h b/arm_compute/core/utils/helpers/float_ops.h new file mode 100644 index 0000000000..5012c8ef84 --- /dev/null +++ b/arm_compute/core/utils/helpers/float_ops.h @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2019 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef __ARM_COMPUTE_UTILS_HELPERS_FLOAT_OPS_H__ +#define __ARM_COMPUTE_UTILS_HELPERS_FLOAT_OPS_H__ + +namespace arm_compute +{ +namespace helpers +{ +namespace float_ops +{ +union RawFloat +{ + /** Constructor + * + * @param[in] val Floating-point value + */ + explicit RawFloat(float val) + : f32(val) + { + } + /** Extract sign of floating point number + * + * @return Sign of floating point number + */ + int32_t sign() const + { + return i32 >> 31; + } + /** Extract exponent of floating point number + * + * @return Exponent of floating point number + */ + int32_t exponent() const + { + return (i32 >> 23) & 0xFF; + } + /** Extract mantissa of floating point number + * + * @return Mantissa of floating point number + */ + int32_t mantissa() const + { + return i32 & 0x007FFFFF; + } + + int32_t i32; + float f32; +}; + +/** Checks if two floating point numbers are equal given an allowed number of ULPs + * + * @param[in] a First number to compare + * @param[in] b Second number to compare + * @param[in] max_allowed_ulps Number of allowed ULPs + * + * @return True if number is close else false + */ +bool is_equal_ulps(float a, float b, int max_allowed_ulps = 0) +{ + RawFloat ra(a); + RawFloat rb(b); + + // Early check for sign + if(ra.sign() != rb.sign()) + { + return (a == b); + } + + // Check ULP distance + const int ulps = std::abs(ra.i32 - rb.i32); + return ulps <= max_allowed_ulps; +} +} // namespace float_ops +} // namespace helpers +} // namespace arm_compute +#endif /* __ARM_COMPUTE_UTILS_HELPERS_FLOAT_OPS_H__ */ diff --git a/src/core/CL/CLKernelLibrary.cpp b/src/core/CL/CLKernelLibrary.cpp index e426db28c9..9806796203 100644 --- a/src/core/CL/CLKernelLibrary.cpp +++ b/src/core/CL/CLKernelLibrary.cpp @@ -149,7 +149,7 @@ const std::map CLKernelLibrary::_kernel_program_map = { "accumulate_weighted", "accumulate.cl" }, { "activation_layer", "activation_layer.cl" }, { "activation_layer_qa8", "activation_layer_qa8.cl" }, - { "activation_layer_logistic_qa8", "activation_layer_qa8.cl" }, + { "activation_layer_qa8_f32", "activation_layer_qa8.cl" }, { "batch_to_space_nchw", "batch_to_space.cl" }, { "batch_to_space_static_nchw", "batch_to_space.cl" }, { "batch_to_space_nhwc", "batch_to_space.cl" }, diff --git a/src/core/CL/cl_kernels/activation_layer_qa8.cl b/src/core/CL/cl_kernels/activation_layer_qa8.cl index cfb61376ca..41f23ca79b 100644 --- a/src/core/CL/cl_kernels/activation_layer_qa8.cl +++ b/src/core/CL/cl_kernels/activation_layer_qa8.cl @@ -26,52 +26,25 @@ #define TYPE VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) #define VEC_FLOAT VEC_DATA_TYPE(float, VEC_SIZE) -// RELU Activation -inline TYPE relu_op(TYPE x) -{ - return max((TYPE)CONST_0, x); -} -// Bounded RELU Activation -inline TYPE brelu_op(TYPE x) -{ - return min((TYPE)A_VAL, max(CONST_0, x)); -} -// Lower Upper Bounded RELU Activation -inline TYPE lu_brelu_op(TYPE x) -{ - return min(max(x, (TYPE)B_VAL), (TYPE)A_VAL); -} +#if defined(FLOAT_DOMAIN) +// Activations performed in the float domain -#define ACTIVATION_OP2(op, x) op##_op(x) -#define ACTIVATION_OP(op, x) ACTIVATION_OP2(op, x) - -#if defined(O1_VAL) && defined(O2_VAL) && defined(S1_VAL) && defined(S2_VAL) -#define PERFORM_ACTIVATION_QA8(act, data) \ - ({ \ - data = ACTIVATION_OP(act, data); \ - \ - VEC_DATA_TYPE(float, VEC_SIZE) \ - fdata = CONVERT(data, VEC_DATA_TYPE(float, VEC_SIZE)); \ - \ - fdata = round((fdata - (float)O1_VAL) * ((float)S1_VAL / (float)S2_VAL) + (float)O2_VAL); \ - data = CONVERT_SAT(fdata, VEC_DATA_TYPE(uchar, VEC_SIZE)); \ - }) -#else /* defined(O1_VAL) && defined(O2_VAL) && defined(S1_VAL) && defined(S2_VAL) */ -#define PERFORM_ACTIVATION_QA8(act, data) \ - ({ \ - data = ACTIVATION_OP(act, data); \ - }) -#endif /* defined(O1_VAL) && defined(O2_VAL) && defined(S1_VAL) && defined(S2_VAL) */ +#include "activation_float_helpers.h" -#if defined(ACT) +#if defined(O2_VAL) && defined(S2_VAL) +#define OFFSET_OUT O2_VAL +#define SCALE_OUT S2_VAL +#else // defined(O2_VAL) && defined(S2_VAL) +#define OFFSET_OUT O1_VAL +#define SCALE_OUT S1_VAL +#endif // defined(O2_VAL) && defined(S2_VAL) -/** This performs an activation function on QASYMM8 inputs. +/** This performs an activation function on QASYMM8 inputs with float transformations. * * @note In order to perform the activation function "in-place", the pre-processor -DIN_PLACE must be passed at compile time * * @note Datatype should be given as a preprocessor argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=short * @note Vector size should be given as a preprocessor argument using -DVEC_SIZE=size. e.g. -DVEC_SIZE=16 - * @note Activation function should be given as a preprocessor argument using -DACT=name. e.g. -DACT=TANH * @note A, B variables required by some activation functions are set using -DA_VAL= and -DB_VAL= respectively. * @note Quantization scales of the input/output tensors are passed in with -DS1_VAL= and -DS2_VAL= respectively. * @note Quantization offsets of the input/output tensors are passed in with -DO1_VAL= and -DO2_VAL= respectively. @@ -94,7 +67,7 @@ inline TYPE lu_brelu_op(TYPE x) * @param[in] output_step_z (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes) * @param[in] output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image */ -__kernel void activation_layer_qa8( +__kernel void activation_layer_qa8_f32( TENSOR3D_DECLARATION(input) #ifndef IN_PLACE , @@ -113,29 +86,65 @@ __kernel void activation_layer_qa8( // Load data TYPE data = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)input.ptr); - data = PERFORM_ACTIVATION_QA8(ACT, data); + VEC_FLOAT data_flt = CONVERT(data, VEC_FLOAT); + data_flt = round(data_flt - (float)O1_VAL) * ((float)S1_VAL); + data_flt = ACTIVATION(ACT, float, data_flt, A_VAL, B_VAL); + + data = CONVERT_SAT(round(data_flt / ((float)SCALE_OUT)) + (float)OFFSET_OUT, TYPE); // Store result VSTORE(VEC_SIZE) (data, 0, (__global DATA_TYPE *)output.ptr); } -#endif /* defined(ACT) */ +#else // defined(FLOAT_DOMAIN) +// Activations performed in the quantized domain -#if defined(O2_VAL) && defined(S2_VAL) -#define OFFSET_OUT O2_VAL -#define SCALE_OUT S2_VAL -#else // defined(O2_VAL) && defined(S2_VAL) -#define OFFSET_OUT O1_VAL -#define SCALE_OUT S1_VAL -#endif // defined(O2_VAL) && defined(S2_VAL) +// RELU Activation +inline TYPE relu_op(TYPE x) +{ + return max((TYPE)CONST_0, x); +} +// Bounded RELU Activation +inline TYPE brelu_op(TYPE x) +{ + return min((TYPE)A_VAL, max(CONST_0, x)); +} +// Lower Upper Bounded RELU Activation +inline TYPE lu_brelu_op(TYPE x) +{ + return min(max(x, (TYPE)B_VAL), (TYPE)A_VAL); +} + +#define ACTIVATION_OP2(op, x) op##_op(x) +#define ACTIVATION_OP(op, x) ACTIVATION_OP2(op, x) + +#if defined(O1_VAL) && defined(O2_VAL) && defined(S1_VAL) && defined(S2_VAL) +#define PERFORM_ACTIVATION_QA8(act, data) \ + ({ \ + data = ACTIVATION_OP(act, data); \ + \ + VEC_DATA_TYPE(float, VEC_SIZE) \ + fdata = CONVERT(data, VEC_DATA_TYPE(float, VEC_SIZE)); \ + \ + fdata = round((fdata - (float)O1_VAL) * ((float)S1_VAL / (float)S2_VAL) + (float)O2_VAL); \ + data = CONVERT_SAT(fdata, VEC_DATA_TYPE(uchar, VEC_SIZE)); \ + }) +#else /* defined(O1_VAL) && defined(O2_VAL) && defined(S1_VAL) && defined(S2_VAL) */ +#define PERFORM_ACTIVATION_QA8(act, data) \ + ({ \ + data = ACTIVATION_OP(act, data); \ + }) +#endif /* defined(O1_VAL) && defined(O2_VAL) && defined(S1_VAL) && defined(S2_VAL) */ -/** This performs a Logistic activation function on QASYMM8 inputs. +#if defined(ACT) +/** This performs an activation function on QASYMM8 inputs. * * @note In order to perform the activation function "in-place", the pre-processor -DIN_PLACE must be passed at compile time * * @note Datatype should be given as a preprocessor argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=short * @note Vector size should be given as a preprocessor argument using -DVEC_SIZE=size. e.g. -DVEC_SIZE=16 + * @note Activation function should be given as a preprocessor argument using -DACT=name. e.g. -DACT=TANH * @note A, B variables required by some activation functions are set using -DA_VAL= and -DB_VAL= respectively. * @note Quantization scales of the input/output tensors are passed in with -DS1_VAL= and -DS2_VAL= respectively. * @note Quantization offsets of the input/output tensors are passed in with -DO1_VAL= and -DO2_VAL= respectively. @@ -158,7 +167,7 @@ __kernel void activation_layer_qa8( * @param[in] output_step_z (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes) * @param[in] output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image */ -__kernel void activation_layer_logistic_qa8( +__kernel void activation_layer_qa8( TENSOR3D_DECLARATION(input) #ifndef IN_PLACE , @@ -167,7 +176,7 @@ __kernel void activation_layer_logistic_qa8( ) { // Get pixels pointer - Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT(input); + Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT(input); #ifdef IN_PLACE Tensor3D output = input; #else /* IN_PLACE */ @@ -177,13 +186,11 @@ __kernel void activation_layer_logistic_qa8( // Load data TYPE data = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)input.ptr); - VEC_FLOAT data_flt = CONVERT(data, VEC_FLOAT); - data_flt = round(data_flt - (float)O1_VAL) * ((float)S1_VAL); - data_flt = 1.f / (1.f + exp(-data_flt)); - - data = CONVERT_SAT(round(data_flt / ((float)SCALE_OUT)) + (float)OFFSET_OUT, TYPE); + data = PERFORM_ACTIVATION_QA8(ACT, data); // Store result VSTORE(VEC_SIZE) (data, 0, (__global DATA_TYPE *)output.ptr); } +#endif // defined(ACT) +#endif // defined(FLOAT_DOMAIN) diff --git a/src/core/CL/kernels/CLActivationLayerKernel.cpp b/src/core/CL/kernels/CLActivationLayerKernel.cpp index 65e6561b0a..34d1298d61 100644 --- a/src/core/CL/kernels/CLActivationLayerKernel.cpp +++ b/src/core/CL/kernels/CLActivationLayerKernel.cpp @@ -30,14 +30,14 @@ #include "arm_compute/core/Helpers.h" #include "arm_compute/core/IAccessWindow.h" #include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/Types.h" #include "arm_compute/core/Utils.h" #include "arm_compute/core/Window.h" - -#include "arm_compute/core/CL/CLHelpers.h" -#include "arm_compute/core/Types.h" +#include "arm_compute/core/utils/helpers/float_ops.h" #include "support/ToolchainSupport.h" #include +#include using namespace arm_compute; @@ -47,11 +47,23 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c { ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::QASYMM8, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MSG((input->data_type() == DataType::QASYMM8) && (act_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU) - && (act_info.activation() != ActivationLayerInfo::ActivationFunction::BOUNDED_RELU) - && (act_info.activation() != ActivationLayerInfo::ActivationFunction::RELU) - && (act_info.activation() != ActivationLayerInfo::ActivationFunction::LOGISTIC), - "For QASYMM8 only logistic, relu, lower bounded relu and lower-upper bounded relu are supported"); + + static std::set qs8_supported_activations = + { + ActivationLayerInfo::ActivationFunction::RELU, + ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, + ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, + ActivationLayerInfo::ActivationFunction::LOGISTIC, + ActivationLayerInfo::ActivationFunction::TANH + }; + const DataType data_type = input->data_type(); + const QuantizationInfo &oq_info = (output != nullptr) ? output->quantization_info() : input->quantization_info(); + const ActivationLayerInfo::ActivationFunction f_act = act_info.activation(); + + ARM_COMPUTE_RETURN_ERROR_ON_MSG(is_data_type_quantized_asymmetric(data_type) && (qs8_supported_activations.count(f_act) == 0), + "For QASYMM8 only tanh, logistic, relu and lower/upper bounded relu are supported"); + ARM_COMPUTE_RETURN_ERROR_ON(is_data_type_quantized_asymmetric(data_type) && (f_act == ActivationLayerInfo::ActivationFunction::TANH) && (oq_info != QuantizationInfo(1.f / 128.f, 128))); + ARM_COMPUTE_RETURN_ERROR_ON(is_data_type_quantized_asymmetric(data_type) && (f_act == ActivationLayerInfo::ActivationFunction::LOGISTIC) && (oq_info != QuantizationInfo(1.f / 256.f, 0))); // Checks performed when output is configured if((output != nullptr) && (output->total_size() != 0)) @@ -122,7 +134,10 @@ void CLActivationLayerKernel::configure(ICLTensor *input, ICLTensor *output, Act int a_const_int = 0; int b_const_int = 0; - const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(dt); + const ActivationLayerInfo::ActivationFunction f_act = act_info.activation(); + const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(dt); + const bool perform_activation_in_float = (f_act == ActivationLayerInfo::ActivationFunction::LOGISTIC) || (f_act == ActivationLayerInfo::ActivationFunction::TANH); + // Create quantized version of constants a, b if needed if(is_quantized_asymmetric) { @@ -131,18 +146,29 @@ void CLActivationLayerKernel::configure(ICLTensor *input, ICLTensor *output, Act b_const_int = quantize_qasymm8(b_const, iq_info); } - const bool is_logistic_activation_quantized = is_quantized_asymmetric && act_info.activation() == ActivationLayerInfo::ActivationFunction::LOGISTIC; // Set build options CLBuildOptions build_opts; - build_opts.add_option_if(!is_logistic_activation_quantized, "-DACT=" + lower_string(string_from_activation_func(act_info.activation()))); + build_opts.add_option_if(perform_activation_in_float, "-DFLOAT_DOMAIN"); + build_opts.add_option_if(_run_in_place, "-DIN_PLACE"); + build_opts.add_option(("-DACT=" + lower_string(string_from_activation_func(f_act)))); build_opts.add_option(("-DDATA_TYPE=" + get_cl_type_from_data_type(dt))); build_opts.add_option(("-DVEC_SIZE=" + support::cpp11::to_string(num_elems_processed_per_iteration))); - if(is_quantized_asymmetric) + // Set A, B constants in build options + if(is_quantized_asymmetric && !perform_activation_in_float) { build_opts.add_option(("-DA_VAL=" + support::cpp11::to_string(a_const_int))); build_opts.add_option(("-DB_VAL=" + support::cpp11::to_string(b_const_int))); + } + else + { + build_opts.add_option(("-DA_VAL=" + float_to_string_with_full_precision(a_const))); + build_opts.add_option(("-DB_VAL=" + float_to_string_with_full_precision(b_const))); + } + // Set quantization info build options + if(is_quantized_asymmetric) + { const UniformQuantizationInfo iq_info = input->info()->quantization_info().uniform(); // Quantized value of 0 corresponds to the offset o1 @@ -151,7 +177,7 @@ void CLActivationLayerKernel::configure(ICLTensor *input, ICLTensor *output, Act build_opts.add_option(("-DO1_VAL=" + support::cpp11::to_string(iq_info.offset))); // Set scale and offset of the input and output if they have different quantization info - if(is_quantized_asymmetric && output != nullptr) + if(output != nullptr) { const UniformQuantizationInfo oq_info = output->info()->quantization_info().uniform(); @@ -162,19 +188,12 @@ void CLActivationLayerKernel::configure(ICLTensor *input, ICLTensor *output, Act } } } - else - { - build_opts.add_option(("-DA_VAL=" + float_to_string_with_full_precision(a_const))); - build_opts.add_option(("-DB_VAL=" + float_to_string_with_full_precision(b_const))); - } - - build_opts.add_option_if(_run_in_place, "-DIN_PLACE"); // Create kernel std::string kernel_name = std::string("activation_layer"); if(is_quantized_asymmetric) { - kernel_name += is_logistic_activation_quantized ? std::string("_logistic_qa8") : std::string("_qa8"); + kernel_name += perform_activation_in_float ? std::string("_qa8_f32") : std::string("_qa8"); } _kernel = static_cast(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options())); diff --git a/src/core/NEON/kernels/NEActivationLayerKernel.cpp b/src/core/NEON/kernels/NEActivationLayerKernel.cpp index 3f71553926..64342512a0 100644 --- a/src/core/NEON/kernels/NEActivationLayerKernel.cpp +++ b/src/core/NEON/kernels/NEActivationLayerKernel.cpp @@ -39,15 +39,33 @@ #include #include #include +#include using namespace arm_compute; namespace { -Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output) +Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const ActivationLayerInfo &activation_info) { ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::QASYMM8, DataType::F16, DataType::F32); + static std::set qs8_supported_activations = + { + ActivationLayerInfo::ActivationFunction::RELU, + ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, + ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, + ActivationLayerInfo::ActivationFunction::LOGISTIC, + ActivationLayerInfo::ActivationFunction::TANH + }; + const DataType data_type = input->data_type(); + const QuantizationInfo &oq_info = (output != nullptr) ? output->quantization_info() : input->quantization_info(); + const ActivationLayerInfo::ActivationFunction f_act = activation_info.activation(); + + ARM_COMPUTE_RETURN_ERROR_ON_MSG(is_data_type_quantized_asymmetric(data_type) && (qs8_supported_activations.count(f_act) == 0), + "For QASYMM8 only tanh, logistic, relu and lower/upper bounded relu are supported"); + ARM_COMPUTE_RETURN_ERROR_ON(is_data_type_quantized_asymmetric(data_type) && (f_act == ActivationLayerInfo::ActivationFunction::TANH) && (oq_info != QuantizationInfo(1.f / 128.f, 128))); + ARM_COMPUTE_RETURN_ERROR_ON(is_data_type_quantized_asymmetric(data_type) && (f_act == ActivationLayerInfo::ActivationFunction::LOGISTIC) && (oq_info != QuantizationInfo(1.f / 256.f, 0))); + // Checks performed when output is configured if((output != nullptr) && (output->total_size() != 0)) { @@ -96,12 +114,7 @@ void NEActivationLayerKernel::configure(ITensor *input, ITensor *output, Activat _output = output; } - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (output != nullptr) ? output->info() : nullptr)); - - ARM_COMPUTE_ERROR_ON_MSG((input->info()->data_type() == DataType::QASYMM8) && (activation_info.activation() != ActivationLayerInfo::ActivationFunction::RELU) - && (activation_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU) && (activation_info.activation() != ActivationLayerInfo::ActivationFunction::BOUNDED_RELU) - && (activation_info.activation() != ActivationLayerInfo::ActivationFunction::LOGISTIC), - "For QASYMM8 only logistic, relu and lower/upper bounded relu are supported"); + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (output != nullptr) ? output->info() : nullptr, activation_info)); // Activation functions : FP32 static std::map act_map_f32 = @@ -146,6 +159,7 @@ void NEActivationLayerKernel::configure(ITensor *input, ITensor *output, Activat { ActivationFunction::BOUNDED_RELU, &NEActivationLayerKernel::activation }, { ActivationFunction::LU_BOUNDED_RELU, &NEActivationLayerKernel::activation }, { ActivationFunction::RELU, &NEActivationLayerKernel::activation }, + { ActivationFunction::TANH, &NEActivationLayerKernel::activation }, { ActivationFunction::IDENTITY, &NEActivationLayerKernel::activation }, }; @@ -328,6 +342,10 @@ typename std::enable_if::value, void>::type NEActivat const qasymm8_t const_0 = quantize_qasymm8(0.f, qi_in); const qasymm8x16_t vconst_0 = vdupq_n_u8(const_0); const auto vconst_1 = vdupq_n_f32(1.f); + const float32x4_t va_f32 = vdupq_n_f32(_act_info.a()); + const float32x4_t vb_f32 = vdupq_n_f32(_act_info.b()); + const float a_f32 = _act_info.a(); + const float b_f32 = _act_info.b(); // Initialise scale/offset for re-quantization float s = qi_in.scale / qi_out.scale; @@ -385,6 +403,23 @@ typename std::enable_if::value, void>::type NEActivat // Re-quantize to new output space tmp = vquantize(tmp_dep, qi_out); } + else if(act == ActivationFunction::TANH) + { + // De-quantize + const auto vin_deq = vdequantize(vin, qi_in); + // Perform activation + const float32x4x4_t tmp_dep = + { + { + wrapper::vmul(va_f32, wrapper::vtanh(wrapper::vmul(vin_deq.val[0], vb_f32))), + wrapper::vmul(va_f32, wrapper::vtanh(wrapper::vmul(vin_deq.val[1], vb_f32))), + wrapper::vmul(va_f32, wrapper::vtanh(wrapper::vmul(vin_deq.val[2], vb_f32))), + wrapper::vmul(va_f32, wrapper::vtanh(wrapper::vmul(vin_deq.val[3], vb_f32))), + } + }; + // Re-quantize to new output space + tmp = vquantize(tmp_dep, qi_out); + } else { ARM_COMPUTE_ERROR("Unsupported activation function"); @@ -418,6 +453,12 @@ typename std::enable_if::value, void>::type NEActivat tmp_f = 1.f / (1.f + std::exp(-tmp_f)); tmp = quantize_qasymm8(tmp_f, qi_out); } + else if(act == ActivationFunction::TANH) + { + float tmp_f = dequantize_qasymm8(in, qi_in); + tmp_f = a_f32 * std::tanh(b_f32 * tmp_f); + tmp = quantize_qasymm8(tmp_f, qi_out); + } else { ARM_COMPUTE_ERROR("Unsupported activation function"); @@ -431,7 +472,7 @@ typename std::enable_if::value, void>::type NEActivat Status NEActivationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const ActivationLayerInfo &act_info) { ARM_COMPUTE_UNUSED(act_info); - ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, act_info)); ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), (output != nullptr) ? output->clone().get() : nullptr).first); return Status{}; diff --git a/tests/datasets/ActivationFunctionsDataset.h b/tests/datasets/ActivationFunctionsDataset.h index c5dc28f0c3..29fb21cec0 100644 --- a/tests/datasets/ActivationFunctionsDataset.h +++ b/tests/datasets/ActivationFunctionsDataset.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -64,10 +64,11 @@ public: ActivationFunctionsQuantized() : ContainerDataset("ActivationFunctionQuantized", { - ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, - ActivationLayerInfo::ActivationFunction::RELU, + ActivationLayerInfo::ActivationFunction::RELU, + ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, + ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, ActivationLayerInfo::ActivationFunction::LOGISTIC, - ActivationLayerInfo::ActivationFunction::BOUNDED_RELU + ActivationLayerInfo::ActivationFunction::TANH, }) { } diff --git a/tests/validation/CL/ActivationLayer.cpp b/tests/validation/CL/ActivationLayer.cpp index e95db7ca60..a286458483 100644 --- a/tests/validation/CL/ActivationLayer.cpp +++ b/tests/validation/CL/ActivationLayer.cpp @@ -137,7 +137,7 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip( TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), // Window shrink TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QASYMM8), // Unsupported activation + TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QASYMM8), // Invalid quantization info TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), // Mismatching shapes }), framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F16), diff --git a/tests/validation/NEON/ActivationLayer.cpp b/tests/validation/NEON/ActivationLayer.cpp index 3a91c9c3be..1d8fffa903 100644 --- a/tests/validation/NEON/ActivationLayer.cpp +++ b/tests/validation/NEON/ActivationLayer.cpp @@ -223,7 +223,8 @@ using NEActivationLayerQuantizedFixture = ActivationValidationQuantizedFixture void setup(TensorShape shape, bool in_place, ActivationLayerInfo::ActivationFunction function, float alpha_beta, DataType data_type, QuantizationInfo quantization_info) { - _quantization_info = quantization_info; - _data_type = data_type; - _function = function; - ActivationLayerInfo info(function, alpha_beta, alpha_beta); - _target = compute_target(shape, in_place, info, data_type, quantization_info); - _reference = compute_reference(shape, info, data_type, quantization_info); + _in_place = in_place; + _output_quantization_info = calculate_output_quantization_info(info, quantization_info); + _input_quantization_info = in_place ? _output_quantization_info : quantization_info; + _data_type = data_type; + _function = function; + + _target = compute_target(shape, info); + _reference = compute_reference(shape, info); } protected: @@ -85,16 +87,16 @@ protected: } } - TensorType compute_target(const TensorShape &shape, bool in_place, ActivationLayerInfo info, DataType data_type, QuantizationInfo quantization_info) + TensorType compute_target(const TensorShape &shape, ActivationLayerInfo info) { // Create tensors - TensorType src = create_tensor(shape, data_type, 1, quantization_info); - TensorType dst = create_tensor(shape, data_type, 1, quantization_info); + TensorType src = create_tensor(shape, _data_type, 1, _input_quantization_info); + TensorType dst = create_tensor(shape, _data_type, 1, _output_quantization_info); // Create and configure function FunctionType act_layer; - TensorType *dst_ptr = in_place ? &src : &dst; + TensorType *dst_ptr = _in_place ? &src : &dst; act_layer.configure(&src, dst_ptr, info); @@ -105,7 +107,7 @@ protected: src.allocator()->allocate(); ARM_COMPUTE_EXPECT(!src.info()->is_resizable(), framework::LogLevel::ERRORS); - if(!in_place) + if(!_in_place) { dst.allocator()->allocate(); ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -117,7 +119,7 @@ protected: // Compute function act_layer.run(); - if(in_place) + if(_in_place) { return src; } @@ -127,20 +129,37 @@ protected: } } - SimpleTensor compute_reference(const TensorShape &shape, ActivationLayerInfo info, DataType data_type, QuantizationInfo quantization_info) + SimpleTensor compute_reference(const TensorShape &shape, ActivationLayerInfo info) { // Create reference - SimpleTensor src{ shape, data_type, 1, quantization_info }; + SimpleTensor src{ shape, _data_type, 1, _input_quantization_info }; // Fill reference fill(src); - return reference::activation_layer(src, info); + return reference::activation_layer(src, info, _output_quantization_info); + } + +private: + QuantizationInfo calculate_output_quantization_info(const ActivationLayerInfo &act_info, const QuantizationInfo &default_qinfo) + { + switch(act_info.activation()) + { + case ActivationLayerInfo::ActivationFunction::TANH: + return QuantizationInfo(1.f / 128.f, 128); + case ActivationLayerInfo::ActivationFunction::LOGISTIC: + return QuantizationInfo(1.f / 256.f, 0); + default: + return default_qinfo; + } } +protected: TensorType _target{}; SimpleTensor _reference{}; - QuantizationInfo _quantization_info{}; + bool _in_place{}; + QuantizationInfo _input_quantization_info{}; + QuantizationInfo _output_quantization_info{}; DataType _data_type{}; ActivationLayerInfo::ActivationFunction _function{}; }; diff --git a/tests/validation/reference/ActivationLayer.cpp b/tests/validation/reference/ActivationLayer.cpp index 9887e42386..f5e98aa7e8 100644 --- a/tests/validation/reference/ActivationLayer.cpp +++ b/tests/validation/reference/ActivationLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -35,8 +35,10 @@ namespace validation namespace reference { template -SimpleTensor activation_layer(const SimpleTensor &src, ActivationLayerInfo info) +SimpleTensor activation_layer(const SimpleTensor &src, ActivationLayerInfo info, const QuantizationInfo &oq_info) { + ARM_COMPUTE_UNUSED(oq_info); + // Create reference SimpleTensor dst{ src.shape(), src.data_type(), 1 }; @@ -53,16 +55,18 @@ SimpleTensor activation_layer(const SimpleTensor &src, ActivationLayerInfo } template <> -SimpleTensor activation_layer(const SimpleTensor &src, ActivationLayerInfo info) +SimpleTensor activation_layer(const SimpleTensor &src, ActivationLayerInfo info, const QuantizationInfo &oq_info) { + const QuantizationInfo dst_qinfo = oq_info.empty() ? src.quantization_info() : oq_info; + SimpleTensor src_tmp = convert_from_asymmetric(src); SimpleTensor dst_tmp = activation_layer(src_tmp, info); - SimpleTensor dst = convert_to_asymmetric(dst_tmp, src.quantization_info()); + SimpleTensor dst = convert_to_asymmetric(dst_tmp, dst_qinfo); return dst; } -template SimpleTensor activation_layer(const SimpleTensor &src, ActivationLayerInfo info); -template SimpleTensor activation_layer(const SimpleTensor &src, ActivationLayerInfo info); +template SimpleTensor activation_layer(const SimpleTensor &src, ActivationLayerInfo info, const QuantizationInfo &oq_info); +template SimpleTensor activation_layer(const SimpleTensor &src, ActivationLayerInfo info, const QuantizationInfo &oq_info); } // namespace reference } // namespace validation } // namespace test diff --git a/tests/validation/reference/ActivationLayer.h b/tests/validation/reference/ActivationLayer.h index cc43bc10a8..5beca7c76d 100644 --- a/tests/validation/reference/ActivationLayer.h +++ b/tests/validation/reference/ActivationLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -87,7 +87,7 @@ inline T activate_float(T x, T a, T b, ActivationLayerInfo::ActivationFunction a } template -SimpleTensor activation_layer(const SimpleTensor &src, ActivationLayerInfo info); +SimpleTensor activation_layer(const SimpleTensor &src, ActivationLayerInfo info, const QuantizationInfo &oq_info = QuantizationInfo()); } // namespace reference } // namespace validation } // namespace test -- cgit v1.2.1