From fa7ad56c4fc9e63a2f9e9a16e97ac9c275a5e3d8 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Tue, 15 May 2018 17:38:40 +0100 Subject: COMPMID-1163: NEON Scale NHWC failures Change-Id: Ice620385ce787b568b38fcbdddc94ef385396141 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/131355 Reviewed-by: Anthony Barbier Tested-by: Jenkins --- src/core/NEON/kernels/NEScaleKernel.cpp | 94 ++++++++++++++++++++------------- 1 file changed, 56 insertions(+), 38 deletions(-) (limited to 'src/core/NEON/kernels/NEScaleKernel.cpp') diff --git a/src/core/NEON/kernels/NEScaleKernel.cpp b/src/core/NEON/kernels/NEScaleKernel.cpp index 3f57ffba5c..71116447f4 100644 --- a/src/core/NEON/kernels/NEScaleKernel.cpp +++ b/src/core/NEON/kernels/NEScaleKernel.cpp @@ -190,8 +190,8 @@ inline void scale_bilinear_nhwc_core(const ITensor *input, const ITensor *offset const size_t stride_w_elems = stride_w / sizeof(T); const size_t stride_h_elems = stride_h / sizeof(T); - const size_t input_width = input->info()->dimension(1); - const size_t input_height = input->info()->dimension(2); + const int input_width = input->info()->dimension(1); + const int input_height = input->info()->dimension(2); const T *border_area = reinterpret_cast(input->buffer() + input->info()->offset_first_element_in_bytes() - stride_w); @@ -200,55 +200,73 @@ inline void scale_bilinear_nhwc_core(const ITensor *input, const ITensor *offset return !(x < low_x || x > high_x || y < low_y || y > high_y); }; + int border_size = (border_mode == BorderMode::UNDEFINED) ? 0 : 1; + execute_window_loop(window, [&](const Coordinates & id) { - const auto offset = (*reinterpret_cast(offsets->ptr_to_element(Coordinates(id.y(), id.z())))) / sizeof(T); + const auto offset = (*reinterpret_cast(offsets->ptr_to_element(Coordinates(id.y(), id.z())))) / static_cast(sizeof(T)); const auto dx_scale = *reinterpret_cast(dx->ptr_to_element(Coordinates(id.y(), id.z()))); const auto dy_scale = *reinterpret_cast(dy->ptr_to_element(Coordinates(id.y(), id.z()))); const int in_yi = std::floor((id.z() + 0.5f) * hr - 0.5f); const int offset_row = in_yi * stride_h + id.x() * stride_c; const T *in_ptr = reinterpret_cast(in.ptr() + offset * stride_w + offset_row); - T a00 = 0, a01 = 0, a10 = 0, a11 = 0; - - if(border_mode == BorderMode::CONSTANT) - { - a00 = is_valid(offset, 0, input_width - 1, in_yi, 0, input_height - 1) ? *in_ptr : *border_area; - a01 = is_valid(offset + 1, 0, input_width - 1, in_yi, 0, input_height - 1) ? *(in_ptr + stride_w_elems) : *border_area; - a10 = is_valid(offset, 0, input_width - 1, in_yi + 1, 0, input_height - 1) ? *(in_ptr + stride_h_elems) : *border_area; - a11 = is_valid(offset + 1, 0, input_width - 1, in_yi + 1, 0, input_height - 1) ? *(in_ptr + stride_h_elems + stride_w_elems) : *border_area; - } - else if(border_mode == BorderMode::REPLICATE) + if(is_valid(offset, -border_size, input_width - 1 + border_size, in_yi, -border_size, input_height - 1 + border_size)) { - auto clamped_x = utility::clamp(offset, 0, input_width - 1); - auto clamped_x1 = utility::clamp(offset + 1, 0, input_width - 1); - auto clamped_y = utility::clamp(in_yi, 0, input_height - 1); - auto clamped_y1 = utility::clamp(in_yi + 1, 0, input_height - 1); - - a00 = *reinterpret_cast(in.ptr() + clamped_x * stride_w + clamped_y * stride_h + id.x() * stride_c); - a01 = *reinterpret_cast(in.ptr() + clamped_x1 * stride_w + clamped_y * stride_h + id.x() * stride_c); - a10 = *reinterpret_cast(in.ptr() + clamped_x * stride_w + clamped_y1 * stride_h + id.x() * stride_c); - a11 = *reinterpret_cast(in.ptr() + clamped_x1 * stride_w + clamped_y1 * stride_h + id.x() * stride_c); + T a00 = 0, a01 = 0, a10 = 0, a11 = 0; + + if(border_mode == BorderMode::CONSTANT) + { + a00 = is_valid(offset, 0, input_width - 1, in_yi, 0, input_height - 1) ? *in_ptr : *border_area; + a01 = is_valid(offset + 1, 0, input_width - 1, in_yi, 0, input_height - 1) ? *(in_ptr + stride_w_elems) : *border_area; + a10 = is_valid(offset, 0, input_width - 1, in_yi + 1, 0, input_height - 1) ? *(in_ptr + stride_h_elems) : *border_area; + a11 = is_valid(offset + 1, 0, input_width - 1, in_yi + 1, 0, input_height - 1) ? *(in_ptr + stride_h_elems + stride_w_elems) : *border_area; + } + else if(border_mode == BorderMode::REPLICATE) + { + auto clamped_x = utility::clamp(offset, 0, input_width - 1); + auto clamped_x1 = utility::clamp(offset + 1, 0, input_width - 1); + auto clamped_y = utility::clamp(in_yi, 0, input_height - 1); + auto clamped_y1 = utility::clamp(in_yi + 1, 0, input_height - 1); + + a00 = *reinterpret_cast(in.ptr() + clamped_x * stride_w + clamped_y * stride_h + id.x() * stride_c); + a01 = *reinterpret_cast(in.ptr() + clamped_x1 * stride_w + clamped_y * stride_h + id.x() * stride_c); + a10 = *reinterpret_cast(in.ptr() + clamped_x * stride_w + clamped_y1 * stride_h + id.x() * stride_c); + a11 = *reinterpret_cast(in.ptr() + clamped_x1 * stride_w + clamped_y1 * stride_h + id.x() * stride_c); + } + else + { + a00 = is_valid(offset, 0, input_width - 1, in_yi, 0, input_height - 1) ? *in_ptr : 0; + a01 = is_valid(offset + 1, 0, input_width - 1, in_yi, 0, input_height - 1) ? *(in_ptr + stride_w_elems) : 0; + a10 = is_valid(offset, 0, input_width - 1, in_yi + 1, 0, input_height - 1) ? *(in_ptr + stride_h_elems) : 0; + a11 = is_valid(offset + 1, 0, input_width - 1, in_yi + 1, 0, input_height - 1) ? *(in_ptr + stride_h_elems + stride_w_elems) : 0; + } + + // Perform interpolation + const float dx1 = 1.0f - dx_scale; + const float dy1 = 1.0f - dy_scale; + + const float w1 = dx1 * dy1; + const float w2 = dx_scale * dy1; + const float w3 = dx1 * dy_scale; + const float w4 = dx_scale * dy_scale; + + // Store result + *reinterpret_cast(out.ptr()) = static_cast(a00 * w1 + a01 * w2 + a10 * w3 + a11 * w4); } else { - a00 = *in_ptr; - a01 = *(in_ptr + stride_w_elems); - a10 = *(in_ptr + stride_h_elems); - a11 = *(in_ptr + stride_h_elems + stride_w_elems); + if(border_mode == BorderMode::CONSTANT) + { + *reinterpret_cast(out.ptr()) = *border_area; + } + else if(border_mode == BorderMode::REPLICATE) + { + auto clamped_x = utility::clamp(offset, 0, input_width - 1); + auto clamped_y = utility::clamp(in_yi, 0, input_height - 1); + *reinterpret_cast(out.ptr()) = *reinterpret_cast(in.ptr() + clamped_x * stride_w + clamped_y * stride_h + id.x() * stride_c); + } } - - // Perform interpolation - const float dx1 = 1.0f - dx_scale; - const float dy1 = 1.0f - dy_scale; - - const float w1 = dx1 * dy1; - const float w2 = dx_scale * dy1; - const float w3 = dx1 * dy_scale; - const float w4 = dx_scale * dy_scale; - - // Store result - *reinterpret_cast(out.ptr()) = static_cast(a00 * w1 + a01 * w2 + a10 * w3 + a11 * w4); }, in, out); } -- cgit v1.2.1