aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/CpuActivationKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/CpuActivationKernel.cpp')
-rw-r--r--src/cpu/kernels/CpuActivationKernel.cpp18
1 files changed, 10 insertions, 8 deletions
diff --git a/src/cpu/kernels/CpuActivationKernel.cpp b/src/cpu/kernels/CpuActivationKernel.cpp
index ee9db99080..61efcb2dd6 100644
--- a/src/cpu/kernels/CpuActivationKernel.cpp
+++ b/src/cpu/kernels/CpuActivationKernel.cpp
@@ -46,7 +46,8 @@ namespace
static const std::vector<CpuActivationKernel::ActivationKernel> available_kernels =
{
#ifdef __aarch64__
- { // Neon LUT implementantion takes precedence
+ {
+ // Neon LUT implementantion takes precedence
"neon_q8_activation_lut",
[](const ActivationDataTypeISASelectorData & data) { return ActivationLayerInfo::is_lut_supported(data.f, data.dt); },
REGISTER_Q8_NEON(arm_compute::cpu::neon_q8_activation_lut)
@@ -54,27 +55,27 @@ static const std::vector<CpuActivationKernel::ActivationKernel> available_kernel
#endif // __aarch64__
{
"sve2_qu8_activation",
- [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8 && data.isa.sve2; },
+ [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8 && data.isa.sve2 && data.f != ActivationLayerInfo::ActivationFunction::GELU; },
REGISTER_QASYMM8_SVE2(arm_compute::cpu::sve2_qasymm8_activation)
},
{
"sve2_qs8_activation",
- [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED && data.isa.sve2; },
+ [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED && data.isa.sve2 && data.f != ActivationLayerInfo::ActivationFunction::GELU; },
REGISTER_QASYMM8_SIGNED_SVE2(arm_compute::cpu::sve2_qasymm8_signed_activation)
},
{
"sve2_qs16_activation",
- [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::QSYMM16 && data.isa.sve2; },
+ [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::QSYMM16 && data.isa.sve2 && data.f != ActivationLayerInfo::ActivationFunction::GELU; },
REGISTER_QSYMM16_SVE2(arm_compute::cpu::sve2_qsymm16_activation)
},
{
"sve_fp16_activation",
- [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::F16 && data.isa.sve && data.isa.fp16; },
+ [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::F16 && data.isa.sve && data.isa.fp16 && data.f != ActivationLayerInfo::ActivationFunction::GELU; },
REGISTER_FP16_SVE(arm_compute::cpu::sve_fp16_activation)
},
{
"sve_fp32_activation",
- [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::F32 && data.isa.sve; },
+ [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::F32 && data.isa.sve && data.f != ActivationLayerInfo::ActivationFunction::GELU; },
REGISTER_FP32_SVE(arm_compute::cpu::sve_fp32_activation)
},
{
@@ -105,7 +106,7 @@ static const std::vector<CpuActivationKernel::ActivationKernel> available_kernel
};
/* Supported activation in the 8-bit integer domain */
-static const std::array<ActivationLayerInfo::ActivationFunction, 7> qasymm8_activations =
+static const std::array<ActivationLayerInfo::ActivationFunction, 8> qasymm8_activations =
{
ActivationLayerInfo::ActivationFunction::RELU,
ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
@@ -114,6 +115,7 @@ static const std::array<ActivationLayerInfo::ActivationFunction, 7> qasymm8_acti
ActivationLayerInfo::ActivationFunction::TANH,
ActivationLayerInfo::ActivationFunction::HARD_SWISH,
ActivationLayerInfo::ActivationFunction::LEAKY_RELU,
+ ActivationLayerInfo::ActivationFunction::GELU,
};
/* Supported activation in the 16-bit integer domain */
static const std::array<ActivationLayerInfo::ActivationFunction, 4> qsymm16_activations =
@@ -193,7 +195,7 @@ void CpuActivationKernel::configure(const ITensorInfo *src, ITensorInfo *dst, Ac
#ifdef __aarch64__
if(ActivationLayerInfo::is_lut_supported(activation_info.activation(), src->data_type()))
{
- activation_info.init_lut(src->data_type(), src->quantization_info().uniform(), (dst)?dst->quantization_info().uniform():src->quantization_info().uniform());
+ activation_info.init_lut(src->data_type(), src->quantization_info().uniform(), (dst) ? dst->quantization_info().uniform() : src->quantization_info().uniform());
}
#endif // __aarch64__
_act_info = activation_info;