diff options
Diffstat (limited to 'src/core/GLES_COMPUTE/kernels/GCDirectConvolutionLayerKernel.cpp')
-rw-r--r-- | src/core/GLES_COMPUTE/kernels/GCDirectConvolutionLayerKernel.cpp | 66 |
1 files changed, 62 insertions, 4 deletions
diff --git a/src/core/GLES_COMPUTE/kernels/GCDirectConvolutionLayerKernel.cpp b/src/core/GLES_COMPUTE/kernels/GCDirectConvolutionLayerKernel.cpp index a7d721d035..ab78fb994b 100644 --- a/src/core/GLES_COMPUTE/kernels/GCDirectConvolutionLayerKernel.cpp +++ b/src/core/GLES_COMPUTE/kernels/GCDirectConvolutionLayerKernel.cpp @@ -102,6 +102,7 @@ void GCDirectConvolutionLayerKernel<kernel_size>::configure(const IGCTensor *inp options.emplace("#define LOCAL_SIZE_Y " + support::cpp11::to_string(_lws[1])); options.emplace("#define LOCAL_SIZE_Z " + support::cpp11::to_string(_lws[2])); options.emplace("#define STRIDE_X " + support::cpp11::to_string(_conv_stride_x)); + options.emplace("#define STRIDE_Y " + support::cpp11::to_string(_conv_stride_y)); std::string dt_name = (input->info()->data_type() == DataType::F32) ? "DATA_TYPE_FP32" : "DATA_TYPE_FP16"; options.emplace(("#define " + dt_name)); @@ -148,6 +149,10 @@ void GCDirectConvolutionLayerKernel<kernel_size>::configure(const IGCTensor *inp num_elems_written_per_iteration_y = 3; num_elems_written_per_iteration_z = 2; #endif /* PROCESS_X_8ELEMENTS_Y_3ELEMENTS_FP16 */ +#undef PROCESS_X_8ELEMENTS_Y_3ELEMENTS_FP16 +#undef PROCESS_X_4ELEMENTS_Y_3ELEMENTS_FP16 +#undef PROCESS_X_4ELEMENTS_Y_4ELEMENTS_FP16 +#undef PROCESS_X_4ELEMENTS_Y_3ELEMENTS_Z_2ELEMENTS_FP16 break; case DataType::F32: @@ -193,6 +198,9 @@ void GCDirectConvolutionLayerKernel<kernel_size>::configure(const IGCTensor *inp #else /* PROCESS_1_ELEMENT */ #error Have to declare how many elements to process in one thread. #endif /* PROCESS_1_ELEMENT */ +#undef PROCESS_1_ELEMENT +#undef PROCESS_4_ELEMENT +#undef PROCESS_8_ELEMENT break; default: @@ -203,15 +211,65 @@ void GCDirectConvolutionLayerKernel<kernel_size>::configure(const IGCTensor *inp } else if(kernel_size == 1) { + if(weights->info()->dimension(2) % 2 == 0) + { + options.emplace("#define WEIGHTS_OPTIMIZATION"); + } switch(input->info()->data_type()) { case DataType::F16: +#define PROCESS_8X_2Y_1Z + +#if defined(PROCESS_4X_1Y_1Z) + options.emplace("#define PROCESS_4X_1Y_1Z"); + num_elems_read_per_iteration_x = 4; + num_elems_written_per_iteration_x = 4; +#elif defined(PROCESS_4X_2Y_1Z) + options.emplace("#define PROCESS_4X_2Y_1Z"); + num_elems_read_per_iteration_x = 4; + num_elems_read_per_iteration_y = 2; + num_elems_written_per_iteration_x = 4; + num_elems_written_per_iteration_y = 2; +#elif defined(PROCESS_4X_3Y_1Z) + options.emplace("#define PROCESS_4X_3Y_1Z"); + num_elems_read_per_iteration_x = 4; + num_elems_read_per_iteration_y = 3; + num_elems_written_per_iteration_x = 4; + num_elems_written_per_iteration_y = 3; +#elif defined(PROCESS_4X_4Y_1Z) + options.emplace("#define PROCESS_4X_4Y_1Z"); + num_elems_read_per_iteration_x = 4; + num_elems_read_per_iteration_y = 4; + num_elems_written_per_iteration_x = 4; + num_elems_written_per_iteration_y = 4; +#elif defined(PROCESS_4X_2Y_2Z) + ARM_COMPUTE_ERROR_ON_MSG((weights->info()->dimension(4) % 2) == 1, "Current 'weights->info()->dimension(4) % 2) == 1' is not supported"); + options.emplace("#define PROCESS_4X_2Y_2Z"); + num_elems_read_per_iteration_x = 4; + num_elems_read_per_iteration_y = 2; + num_elems_written_per_iteration_x = 4; + num_elems_written_per_iteration_y = 2; + num_elems_written_per_iteration_z = 2; +#elif defined(PROCESS_8X_1Y_1Z) + options.emplace("#define PROCESS_8X_1Y_1Z"); num_elems_read_per_iteration_x = 8; num_elems_written_per_iteration_x = 8; - if(weights->info()->dimension(2) % 2 == 0) - { - options.emplace("#define WEIGHTS_OPTIMIZATION"); - } +#elif defined(PROCESS_8X_2Y_1Z) + options.emplace("#define PROCESS_8X_2Y_1Z"); + num_elems_read_per_iteration_x = 8; + num_elems_read_per_iteration_y = 2; + num_elems_written_per_iteration_x = 8; + num_elems_written_per_iteration_y = 2; +#else /* PROCESS_4X_1Y_1Z */ +#error Have to declare how many elements to process in one thread. +#endif /* PROCESS_4X_1Y_1Z */ +#undef PROCESS_4X_1Y_1Z +#undef PROCESS_4X_2Y_1Z +#undef PROCESS_4X_3Y_1Z +#undef PROCESS_4X_4Y_1Z +#undef PROCESS_4X_2Y_2Z +#undef PROCESS_8X_1Y_1Z +#undef PROCESS_8X_2Y_1Z break; case DataType::F32: |