aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2021-04-08 17:20:00 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2021-04-12 17:39:32 +0000
commit0b76f7dd12240dc7a546c202ee80a7942d9898cd (patch)
tree7dbd9ae56483e111952a0cab4f19d2c3f25157e7
parent6dbcc0e4d2fd0c61602a1a0c4a0ac548da713087 (diff)
downloadComputeLibrary-0b76f7dd12240dc7a546c202ee80a7942d9898cd.tar.gz
Add support for cl_image in CLDirectConvolutionLayer
- The cl_image object can be used for the weights - cl_image can only work for f32/f16 - Fix the implicit padding on the first dimension X Resolves COMPMID-4341 Change-Id: I04e0901c69e7765c42afceca38c4a840645b9123 Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5393 Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--src/core/CL/cl_kernels/direct_convolution.cl62
-rw-r--r--src/core/CL/cl_kernels/tile_helpers.h71
-rw-r--r--src/core/CL/cl_kernels/winograd_output_transform.cl8
-rw-r--r--src/core/gpu/cl/kernels/ClDirectConvolutionKernel.cpp94
-rw-r--r--src/runtime/CL/functions/CLConvolutionLayer.cpp2
5 files changed, 169 insertions, 68 deletions
diff --git a/src/core/CL/cl_kernels/direct_convolution.cl b/src/core/CL/cl_kernels/direct_convolution.cl
index 96196bda8d..220179effb 100644
--- a/src/core/CL/cl_kernels/direct_convolution.cl
+++ b/src/core/CL/cl_kernels/direct_convolution.cl
@@ -122,6 +122,7 @@ __kernel void direct_convolution_nhwc(
#define _IDST_WIDTH DST_WIDTH
#define _IDST_HEIGHT DST_HEIGHT
#define _IDST_CHANNELS DST_CHANNELS
+#define _IY_MULTIPLIER (_IWEI_WIDTH * _IWEI_HEIGHT)
// If quantized, the output tile has to be quantized first before being stored to global memory
#if defined(IS_QUANTIZED)
@@ -136,8 +137,8 @@ __kernel void direct_convolution_nhwc(
// .v = access the whole vector (OpenCL vector)
// .s[x] = access the vector element at position x (scalar access)
- TILE(int, M0, 1, xi) = {{ { 0 } }};
- TILE(int, M0, 1, yi) = {{ { 0 } }};
+ TILE(int, M0, 1, xi);
+ TILE(int, M0, 1, yi);
// Convert the linear index to coordinate
LOOP_UNROLLING(int, i, 0, M0, 1)
@@ -148,29 +149,14 @@ __kernel void direct_convolution_nhwc(
yi[i].v -= PAD_TOP;
}
- uint wei_x = 0;
-
// Initialize the accumulators
- TILE(ACC_DATA_TYPE, M0, N0, c) = {{ { 0 } }};
+ TILE(ACC_DATA_TYPE, M0, N0, c) = { { { 0 } } };
for(int i = 0; i < (_IWEI_WIDTH * _IWEI_HEIGHT); ++i)
{
- uint src_x = 0;
- int xk = i % _IWEI_WIDTH;
- int yk = i / _IWEI_WIDTH;
-
- TILE(int, M0, 1, src_indirect_y) = {{ { 0 } }};
- TILE(int, M0, 1, src_indirect_mask) = {{ { 0 } }};
-
- // Calculate the source indirect Y and the source indirect mask
- // Since the indirect Y is clamped when out-of-bound, the mask is used to
- // force to zero the out-of-bound values
- LOOP_UNROLLING(int, i, 0, M0, 1)
- {
- src_indirect_y[i].v = (CLAMP(xi[i].v + xk, 0, (int)_ISRC_WIDTH - 1) + CLAMP(yi[i].v + yk, 0, (int)_ISRC_HEIGHT - 1) * _ISRC_WIDTH);
- src_indirect_y[i].v += bout * (int)_ISRC_WIDTH * (int)_ISRC_HEIGHT;
- src_indirect_mask[i].v = ((xi[i].v + xk) >= 0 && (xi[i].v + xk) < (int)_ISRC_WIDTH && (yi[i].v + yk) >= 0 && (yi[i].v + yk) < (int)_ISRC_HEIGHT);
- }
+ int ck = 0;
+ int xk = i % _IWEI_WIDTH;
+ int yk = i / _IWEI_WIDTH;
int k = 0;
for(; k <= (_ISRC_CHANNELS - K0); k += K0)
@@ -178,14 +164,16 @@ __kernel void direct_convolution_nhwc(
TILE(SRC_DATA_TYPE, M0, K0, a);
TILE(WEI_DATA_TYPE, N0, K0, b);
+ LOOP_UNROLLING(int, i, 0, M0, 1)
+ {
+ a[i].v = ZERO_VALUE;
+ }
+
// Load tile from the src tensor
- T_LOAD_INDIRECT(SRC_DATA_TYPE, M0, K0, SRC_TENSOR_TYPE, src, src_x, src_stride_y, src_indirect_y, a);
+ T_LOAD_NHWC_INDIRECT(SRC_DATA_TYPE, 1, M0, K0, SRC_TENSOR_TYPE, src, bout, yk, xk, ck, _ISRC_WIDTH, _ISRC_HEIGHT, src_stride_y, xi, yi, a);
// Load tile from the weights tensor
- T_LOAD(WEI_DATA_TYPE, N0, K0, WEI_TENSOR_TYPE, wei, wei_x, cout, wei_stride_w, b);
-
- // Fill with zero the out-of-bound rows
- T_ROWSET_MASK(SRC_DATA_TYPE, M0, K0, ZERO_VALUE, a, src_indirect_mask);
+ T_LOAD(WEI_DATA_TYPE, N0, K0, WEI_TENSOR_TYPE, wei, ck, cout * _IY_MULTIPLIER + i, _IY_MULTIPLIER, wei_stride_y, b);
// Compute the matrix multiplication between two tiles
T_MMUL(SRC_DATA_TYPE, WEI_DATA_TYPE, ACC_DATA_TYPE, M0, N0, K0, NT, T, a, b, c);
@@ -194,8 +182,7 @@ __kernel void direct_convolution_nhwc(
// The computation is not performed if both SRC_OFFSET and WEI_OFFSET are zero
T_OFFSET_CORRECTION(ACC_DATA_TYPE, M0, N0, K0, SRC_OFFSET, WEI_OFFSET, a, b, c);
- src_x += K0;
- wei_x += K0;
+ ck += K0;
}
// We voluntarily use SRC_CHANNELS rather than _DSRC_CHANNELS
@@ -207,14 +194,17 @@ __kernel void direct_convolution_nhwc(
TILE(SRC_DATA_TYPE, M0, 1, a);
TILE(WEI_DATA_TYPE, N0, 1, b);
+ LOOP_UNROLLING(int, i, 0, M0, 1)
+ {
+ a[i].v = ZERO_VALUE;
+ }
+
// Load tile from the src tensor
- T_LOAD_INDIRECT(SRC_DATA_TYPE, M0, 1, SRC_TENSOR_TYPE, src, src_x, src_stride_y, src_indirect_y, a);
+ T_LOAD_NHWC_INDIRECT(SRC_DATA_TYPE, 1, M0, 1, SRC_TENSOR_TYPE, src, bout, yk, xk, ck, _ISRC_WIDTH, _ISRC_HEIGHT, src_stride_y, xi, yi, a);
// Load tile from the weights tensor
- T_LOAD(WEI_DATA_TYPE, N0, 1, WEI_TENSOR_TYPE, wei, wei_x, cout, wei_stride_w, b);
-
- // Fill with zero the out-of-bound rows
- T_ROWSET_MASK(SRC_DATA_TYPE, M0, 1, ZERO_VALUE, a, src_indirect_mask);
+ // The T_LOAD for the left-over elements can only use BUFFER because we load one element per iteration
+ T_LOAD(WEI_DATA_TYPE, N0, 1, BUFFER, wei, ck, cout * _IY_MULTIPLIER + i, _IY_MULTIPLIER, wei_stride_y, b);
// Compute the matrix multiplication between two tiles
T_MMUL(SRC_DATA_TYPE, WEI_DATA_TYPE, ACC_DATA_TYPE, M0, N0, 1, NT, T, a, b, c);
@@ -223,8 +213,7 @@ __kernel void direct_convolution_nhwc(
// The computation is not performed if both SRC_OFFSET and WEI_OFFSET are zero
T_OFFSET_CORRECTION(ACC_DATA_TYPE, M0, N0, 1, SRC_OFFSET, WEI_OFFSET, a, b, c);
- ++src_x;
- ++wei_x;
+ ++ck;
}
#endif // ((SRC_CHANNELS % K0) != 0)
}
@@ -236,7 +225,7 @@ __kernel void direct_convolution_nhwc(
#if defined(HAS_BIAS)
TILE(BIA_DATA_TYPE, 1, N0, bias0);
- T_LOAD(BIA_DATA_TYPE, 1, N0, BUFFER, bia, cout, 0, 0, bias0);
+ T_LOAD(BIA_DATA_TYPE, 1, N0, BUFFER, bia, cout, 0, 1, 0, bias0);
// c = c + bias[broadcasted]
T_ADD_BROADCAST_X(ACC_DATA_TYPE, M0, N0, c, bias0, c);
@@ -274,4 +263,5 @@ __kernel void direct_convolution_nhwc(
#undef _IDST_WIDTH
#undef _IDST_HEIGHT
#undef _IDST_CHANNELS
+#undef _IY_MULTIPLIER
} \ No newline at end of file
diff --git a/src/core/CL/cl_kernels/tile_helpers.h b/src/core/CL/cl_kernels/tile_helpers.h
index b963f8b5e3..496f2dd664 100644
--- a/src/core/CL/cl_kernels/tile_helpers.h
+++ b/src/core/CL/cl_kernels/tile_helpers.h
@@ -165,23 +165,26 @@
/** Load a tile from global memory (tensor)
*
- * @param[in] DATA_TYPE Data type
- * @param[in] HEIGHT Number of dst rows
- * @param[in] WIDTH Number of dst columns
- * @param[in] TENSOR_TYPE Type of cl_type used to store the tensor in global memory (BUFFER=cl_buffer, IMAGE=cl_image).
- * In case of cl_image, only WIDTH multiples of 4 are supported (4, 8, 16)
- * @param[in] TENSOR Tensor basename
- * @param[in] X Starting X position
- * @param[in] Y Starting Y position
- * @param[in] STRIDE_Y Stride Y (in bytes)
- * @param[out] dst Output tile
+ * @param[in] DATA_TYPE Data type
+ * @param[in] HEIGHT Number of dst rows
+ * @param[in] WIDTH Number of dst columns
+ * @param[in] TENSOR_TYPE Type of cl_type used to store the tensor in global memory (BUFFER=cl_buffer, IMAGE=cl_image).
+ * In case of cl_image, only WIDTH multiples of 4 are supported (4, 8, 16)
+ * @param[in] TENSOR Tensor basename
+ * @param[in] X Starting X position
+ * @param[in] Y Starting Y position
+ * @param[in] YI_MULTIPLIER Parameter used to multiply the internal row increment (_i).
+ * In common cases should be 1 but it becomes useful when we want to load rows which are multiple of STRIDE_Y. (e.g. loading the weights of convolution layer).
+ * In this case the address calculation is performed as: (Y + _i * Y_MULTIPLIER) * STRIDE_Y
+ * @param[in] STRIDE_Y Stride Y (in bytes) used to load each row.
+ * @param[out] dst Output tile
*/
-#define T_LOAD(DATA_TYPE, HEIGHT, WIDTH, TENSOR_TYPE, TENSOR, X, Y, STRIDE_Y, dst) \
+#define T_LOAD(DATA_TYPE, HEIGHT, WIDTH, TENSOR_TYPE, TENSOR, X, Y, YI_MULTIPLIER, STRIDE_Y, dst) \
({ \
LOOP_UNROLLING(int, _i, 0, HEIGHT, 1) \
{ \
- dst[_i].v = V_LOAD(DATA_TYPE, WIDTH, TENSOR_TYPE, TENSOR, X, ((Y) + _i), STRIDE_Y); \
- } \
+ dst[_i].v = V_LOAD(DATA_TYPE, WIDTH, TENSOR_TYPE, TENSOR, X, ((Y) + _i * (int)(YI_MULTIPLIER)), STRIDE_Y); \
+ } \
})
/** Load a tile from global memory (tensor) using an indirect Y index tile
@@ -223,7 +226,7 @@
* @param[in] STRIDE_Y Stride Y (in bytes)
* @param[out] dst Output tile
*/
-#define T_LOAD_NHWC(DATA_TYPE, TILE_HEIGHT, TILE_WIDTH, TILE_CHANNELS, TENSOR_TYPE, TENSOR, B, Y, X, C, TENSOR_WIDTH, TENSOR_HEIGHT, STRIDE_Y, dst) \
+#define T_LOAD_NHWC(DATA_TYPE, TILE_HEIGHT, TILE_WIDTH, TILE_CHANNELS, TENSOR_TYPE, TENSOR, B, Y, X, C, TENSOR_WIDTH, TENSOR_HEIGHT, STRIDE_Y, dst) \
({ \
LOOP_UNROLLING(int, _yk, 0, (TILE_HEIGHT), 1) \
{ \
@@ -235,9 +238,43 @@
if(_src_valid_y != 0) \
{ \
dst[_xk + _yk * (TILE_WIDTH)].v = V_LOAD(DATA_TYPE, TILE_CHANNELS, TENSOR_TYPE, TENSOR, C, _src_y, STRIDE_Y); \
- } \
- } \
- } \
+ } \
+ } \
+ } \
+ })
+
+/** Load a tile from global memory (tensor) when the tensor is stored using a NHWC layout using indirect X and Y coordinates
+ *
+ * @param[in] DATA_TYPE Data type
+ * @param[in] TILE_HEIGHT Number of elements to load from Y (height) dimension
+ * @param[in] TILE_WIDTH Number of elements to load from X (width) dimension
+ * @param[in] TILE_CHANNELS Number of elements to load from C (channel) dimension
+ * @param[in] TENSOR_TYPE Type of cl_type used to store the tensor in global memory (BUFFER=cl_buffer, IMAGE=cl_image). Currently BUFFER only is supported
+ * In case of cl_image, only TILE_CHANNELS multiples of 4 are supported (4, 8, 16)
+ * @param[in] TENSOR Tensor basename
+ * @param[in] B Starting batch index
+ * @param[in] Y Starting Y index
+ * @param[in] X Starting X index
+ * @param[in] C Starting C index
+ * @param[in] TENSOR_HEIGHT Number of elements to load from Y (height) dimension
+ * @param[in] TENSOR_WIDTH Number of elements to load from X (width) dimension
+ * @param[in] STRIDE_Y Stride Y (in bytes)
+ * @param[out] xi A tile with (TILE_WIDTH x TILE_HEIGHT) values with the indirect X coordinate
+ * @param[out] yi A tile with (TILE_WIDTH x TILE_HEIGHT) values with the indirect Y coordinate
+ * @param[out] dst Output tile
+ */
+#define T_LOAD_NHWC_INDIRECT(DATA_TYPE, TILE_HEIGHT, TILE_WIDTH, TILE_CHANNELS, TENSOR_TYPE, TENSOR, B, Y, X, C, TENSOR_WIDTH, TENSOR_HEIGHT, STRIDE_Y, xi, yi, dst) \
+ ({ \
+ LOOP_UNROLLING(int, _i, 0, (TILE_WIDTH * TILE_HEIGHT), 1) \
+ { \
+ int _src_y = (X) + xi[_i].v + ((Y) + yi[_i].v) * (TENSOR_WIDTH); \
+ _src_y += (B) * (int)(TENSOR_WIDTH) * (int)(TENSOR_HEIGHT); \
+ int _src_valid_y = (((X) + xi[_i].v) >= 0 && ((X) + xi[_i].v) < (int)(TENSOR_WIDTH) && ((Y) + yi[_i].v) >= 0 && ((Y) + yi[_i].v) < (int)(TENSOR_HEIGHT)); \
+ if(_src_valid_y != 0) \
+ { \
+ dst[_i].v = V_LOAD(DATA_TYPE, TILE_CHANNELS, TENSOR_TYPE, TENSOR, C, _src_y, STRIDE_Y); \
+ } \
+ } \
})
/** Store a tile to global memory (tensor) using an indirect Y index tile and conditionally use a different length for the store
diff --git a/src/core/CL/cl_kernels/winograd_output_transform.cl b/src/core/CL/cl_kernels/winograd_output_transform.cl
index 674a138d48..6bd90604e5 100644
--- a/src/core/CL/cl_kernels/winograd_output_transform.cl
+++ b/src/core/CL/cl_kernels/winograd_output_transform.cl
@@ -637,7 +637,7 @@ __kernel void winograd_output_transform_4x4_3x3_nhwc(
#if defined(HAS_BIAS)
TILE(DATA_TYPE, 1, N0, b);
- T_LOAD(DATA_TYPE, 1, N0, BUFFER, bias, cout, 0, 0, b);
+ T_LOAD(DATA_TYPE, 1, N0, BUFFER, bias, cout, 0, 1, 0, b);
// c = c + bias[broadcasted]
T_ADD_BROADCAST_X(DATA_TYPE, 4, N0, out, b, out);
@@ -718,7 +718,7 @@ __kernel void winograd_output_transform_4x4_3x3_nhwc(
#if defined(HAS_BIAS)
TILE(DATA_TYPE, 1, N0, b);
- T_LOAD(DATA_TYPE, 1, N0, BUFFER, bias, cout, 0, 0, b);
+ T_LOAD(DATA_TYPE, 1, N0, BUFFER, bias, cout, 0, 1, 0, b);
// c = c + bias[broadcasted]
T_ADD_BROADCAST_X(DATA_TYPE, 16, N0, out, b, out);
@@ -1070,7 +1070,7 @@ __kernel void winograd_output_transform_4x4_5x5_nhwc(
#if defined(HAS_BIAS)
TILE(DATA_TYPE, 1, N0, b);
- T_LOAD(DATA_TYPE, 1, N0, BUFFER, bias, cout, 0, 0, b);
+ T_LOAD(DATA_TYPE, 1, N0, BUFFER, bias, cout, 0, 1, 0, b);
// c = c + bias[broadcasted]
T_ADD_BROADCAST_X(DATA_TYPE, 4, N0, out, b, out);
@@ -1162,7 +1162,7 @@ __kernel void winograd_output_transform_4x4_5x5_nhwc(
#if defined(HAS_BIAS)
TILE(DATA_TYPE, 1, N0, b);
- T_LOAD(DATA_TYPE, 1, N0, BUFFER, bias, cout, 0, 0, b);
+ T_LOAD(DATA_TYPE, 1, N0, BUFFER, bias, cout, 0, 1, 0, b);
// c = c + bias[broadcasted]
T_ADD_BROADCAST_X(DATA_TYPE, 16, N0, out, b, out);
diff --git a/src/core/gpu/cl/kernels/ClDirectConvolutionKernel.cpp b/src/core/gpu/cl/kernels/ClDirectConvolutionKernel.cpp
index c6ca084386..bf26477895 100644
--- a/src/core/gpu/cl/kernels/ClDirectConvolutionKernel.cpp
+++ b/src/core/gpu/cl/kernels/ClDirectConvolutionKernel.cpp
@@ -33,12 +33,13 @@
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
#include "src/core/AccessWindowStatic.h"
+#include "src/core/CL/CLUtils.h"
#include "src/core/CL/CLValidate.h"
+#include "src/core/CL/gemm/CLGEMMHelpers.h"
#include "src/core/helpers/AutoConfiguration.h"
#include "src/core/helpers/WindowHelpers.h"
#include "support/Cast.h"
#include "support/StringSupport.h"
-
namespace arm_compute
{
namespace opencl
@@ -276,7 +277,11 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *src, ITenso
if(data_layout == DataLayout::NHWC)
{
const unsigned int vec_size = std::min(static_cast<unsigned int>(dst->tensor_shape()[0]), 4u);
- const unsigned int num_rows = dst->tensor_shape()[0] > 16 ? 2u : 1U;
+ unsigned int num_rows = 1U;
+ if(dst->tensor_shape()[0] > 16)
+ {
+ num_rows = src->data_type() == DataType::F32 ? 2U : 4U;
+ }
// Create window and update padding
Window win = calculate_max_window(output_shape, Steps(vec_size, num_rows));
@@ -318,6 +323,50 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *src, ITenso
ARM_COMPUTE_ERROR("Not supported");
}
}
+
+bool export_to_cl_image_support(ITensorInfo *tensor, GPUTarget gpu_target, DataLayout data_layout)
+{
+ if(tensor->tensor_shape()[0] % 4 || (data_layout != DataLayout::NHWC))
+ {
+ return false;
+ }
+
+ // If not floating point
+ if(!is_data_type_float(tensor->data_type()))
+ {
+ return false;
+ }
+
+ if(gpu_target == GPUTarget::G71 || get_arch_from_target(gpu_target) == GPUTarget::MIDGARD)
+ {
+ return false;
+ }
+
+ // Check if the cl_khr_image2d_from_buffer extension is supported on the target platform
+ if(!image2d_from_buffer_supported(CLKernelLibrary::get().get_device()))
+ {
+ return false;
+ }
+
+ // Check cl image pitch alignment
+ if(get_cl_image_pitch_alignment(CLKernelLibrary::get().get_device()) == 0)
+ {
+ return false;
+ }
+
+ const size_t image_w = tensor->tensor_shape()[0] / 4;
+ const size_t image_h = tensor->tensor_shape()[1] * tensor->tensor_shape()[2] * tensor->tensor_shape()[3];
+ const size_t max_image_w = CLKernelLibrary::get().get_device().getInfo<CL_DEVICE_IMAGE2D_MAX_WIDTH>();
+ const size_t max_image_h = CLKernelLibrary::get().get_device().getInfo<CL_DEVICE_IMAGE2D_MAX_HEIGHT>();
+
+ if(image_w > max_image_w || image_h > max_image_h)
+ {
+ return false;
+ }
+
+ return true;
+}
+
} // namespace
BorderSize ClDirectConvolutionKernel::border_size() const
@@ -365,13 +414,19 @@ void ClDirectConvolutionKernel::configure(const CLCompileContext &compile_contex
kernel_name << "direct_convolution_nhwc";
- const unsigned int n0 = win_config.second.x().step();
- const unsigned int m0 = win_config.second.y().step();
+ const unsigned int n0 = win_config.second.x().step();
+ const unsigned int m0 = win_config.second.y().step();
+ const unsigned int k0 = adjust_vec_size(8u, src->dimension(channel_idx));
+ const unsigned int partial_store_n0 = dst->dimension(channel_idx) % n0;
+ const unsigned int pad_left = conv_info.pad_left();
+ const unsigned int pad_top = conv_info.pad_top();
+ const bool export_to_cl_image = export_to_cl_image_support(weights, gpu_target, _data_layout);
- const unsigned int k0 = adjust_vec_size(8u, src->dimension(channel_idx));
- const unsigned int partial_store_n0 = dst->dimension(channel_idx) % n0;
- const unsigned int pad_left = conv_info.pad_left();
- const unsigned int pad_top = conv_info.pad_top();
+ // Update the padding for the weights tensor if we can export to cl_image
+ if(export_to_cl_image)
+ {
+ arm_compute::cl_gemm::update_padding_for_cl_image(weights);
+ }
if(biases != nullptr)
{
@@ -390,7 +445,7 @@ void ClDirectConvolutionKernel::configure(const CLCompileContext &compile_contex
build_options.add_option("-DDST_HEIGHT=" + support::cpp11::to_string(dst->dimension(height_idx)));
build_options.add_option("-DDST_CHANNELS=" + support::cpp11::to_string(dst->dimension(channel_idx)));
build_options.add_option("-DDST_DATA_TYPE=" + get_cl_type_from_data_type(dst->data_type()));
- build_options.add_option("-DWEI_TENSOR_TYPE=BUFFER");
+ build_options.add_option_if_else(export_to_cl_image, "-DWEI_TENSOR_TYPE=IMAGE", "-DWEI_TENSOR_TYPE=BUFFER");
build_options.add_option("-DWEI_WIDTH=" + support::cpp11::to_string(weights->dimension(width_idx)));
build_options.add_option("-DWEI_HEIGHT=" + support::cpp11::to_string(weights->dimension(height_idx)));
build_options.add_option("-DWEI_DATA_TYPE=" + get_cl_type_from_data_type(weights->data_type()));
@@ -533,13 +588,32 @@ void ClDirectConvolutionKernel::run_op(ITensorPack &tensors, const Window &windo
if(_data_layout == DataLayout::NHWC)
{
- const size_t dim_y_collapsed = ceil_to_multiple(dst->info()->dimension(1) * dst->info()->dimension(2), slice.y().step());
+ cl::Image2D weights_cl_image;
+
+ const size_t dim_y_collapsed = ceil_to_multiple(dst->info()->dimension(1) * dst->info()->dimension(2), slice.y().step());
+ const bool export_to_cl_image = export_to_cl_image_support(weights->info(), get_target(), _data_layout);
+
slice.set(Window::DimY, Window::Dimension(0, dim_y_collapsed, slice.y().step()));
slice.set(Window::DimZ, Window::Dimension(0, dst->info()->dimension(3), 1));
+ if(export_to_cl_image)
+ {
+ const size_t image_w = weights->info()->dimension(0) / 4;
+ const size_t image_h = weights->info()->dimension(1) * weights->info()->dimension(2) * weights->info()->dimension(3);
+ const TensorShape shape2d(image_w, image_h);
+ const size_t image_row_pitch = weights->info()->strides_in_bytes()[1];
+
+ // Export cl_buffer to cl_image
+ weights_cl_image = create_image2d_from_buffer(CLKernelLibrary::get().context(), weights->cl_buffer(), shape2d, weights->info()->data_type(), image_row_pitch);
+ }
+
unsigned int idx = 0;
add_4D_tensor_argument(idx, src, slice);
add_4D_tensor_argument(idx, dst, slice);
+ if(export_to_cl_image)
+ {
+ _kernel.setArg(idx++, weights_cl_image);
+ }
add_4D_tensor_argument(idx, weights, slice);
if(biases != nullptr)
{
diff --git a/src/runtime/CL/functions/CLConvolutionLayer.cpp b/src/runtime/CL/functions/CLConvolutionLayer.cpp
index b1351f6747..ac18b966af 100644
--- a/src/runtime/CL/functions/CLConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLConvolutionLayer.cpp
@@ -224,7 +224,7 @@ ConvolutionMethod CLConvolutionLayer::get_convolution_method(const ITensorInfo *
{
return ConvolutionMethod::DIRECT;
}
- if((weights->dimension(idx_h) > 7) && (input->dimension(idx_c) > output->dimension(idx_c)) && (CLDirectConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info)))
+ if((weights->dimension(idx_h) > 7) && (input->dimension(idx_c) >= output->dimension(idx_c)) && (CLDirectConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info)))
{
return ConvolutionMethod::DIRECT;
}