diff options
Diffstat (limited to 'src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp | 19 |
1 files changed, 10 insertions, 9 deletions
diff --git a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp index 92000bb2f6..46551553c9 100644 --- a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp +++ b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp @@ -48,7 +48,8 @@ namespace { struct BatchNormalizationSelectorData { - DataType dt; + DataType dt; + const CPUInfo &ci; }; using BatchNormalizationSelectorPtr = std::add_pointer<bool(const BatchNormalizationSelectorData &data)>::type; using BatchNormalizationKernelPtr = std::add_pointer<void(ITensor *, ITensor *, const ITensor *, const ITensor *, const ITensor *, const ITensor *, @@ -63,19 +64,19 @@ struct BatchNormalizationKernel static const BatchNormalizationKernel available_kernels[] = { -#if defined(ENABLE_SVE) +#if defined(ARM_COMPUTE_ENABLE_SVE) { "fp16_sve_batch_normalization", - [](const BatchNormalizationSelectorData & data) { return data.dt == DataType::F16; }, + [](const BatchNormalizationSelectorData & data) { return data.dt == DataType::F16 && data.ci.has_sve(); }, REGISTER_FP16_SVE(arm_compute::cpu::fp16_sve_batch_normalization) }, { "f32_sve_batch_normalization", - [](const BatchNormalizationSelectorData & data) { return data.dt == DataType::F32; }, + [](const BatchNormalizationSelectorData & data) { return data.dt == DataType::F32 && data.ci.has_sve(); }, REGISTER_FP32_SVE(arm_compute::cpu::fp32_sve_batch_normalization) }, -#endif /* !defined(ENABLE_SVE) */ -#if defined(ENABLE_NEON) +#endif /* !defined(ARM_COMPUTE_ENABLE_SVE) */ +#if defined(ARM_COMPUTE_ENABLE_NEON) #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) { "fp16_neon_batch_normalization", @@ -88,7 +89,7 @@ static const BatchNormalizationKernel available_kernels[] = [](const BatchNormalizationSelectorData & data) { return data.dt == DataType::F32; }, REGISTER_FP32_NEON(arm_compute::cpu::fp32_neon_batch_normalization) }, -#endif /* !defined(ENABLE_NEON) */ +#endif /* !defined(ARM_COMPUTE_ENABLE_NEON) */ }; const BatchNormalizationKernel *get_implementation(const BatchNormalizationSelectorData &data) @@ -109,7 +110,7 @@ validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const IT { ARM_COMPUTE_UNUSED(epsilon); - const auto *uk = get_implementation(BatchNormalizationSelectorData{ input->data_type() }); + const auto *uk = get_implementation(BatchNormalizationSelectorData{ input->data_type(), CPUInfo::get() }); ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr); if(act_info.enabled()) @@ -387,7 +388,7 @@ void NEBatchNormalizationLayerKernel::run(const Window &window, const ThreadInfo } else { - const auto *uk = get_implementation(BatchNormalizationSelectorData{ _input->info()->data_type() }); + const auto *uk = get_implementation(BatchNormalizationSelectorData{ _input->info()->data_type(), CPUInfo::get() }); uk->ukernel(_input, _output, _mean, _var, _beta, _gamma, _epsilon, _act_info, window); } } |