diff options
Diffstat (limited to 'src/core/NEON/kernels/NEScaleKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/NEScaleKernel.cpp | 38 |
1 files changed, 18 insertions, 20 deletions
diff --git a/src/core/NEON/kernels/NEScaleKernel.cpp b/src/core/NEON/kernels/NEScaleKernel.cpp index 0f329a1c2c..38a0706c12 100644 --- a/src/core/NEON/kernels/NEScaleKernel.cpp +++ b/src/core/NEON/kernels/NEScaleKernel.cpp @@ -28,6 +28,7 @@ #include "arm_compute/core/Helpers.h" #include "arm_compute/core/NEON/wrapper/wrapper.h" #include "arm_compute/core/Window.h" +#include "arm_compute/core/utils/misc/Rounding.h" #include "arm_compute/core/utils/misc/Utility.h" #include "src/core/utils/ScaleUtils.h" @@ -167,7 +168,7 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITen template <typename T> inline void scale_nearest_nhwc_core(const ITensor *input, const ITensor *offsets, ITensor *output, - float hr, Window window, const Window &win_in, size_t stride_w, size_t stride_h, size_t stride_c, float sampling_offset) + float hr, Window window, const Window &win_in, size_t stride_w, size_t stride_h, size_t stride_c, float sampling_offset, bool align_corners) { const int window_step_x = 16 / sizeof(T); const auto window_start_x = static_cast<int32_t>(window.x().start()); @@ -183,7 +184,7 @@ inline void scale_nearest_nhwc_core(const ITensor *input, const ITensor *offsets execute_window_loop(window, [&](const Coordinates & id) { const int32_t offset = *reinterpret_cast<const int32_t *>(offsets->ptr_to_element(Coordinates(id.y(), id.z()))); - const int in_yi = std::floor((id.z() + sampling_offset) * hr); + const auto in_yi = static_cast<int>(align_corners ? arm_compute::utils::rounding::round_half_away_from_zero((id.z() + sampling_offset) * hr) : std::floor((id.z() + sampling_offset) * hr)); const int offset_row = in_yi * stride_h; int32_t x = window_start_x; for(; x < window_end_x - window_step_x; x += window_step_x) @@ -460,8 +461,8 @@ void NEScaleKernel::scale_nearest_nchw(const Window &window) const auto offsets_ptr = reinterpret_cast<const int32_t *>(offsets.ptr()); const uint8_t *const in_ptr = in.ptr(); - const int in_yi = std::floor((id.y() + _sampling_offset) * hr); - const int in_yi_clamped = std::min(static_cast<int>(_input->info()->dimension(1)), std::max(in_yi, -1)); + const auto in_yi = static_cast<int>(_align_corners ? arm_compute::utils::rounding::round_half_away_from_zero((id.y() + _sampling_offset) * hr) : std::floor((id.y() + _sampling_offset) * hr)); + const int in_yi_clamped = std::min(static_cast<int>(_input->info()->dimension(1)), std::max(in_yi, -1)); ARM_COMPUTE_ERROR_ON(in_yi_clamped < -1 || in_yi_clamped > static_cast<int>(_input->info()->dimension(1))); const int offset_row = in_yi_clamped * input_stride; @@ -497,8 +498,8 @@ void NEScaleKernel::scale_nearest_nchw(const Window &window) const auto offsets_ptr = reinterpret_cast<const int32_t *>(offsets.ptr()); const uint8_t *const in_ptr = in.ptr(); - const int in_yi = std::floor((id.y() + _sampling_offset) * hr); - const int in_yi_clamped = std::min(static_cast<int>(_input->info()->dimension(1)), std::max(in_yi, -1)); + const auto in_yi = static_cast<int>(_align_corners ? arm_compute::utils::rounding::round_half_away_from_zero((id.y() + _sampling_offset) * hr) : std::floor((id.y() + _sampling_offset) * hr)); + const int in_yi_clamped = std::min(static_cast<int>(_input->info()->dimension(1)), std::max(in_yi, -1)); ARM_COMPUTE_ERROR_ON(in_yi_clamped < -1 || in_yi_clamped > static_cast<int>(_input->info()->dimension(1))); const int offset_row = in_yi_clamped * input_stride; @@ -537,9 +538,8 @@ void NEScaleKernel::scale_nearest_nchw(const Window &window) execute_window_loop(window, [&](const Coordinates & id) { const auto offsets_ptr = reinterpret_cast<const int32_t *>(offsets.ptr()); - - const int in_yi = std::floor((id.y() + _sampling_offset) * hr); - const int offset_row = in_yi * input_stride; + const auto in_yi = static_cast<int>(_align_corners ? arm_compute::utils::rounding::round_half_away_from_zero((id.y() + _sampling_offset) * hr) : std::floor((id.y() + _sampling_offset) * hr)); + const int offset_row = in_yi * input_stride; tmp.val[0] = vsetq_lane_s16(*reinterpret_cast<const int16_t *>(in.ptr() + offsets_ptr[0] + offset_row), tmp.val[0], 0); tmp.val[0] = vsetq_lane_s16(*reinterpret_cast<const int16_t *>(in.ptr() + offsets_ptr[2] + offset_row), tmp.val[0], 1); @@ -578,9 +578,8 @@ void NEScaleKernel::scale_nearest_nchw(const Window &window) execute_window_loop(window, [&](const Coordinates & id) { const auto offsets_ptr = reinterpret_cast<const int32_t *>(offsets.ptr()); - - const int in_yi = std::floor((id.y() + _sampling_offset) * hr); - const int offset_row = in_yi * input_stride; + const auto in_yi = static_cast<int>(_align_corners ? arm_compute::utils::rounding::round_half_away_from_zero((id.y() + _sampling_offset) * hr) : std::floor((id.y() + _sampling_offset) * hr)); + const int offset_row = in_yi * input_stride; tmp.val[0] = vsetq_lane_f16(*reinterpret_cast<const __fp16 *>(in.ptr() + offsets_ptr[0] + offset_row), tmp.val[0], 0); tmp.val[0] = vsetq_lane_f16(*reinterpret_cast<const __fp16 *>(in.ptr() + offsets_ptr[2] + offset_row), tmp.val[0], 1); @@ -621,9 +620,8 @@ void NEScaleKernel::scale_nearest_nchw(const Window &window) execute_window_loop(window, [&](const Coordinates & id) { const auto offsets_ptr = reinterpret_cast<const int32_t *>(offsets.ptr()); - - const int in_yi = std::floor((id.y() + _sampling_offset) * hr); - const int offset_row = in_yi * input_stride; + const auto in_yi = static_cast<int>(_align_corners ? arm_compute::utils::rounding::round_half_away_from_zero((id.y() + _sampling_offset) * hr) : std::floor((id.y() + _sampling_offset) * hr)); + const int offset_row = in_yi * input_stride; tmp.val[0] = vsetq_lane_f32(*reinterpret_cast<const float *>(in.ptr() + offsets_ptr[0] + offset_row), tmp.val[0], 0); tmp.val[0] = vsetq_lane_f32(*reinterpret_cast<const float *>(in.ptr() + offsets_ptr[4] + offset_row), tmp.val[0], 1); @@ -1024,7 +1022,7 @@ void NEScaleKernel::scale_nhwc(const Window &window) { if(_policy == InterpolationPolicy::NEAREST_NEIGHBOR) { - scale_nearest_nhwc_core<int8_t>(_input, _offsets, _output, hr, window, win_in, input_stride_w, input_stride_h, input_stride_c, _sampling_offset); + scale_nearest_nhwc_core<int8_t>(_input, _offsets, _output, hr, window, win_in, input_stride_w, input_stride_h, input_stride_c, _sampling_offset, _align_corners); } else { @@ -1038,7 +1036,7 @@ void NEScaleKernel::scale_nhwc(const Window &window) { if(_policy == InterpolationPolicy::NEAREST_NEIGHBOR) { - scale_nearest_nhwc_core<uint8_t>(_input, _offsets, _output, hr, window, win_in, input_stride_w, input_stride_h, input_stride_c, _sampling_offset); + scale_nearest_nhwc_core<uint8_t>(_input, _offsets, _output, hr, window, win_in, input_stride_w, input_stride_h, input_stride_c, _sampling_offset, _align_corners); } else { @@ -1051,7 +1049,7 @@ void NEScaleKernel::scale_nhwc(const Window &window) { if(_policy == InterpolationPolicy::NEAREST_NEIGHBOR) { - scale_nearest_nhwc_core<int16_t>(_input, _offsets, _output, hr, window, win_in, input_stride_w, input_stride_h, input_stride_c, _sampling_offset); + scale_nearest_nhwc_core<int16_t>(_input, _offsets, _output, hr, window, win_in, input_stride_w, input_stride_h, input_stride_c, _sampling_offset, _align_corners); } else { @@ -1066,7 +1064,7 @@ void NEScaleKernel::scale_nhwc(const Window &window) if(_policy == InterpolationPolicy::NEAREST_NEIGHBOR) { scale_nearest_nhwc_core<float16_t>(_input, _offsets, _output, hr, - window, win_in, input_stride_w, input_stride_h, input_stride_c, _sampling_offset); + window, win_in, input_stride_w, input_stride_h, input_stride_c, _sampling_offset, _align_corners); } else { @@ -1080,7 +1078,7 @@ void NEScaleKernel::scale_nhwc(const Window &window) { if(_policy == InterpolationPolicy::NEAREST_NEIGHBOR) { - scale_nearest_nhwc_core<float>(_input, _offsets, _output, hr, window, win_in, input_stride_w, input_stride_h, input_stride_c, _sampling_offset); + scale_nearest_nhwc_core<float>(_input, _offsets, _output, hr, window, win_in, input_stride_w, input_stride_h, input_stride_c, _sampling_offset, _align_corners); } else { |