diff options
Diffstat (limited to 'src/cpu/kernels/CpuActivationKernel.cpp')
-rw-r--r-- | src/cpu/kernels/CpuActivationKernel.cpp | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/src/cpu/kernels/CpuActivationKernel.cpp b/src/cpu/kernels/CpuActivationKernel.cpp index ee9db99080..61efcb2dd6 100644 --- a/src/cpu/kernels/CpuActivationKernel.cpp +++ b/src/cpu/kernels/CpuActivationKernel.cpp @@ -46,7 +46,8 @@ namespace static const std::vector<CpuActivationKernel::ActivationKernel> available_kernels = { #ifdef __aarch64__ - { // Neon LUT implementantion takes precedence + { + // Neon LUT implementantion takes precedence "neon_q8_activation_lut", [](const ActivationDataTypeISASelectorData & data) { return ActivationLayerInfo::is_lut_supported(data.f, data.dt); }, REGISTER_Q8_NEON(arm_compute::cpu::neon_q8_activation_lut) @@ -54,27 +55,27 @@ static const std::vector<CpuActivationKernel::ActivationKernel> available_kernel #endif // __aarch64__ { "sve2_qu8_activation", - [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8 && data.isa.sve2; }, + [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8 && data.isa.sve2 && data.f != ActivationLayerInfo::ActivationFunction::GELU; }, REGISTER_QASYMM8_SVE2(arm_compute::cpu::sve2_qasymm8_activation) }, { "sve2_qs8_activation", - [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED && data.isa.sve2; }, + [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED && data.isa.sve2 && data.f != ActivationLayerInfo::ActivationFunction::GELU; }, REGISTER_QASYMM8_SIGNED_SVE2(arm_compute::cpu::sve2_qasymm8_signed_activation) }, { "sve2_qs16_activation", - [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::QSYMM16 && data.isa.sve2; }, + [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::QSYMM16 && data.isa.sve2 && data.f != ActivationLayerInfo::ActivationFunction::GELU; }, REGISTER_QSYMM16_SVE2(arm_compute::cpu::sve2_qsymm16_activation) }, { "sve_fp16_activation", - [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::F16 && data.isa.sve && data.isa.fp16; }, + [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::F16 && data.isa.sve && data.isa.fp16 && data.f != ActivationLayerInfo::ActivationFunction::GELU; }, REGISTER_FP16_SVE(arm_compute::cpu::sve_fp16_activation) }, { "sve_fp32_activation", - [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::F32 && data.isa.sve; }, + [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::F32 && data.isa.sve && data.f != ActivationLayerInfo::ActivationFunction::GELU; }, REGISTER_FP32_SVE(arm_compute::cpu::sve_fp32_activation) }, { @@ -105,7 +106,7 @@ static const std::vector<CpuActivationKernel::ActivationKernel> available_kernel }; /* Supported activation in the 8-bit integer domain */ -static const std::array<ActivationLayerInfo::ActivationFunction, 7> qasymm8_activations = +static const std::array<ActivationLayerInfo::ActivationFunction, 8> qasymm8_activations = { ActivationLayerInfo::ActivationFunction::RELU, ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, @@ -114,6 +115,7 @@ static const std::array<ActivationLayerInfo::ActivationFunction, 7> qasymm8_acti ActivationLayerInfo::ActivationFunction::TANH, ActivationLayerInfo::ActivationFunction::HARD_SWISH, ActivationLayerInfo::ActivationFunction::LEAKY_RELU, + ActivationLayerInfo::ActivationFunction::GELU, }; /* Supported activation in the 16-bit integer domain */ static const std::array<ActivationLayerInfo::ActivationFunction, 4> qsymm16_activations = @@ -193,7 +195,7 @@ void CpuActivationKernel::configure(const ITensorInfo *src, ITensorInfo *dst, Ac #ifdef __aarch64__ if(ActivationLayerInfo::is_lut_supported(activation_info.activation(), src->data_type())) { - activation_info.init_lut(src->data_type(), src->quantization_info().uniform(), (dst)?dst->quantization_info().uniform():src->quantization_info().uniform()); + activation_info.init_lut(src->data_type(), src->quantization_info().uniform(), (dst) ? dst->quantization_info().uniform() : src->quantization_info().uniform()); } #endif // __aarch64__ _act_info = activation_info; |