diff options
Diffstat (limited to 'src/cpu/kernels/CpuActivationKernel.cpp')
-rw-r--r-- | src/cpu/kernels/CpuActivationKernel.cpp | 18 |
1 files changed, 11 insertions, 7 deletions
diff --git a/src/cpu/kernels/CpuActivationKernel.cpp b/src/cpu/kernels/CpuActivationKernel.cpp index 7cfa39b286..4253027231 100644 --- a/src/cpu/kernels/CpuActivationKernel.cpp +++ b/src/cpu/kernels/CpuActivationKernel.cpp @@ -43,6 +43,13 @@ namespace kernels { namespace { + +bool is_fp16_lut_supported(ActivationLayerInfo::ActivationFunction func) +{ + return func == ActivationLayerInfo::ActivationFunction::LOGISTIC || + func == ActivationLayerInfo::ActivationFunction::TANH; +} + static const std::vector<CpuActivationKernel::ActivationKernel> available_kernels = { #ifdef ARM_COMPUTE_ENABLE_SVE {"sve2_q8_activation_lut", @@ -85,10 +92,7 @@ static const std::vector<CpuActivationKernel::ActivationKernel> available_kernel REGISTER_QSYMM16_SVE2(arm_compute::cpu::sve2_qsymm16_activation)}, {"sve_fp16_activation_lut", [](const ActivationDataTypeISASelectorData &data) - { - return data.dt == DataType::F16 && data.isa.fp16 && data.isa.sve && - data.f == ActivationLayerInfo::ActivationFunction::LOGISTIC; - }, + { return data.dt == DataType::F16 && data.isa.fp16 && data.isa.sve && is_fp16_lut_supported(data.f); }, REGISTER_FP16_SVE(arm_compute::cpu::sve_fp16_activation_lut)}, {"sve_fp16_activation", [](const ActivationDataTypeISASelectorData &data) @@ -299,10 +303,10 @@ void CpuActivationKernel::configure(const ITensorInfo *src, ITensorInfo *dst, Ac activation_info.setLookupTable256(tmp_lut); } - if (src->data_type() == DataType::F16 && - activation_info.activation() == ActivationLayerInfo::ActivationFunction::LOGISTIC) + if (std::string(uk->name) == "sve_fp16_activation_lut") { - const LUTInfo info = {activation_info.activation(), src->data_type(), src->quantization_info()}; + const LUTInfo info = {activation_info.activation(), activation_info.a(), activation_info.b(), src->data_type(), + src->quantization_info().uniform()}; activation_info.setLookupTable65536((lut_manager.get_lut_table(info))); } #endif // __aarch64__ |