diff options
Diffstat (limited to 'src/core/NEON/kernels/NEMeanStdDevNormalizationKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/NEMeanStdDevNormalizationKernel.cpp | 54 |
1 files changed, 24 insertions, 30 deletions
diff --git a/src/core/NEON/kernels/NEMeanStdDevNormalizationKernel.cpp b/src/core/NEON/kernels/NEMeanStdDevNormalizationKernel.cpp index 37e88a8565..451031d696 100644 --- a/src/core/NEON/kernels/NEMeanStdDevNormalizationKernel.cpp +++ b/src/core/NEON/kernels/NEMeanStdDevNormalizationKernel.cpp @@ -28,12 +28,13 @@ #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/Window.h" -#include "src/core/CPP/Validate.h" -#include "src/core/NEON/NEMath.h" -#include "src/core/NEON/wrapper/wrapper.h" + #include "src/core/common/Registrars.h" +#include "src/core/CPP/Validate.h" #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/WindowHelpers.h" +#include "src/core/NEON/NEMath.h" +#include "src/core/NEON/wrapper/wrapper.h" #include "src/cpu/kernels/meanstddevnorm/list.h" namespace arm_compute @@ -46,7 +47,8 @@ struct MeanStdDevNormSelectorData }; using MeanStdDevNormSelctorPtr = std::add_pointer<bool(const MeanStdDevNormSelectorData &data)>::type; -using MeanStdDevNormUKernelPtr = std::add_pointer<void(ITensor *input, ITensor *output, float epsilon, const Window &window)>::type; +using MeanStdDevNormUKernelPtr = + std::add_pointer<void(ITensor *input, ITensor *output, float epsilon, const Window &window)>::type; struct MeanStdDevNormKernel { @@ -55,25 +57,15 @@ struct MeanStdDevNormKernel MeanStdDevNormUKernelPtr ukernel; }; -static const std::vector<MeanStdDevNormKernel> available_kernels = -{ - { - "fp32_neon_meanstddevnorm", - [](const MeanStdDevNormSelectorData & data) { return data.dt == DataType::F32; }, - REGISTER_FP32_NEON(arm_compute::cpu::neon_fp32_meanstddevnorm) - }, +static const std::vector<MeanStdDevNormKernel> available_kernels = { + {"fp32_neon_meanstddevnorm", [](const MeanStdDevNormSelectorData &data) { return data.dt == DataType::F32; }, + REGISTER_FP32_NEON(arm_compute::cpu::neon_fp32_meanstddevnorm)}, #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - { - "fp16_neon_meanstddevnorm", - [](const MeanStdDevNormSelectorData & data) { return data.dt == DataType::F16; }, - REGISTER_FP16_NEON(arm_compute::cpu::neon_fp16_meanstddevnorm) - }, + {"fp16_neon_meanstddevnorm", [](const MeanStdDevNormSelectorData &data) { return data.dt == DataType::F16; }, + REGISTER_FP16_NEON(arm_compute::cpu::neon_fp16_meanstddevnorm)}, #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - { - "qasymm8_neon_meanstddevnorm", - [](const MeanStdDevNormSelectorData & data) { return data.dt == DataType::QASYMM8; }, - REGISTER_QASYMM8_NEON(arm_compute::cpu::neon_qasymm8_meanstddevnorm) - }, + {"qasymm8_neon_meanstddevnorm", [](const MeanStdDevNormSelectorData &data) { return data.dt == DataType::QASYMM8; }, + REGISTER_QASYMM8_NEON(arm_compute::cpu::neon_qasymm8_meanstddevnorm)}, }; /** Micro-kernel selector @@ -84,9 +76,9 @@ static const std::vector<MeanStdDevNormKernel> available_kernels = */ const MeanStdDevNormKernel *get_implementation(const MeanStdDevNormSelectorData &data) { - for(const auto &uk : available_kernels) + for (const auto &uk : available_kernels) { - if(uk.is_selected(data)) + if (uk.is_selected(data)) { return &uk; } @@ -103,7 +95,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, f ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32, DataType::QASYMM8); // Checks performed when output is configured - if((output != nullptr) && (output->total_size() != 0)) + if ((output != nullptr) && (output->total_size() != 0)) { ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); @@ -113,7 +105,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, f std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output) { - if(output != nullptr) + if (output != nullptr) { ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); // Output auto inizialitation if not yet initialized @@ -128,8 +120,7 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITen } } // namespace -NEMeanStdDevNormalizationKernel::NEMeanStdDevNormalizationKernel() - : _input(nullptr), _output(nullptr), _epsilon(1e-8f) +NEMeanStdDevNormalizationKernel::NEMeanStdDevNormalizationKernel() : _input(nullptr), _output(nullptr), _epsilon(1e-8f) { } @@ -137,7 +128,8 @@ void NEMeanStdDevNormalizationKernel::configure(ITensor *input, ITensor *output, { ARM_COMPUTE_ERROR_ON_NULLPTR(input); - ARM_COMPUTE_ERROR_THROW_ON(NEMeanStdDevNormalizationKernel::validate(input->info(), (output != nullptr) ? output->info() : nullptr, epsilon)); + ARM_COMPUTE_ERROR_THROW_ON(NEMeanStdDevNormalizationKernel::validate( + input->info(), (output != nullptr) ? output->info() : nullptr, epsilon)); _input = input; _output = (output == nullptr) ? input : output; @@ -152,7 +144,9 @@ void NEMeanStdDevNormalizationKernel::configure(ITensor *input, ITensor *output, Status NEMeanStdDevNormalizationKernel::validate(const ITensorInfo *input, const ITensorInfo *output, float epsilon) { ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, epsilon)); - ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), (output != nullptr) ? output->clone().get() : nullptr).first); + ARM_COMPUTE_RETURN_ON_ERROR( + validate_and_configure_window(input->clone().get(), (output != nullptr) ? output->clone().get() : nullptr) + .first); return Status{}; } @@ -162,7 +156,7 @@ void NEMeanStdDevNormalizationKernel::run(const Window &window, const ThreadInfo ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window); - const auto *uk = get_implementation(MeanStdDevNormSelectorData{ _output->info()->data_type() }); + const auto *uk = get_implementation(MeanStdDevNormSelectorData{_output->info()->data_type()}); ARM_COMPUTE_ERROR_ON(uk == nullptr || uk->ukernel == nullptr); uk->ukernel(_input, _output, _epsilon, window); |