From 25ef7217ec4e13682bf37c87c0c6075a799ba1c0 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Tue, 2 Jun 2020 23:00:41 +0100 Subject: COMPMID-3180: Remove padding from NEThreshold - Removes padding from NEThresholdKernel - Alters configuration interface to use a descriptor Change-Id: I394d5e1375454813856d9d206e61dc9a87c2cadc Signed-off-by: Georgios Pinitas Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3300 Reviewed-by: Michele Di Giorgio Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins --- src/core/NEON/kernels/NEThresholdKernel.cpp | 171 +++++++++++++++++++++------- 1 file changed, 129 insertions(+), 42 deletions(-) (limited to 'src/core/NEON/kernels/NEThresholdKernel.cpp') diff --git a/src/core/NEON/kernels/NEThresholdKernel.cpp b/src/core/NEON/kernels/NEThresholdKernel.cpp index 5c3b2a7540..b8adc15e77 100644 --- a/src/core/NEON/kernels/NEThresholdKernel.cpp +++ b/src/core/NEON/kernels/NEThresholdKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2019 ARM Limited. + * Copyright (c) 2016-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -28,30 +28,60 @@ #include "arm_compute/core/ITensor.h" #include "arm_compute/core/Validate.h" -#include +#include "arm_compute/core/NEON/wrapper/wrapper.h" namespace arm_compute { -class Coordinates; +namespace +{ +Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const ThresholdKernelInfo &info) +{ + ARM_COMPUTE_UNUSED(info); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8); + + // Checks performed when output is configured + if((output != nullptr) && (output->total_size() != 0)) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); + } + + return Status{}; +} + +std::pair validate_and_configure_window(ITensorInfo *input, ITensorInfo *output) +{ + // Configure kernel window + Window win = calculate_max_window(*input, Steps()); + + // Output auto inizialitation if not yet initialized + auto_init_if_empty(*output, *input->clone()); + + // NEThresholdKernel doesn't need padding so update_window_and_padding() can be skipped + Coordinates coord; + coord.set_num_dimensions(output->num_dimensions()); + output->set_valid_region(ValidRegion(coord, output->tensor_shape())); + + return std::make_pair(Status{}, win); +} +} // namespace NEThresholdKernel::NEThresholdKernel() - : _func(nullptr), _input(nullptr), _output(nullptr), _threshold(0), _false_value(0), _true_value(0), _upper(0) + : _func(nullptr), _input(nullptr), _output(nullptr), _info() { } -void NEThresholdKernel::configure(const ITensor *input, ITensor *output, uint8_t threshold, uint8_t false_value, uint8_t true_value, ThresholdType type, uint8_t upper) +void NEThresholdKernel::configure(const ITensor *input, ITensor *output, const ThresholdKernelInfo &info) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8); + ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); + ARM_COMPUTE_ERROR_THROW_ON(validate(input->info(), output->info(), info)); - _input = input; - _output = output; - _threshold = threshold; - _false_value = false_value; - _true_value = true_value; - _upper = upper; + _input = input; + _output = output; + _info = info; - switch(type) + switch(_info.type) { case ThresholdType::BINARY: _func = &NEThresholdKernel::run_binary; @@ -64,54 +94,111 @@ void NEThresholdKernel::configure(const ITensor *input, ITensor *output, uint8_t break; } - constexpr unsigned int num_elems_processed_per_iteration = 16; + // Configure kernel window + auto win_config = validate_and_configure_window(input->info(), output->info()); + ARM_COMPUTE_ERROR_THROW_ON(win_config.first); + ICPPKernel::configure(win_config.second); +} - Window win = calculate_max_window(*input->info(), Steps(num_elems_processed_per_iteration)); - AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration); - update_window_and_padding(win, AccessWindowHorizontal(input->info(), 0, num_elems_processed_per_iteration), output_access); - output_access.set_valid_region(win, input->info()->valid_region()); +Status NEThresholdKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const ThresholdKernelInfo &info) +{ + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, info)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get()).first); - INEKernel::configure(win); + return Status{}; } inline void NEThresholdKernel::run_binary(const Window &window) { - const uint8x16_t threshold = vdupq_n_u8(_threshold); - const uint8x16_t true_value = vdupq_n_u8(_true_value); - const uint8x16_t false_value = vdupq_n_u8(_false_value); + /** NEON vector tag type. */ + using Type = uint8_t; + using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t; - Iterator input(_input, window); - Iterator output(_output, window); + const int window_step_x = 16 / sizeof(Type); + const auto window_start_x = static_cast(window.x().start()); + const auto window_end_x = static_cast(window.x().end()); - execute_window_loop(window, [&](const Coordinates &) - { - const uint8x16_t data = vld1q_u8(input.ptr()); - const uint8x16_t mask = vcgtq_u8(data, threshold); + Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); + win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); + + const uint8_t threshold = _info.threshold; + const uint8_t true_value = _info.true_value; + const uint8_t false_value = _info.false_value; - vst1q_u8(output.ptr(), vbslq_u8(mask, true_value, false_value)); + const auto vthreshold = wrapper::vdup_n(threshold, ExactTagType{}); + const auto vtrue_value = wrapper::vdup_n(true_value, ExactTagType{}); + const auto vfalse_value = wrapper::vdup_n(false_value, ExactTagType{}); + + Iterator input(_input, win_collapsed); + Iterator output(_output, win_collapsed); + + execute_window_loop(win_collapsed, [&](const Coordinates &) + { + const auto input_ptr = reinterpret_cast(input.ptr()); + const auto output_ptr = reinterpret_cast(output.ptr()); + + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + const auto vdata = wrapper::vloadq(input_ptr + x); + const auto vmask = wrapper::vcgt(vdata, vthreshold); + wrapper::vstore(output_ptr + x, wrapper::vbsl(vmask, vtrue_value, vfalse_value)); + } + + for(; x < window_end_x; ++x) + { + const Type data = *(reinterpret_cast(input_ptr + x)); + *(output_ptr + x) = (data > threshold) ? true_value : false_value; + } }, input, output); } inline void NEThresholdKernel::run_range(const Window &window) { - const uint8x16_t lower_threshold = vdupq_n_u8(_threshold); - const uint8x16_t upper_threshold = vdupq_n_u8(_upper); - const uint8x16_t true_value = vdupq_n_u8(_true_value); - const uint8x16_t false_value = vdupq_n_u8(_false_value); + /** NEON vector tag type. */ + using Type = uint8_t; + using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t; - Iterator input(_input, window); - Iterator output(_output, window); + const int window_step_x = 16 / sizeof(Type); + const auto window_start_x = static_cast(window.x().start()); + const auto window_end_x = static_cast(window.x().end()); - execute_window_loop(window, [&](const Coordinates &) - { - const uint8x16_t data = vld1q_u8(input.ptr()); + Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); + win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); + + const uint8_t lower_threshold = _info.threshold; + const uint8_t upper_threshold = _info.upper; + const uint8_t true_value = _info.true_value; + const uint8_t false_value = _info.false_value; - uint8x16_t mask = vcleq_u8(data, upper_threshold); + const auto vlower_threshold = wrapper::vdup_n(lower_threshold, ExactTagType{}); + const auto vupper_threshold = wrapper::vdup_n(upper_threshold, ExactTagType{}); + const auto vtrue_value = wrapper::vdup_n(true_value, ExactTagType{}); + const auto vfalse_value = wrapper::vdup_n(false_value, ExactTagType{}); - mask = vandq_u8(vcgeq_u8(data, lower_threshold), mask); + Iterator input(_input, win_collapsed); + Iterator output(_output, win_collapsed); - vst1q_u8(output.ptr(), vbslq_u8(mask, true_value, false_value)); + execute_window_loop(win_collapsed, [&](const Coordinates &) + { + const auto input_ptr = reinterpret_cast(input.ptr()); + const auto output_ptr = reinterpret_cast(output.ptr()); + + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + const auto vdata = wrapper::vloadq(input_ptr + x); + auto vmask = wrapper::vcle(vdata, vupper_threshold); + vmask = wrapper::vand(wrapper::vcge(vdata, vlower_threshold), vmask); + wrapper::vstore(output_ptr + x, wrapper::vbsl(vmask, vtrue_value, vfalse_value)); + } + + for(; x < window_end_x; ++x) + { + const Type data = *(reinterpret_cast(input_ptr + x)); + *(output_ptr + x) = (data <= upper_threshold && data >= lower_threshold) ? true_value : false_value; + } }, input, output); } -- cgit v1.2.1