From 44b4e974590f1a6a07b235f203006cc9010b37e8 Mon Sep 17 00:00:00 2001 From: George Wort Date: Tue, 8 Jan 2019 11:41:54 +0000 Subject: COMPMID-1794: Add support for NHWC in CLROIAlignLayer Change-Id: If1df8f6c0549c986e607cbceb0977c80b2891b75 Reviewed-on: https://review.mlplatform.org/493 Tested-by: Arm Jenkins Reviewed-by: Isabella Gottardi Reviewed-by: Michele Di Giorgio --- src/core/CL/cl_kernels/roi_align_layer.cl | 55 ++++++++++++++++++++----------- 1 file changed, 36 insertions(+), 19 deletions(-) (limited to 'src/core/CL/cl_kernels/roi_align_layer.cl') diff --git a/src/core/CL/cl_kernels/roi_align_layer.cl b/src/core/CL/cl_kernels/roi_align_layer.cl index f52eb18078..a956860be2 100644 --- a/src/core/CL/cl_kernels/roi_align_layer.cl +++ b/src/core/CL/cl_kernels/roi_align_layer.cl @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -75,11 +75,17 @@ inline DATA_TYPE roi_align_1x1(const Tensor3D *input, float region_start_x, const float w2 = hy * lx; const float w3 = ly * hx; const float w4 = ly * lx; - - const DATA_TYPE data1 = *(__global DATA_TYPE *)tensor3D_offset(input, x_low, y_low, pz); - const DATA_TYPE data2 = *(__global DATA_TYPE *)tensor3D_offset(input, x_high, y_low, pz); - const DATA_TYPE data3 = *(__global DATA_TYPE *)tensor3D_offset(input, x_low, y_high, pz); - const DATA_TYPE data4 = *(__global DATA_TYPE *)tensor3D_offset(input, x_high, y_high, pz); +#if defined(NHWC) + const DATA_TYPE data1 = *(__global DATA_TYPE *)tensor3D_offset(input, pz, x_low, y_low); + const DATA_TYPE data2 = *(__global DATA_TYPE *)tensor3D_offset(input, pz, x_high, y_low); + const DATA_TYPE data3 = *(__global DATA_TYPE *)tensor3D_offset(input, pz, x_low, y_high); + const DATA_TYPE data4 = *(__global DATA_TYPE *)tensor3D_offset(input, pz, x_high, y_high); +#else // !defined(NHWC) + const DATA_TYPE data1 = *(__global DATA_TYPE *)tensor3D_offset(input, x_low, y_low, pz); + const DATA_TYPE data2 = *(__global DATA_TYPE *)tensor3D_offset(input, x_high, y_low, pz); + const DATA_TYPE data3 = *(__global DATA_TYPE *)tensor3D_offset(input, x_low, y_high, pz); + const DATA_TYPE data4 = *(__global DATA_TYPE *)tensor3D_offset(input, x_high, y_high, pz); +#endif // defined(NHWC) sum += w1 * data1 + w2 * data2 + w3 * data3 + w4 * data4; } } @@ -133,9 +139,15 @@ __kernel void roi_align_layer( Image rois = CONVERT_TO_IMAGE_STRUCT_NO_STEP(rois); Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(output); - const int px = get_global_id(0); - const int py = get_global_id(1); - const int pw = get_global_id(2); +#if defined(NHWC) + const int px = get_global_id(1); + const int py = get_global_id(2); + const int pw = get_global_id(0); +#else // !defined(NHWC) + const int px = get_global_id(0); + const int py = get_global_id(1); + const int pw = get_global_id(2); +#endif // defined(NHWC) // Load roi parameters // roi is laid out as follows { batch_index, x1, y1, x2, y2 } @@ -161,7 +173,7 @@ __kernel void roi_align_layer( const float2 roi_bin_grid = SAMPLING_RATIO; #else // !defined(SAMPLING_RATIO) // Note that we subtract EPS_GRID before ceiling. This is to avoid situations where 1.000001 gets ceiled to 2. - const float2 roi_bin_grid = ceil(bin_size - EPS_GRID); + const float2 roi_bin_grid = ceil(bin_size - EPS_GRID); #endif // defined(SAMPLING_RATIO) // Move input and output pointer across the fourth dimension @@ -169,15 +181,20 @@ __kernel void roi_align_layer( output.ptr += pw * output_stride_w; for(int pz = 0; pz < MAX_DIM_Z; ++pz) { - *(__global DATA_TYPE *)tensor3D_offset(&output, px, py, pz) = (__global DATA_TYPE)roi_align_1x1(&input, - region_start.x, - bin_size.x, - roi_bin_grid.x, - region_end.x, - region_start.y, - bin_size.y, - roi_bin_grid.y, - region_end.y, pz); +#if defined(NHWC) + DATA_TYPE *_output_ptr = (__global DATA_TYPE *)tensor3D_offset(&output, pz, px, py); +#else // !defined(NHWC) + DATA_TYPE *_output_ptr = (__global DATA_TYPE *)tensor3D_offset(&output, px, py, pz); +#endif // defined(NHWC) + *_output_ptr = (__global DATA_TYPE)roi_align_1x1(&input, + region_start.x, + bin_size.x, + roi_bin_grid.x, + region_end.x, + region_start.y, + bin_size.y, + roi_bin_grid.y, + region_end.y, pz); } } #endif // Check for compile time constants -- cgit v1.2.1