diff options
Diffstat (limited to 'src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp | 99 |
1 files changed, 71 insertions, 28 deletions
diff --git a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp index ef572cfc7e..b3227c0db9 100644 --- a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp +++ b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp @@ -51,45 +51,88 @@ CLGEMMLowpMatrixMultiplyKernel::CLGEMMLowpMatrixMultiplyKernel() { } -void CLGEMMLowpMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, - int32_t a_offset, int32_t b_offset, int32_t output_offset, int32_t output_mult_int, int32_t shift) +void CLGEMMLowpMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, bool is_interleaved_transposed) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::U8); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8); - ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1, output); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::QASYMM8); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::QASYMM8); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32); + + if(!is_interleaved_transposed) + { + ARM_COMPUTE_ERROR_ON(input0->info()->dimension(0) != input1->info()->dimension(1)); + } + + TensorShape in1_shape = input1->info()->tensor_shape(); + in1_shape.collapse(2); _input0 = input0; _input1 = input1; _output = output; - // Create kernel and set static arguments - std::set<std::string> build_opts = { ("-DCOLS_B=" + support::cpp11::to_string(input1->info()->dimension(0))) }; - _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("gemm_mm_interleaved_transposed_u8", build_opts)); - unsigned int idx = 3 * num_arguments_per_2D_tensor(); //Skip the input and output parameters - _kernel.setArg<int32_t>(idx++, a_offset); - _kernel.setArg<int32_t>(idx++, b_offset); - _kernel.setArg<int32_t>(idx++, output_offset); - _kernel.setArg<int32_t>(idx++, output_mult_int); - _kernel.setArg<int32_t>(idx++, shift); + CLBuildOptions build_opts; - // Configure window - constexpr unsigned int num_elems_processed_per_iteration_x = 16; - constexpr unsigned int num_elems_processed_per_iteration_y = 4; - constexpr unsigned int num_elems_read_per_iteration_input0 = 4; - constexpr unsigned int num_elems_read_per_iteration_input1 = 16; + if(is_interleaved_transposed) + { + // Create kernel and set static arguments + build_opts.add_option("-DCOLS_B=" + support::cpp11::to_string(input1->info()->dimension(0))); + _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("gemmlowp_mm_interleaved_transposed", build_opts.options())); - Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y)); + // Configure window + constexpr unsigned int num_elems_processed_per_iteration_x = 16; + constexpr unsigned int num_elems_processed_per_iteration_y = 4; + constexpr unsigned int num_elems_read_per_iteration_input0 = 4; + constexpr unsigned int num_elems_read_per_iteration_input1 = 16; - AccessWindowRectangle input0_access(input0->info(), 0, 0, num_elems_read_per_iteration_input0, 1); - AccessWindowRectangle input1_access(input1->info(), 0, 0, num_elems_read_per_iteration_input1, 1); - AccessWindowRectangle output_access(output->info(), 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y); + Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y)); - update_window_and_padding(win, input0_access, input1_access, output_access); + AccessWindowRectangle input0_access(input0->info(), 0, 0, num_elems_read_per_iteration_input0, 1); + AccessWindowRectangle input1_access(input1->info(), 0, 0, num_elems_read_per_iteration_input1, 1); + AccessWindowRectangle output_access(output->info(), 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y); - output_access.set_valid_region(win, ValidRegion(Coordinates(0, 0), output->info()->tensor_shape())); + update_window_and_padding(win, input0_access, input1_access, output_access); + + output_access.set_valid_region(win, ValidRegion(Coordinates(0, 0), output->info()->tensor_shape())); + + ICLKernel::configure(win); + } + else + { + // Special case for 1xN, 2xN, 3xN and 4xN input0 tensor. num_elems_processed_per_iteration_x + constexpr unsigned int num_elems_processed_per_iteration_x = 16; + const unsigned int num_elems_processed_per_iteration_y = std::min(static_cast<int>(output->info()->dimension(1)), 4); + + build_opts.add_option("-DCOLS_A=" + support::cpp11::to_string(input0->info()->dimension(0))); + build_opts.add_option("-DNUM_ELEMS_PROCESSED_PER_THREAD_X=" + support::cpp11::to_string(num_elems_processed_per_iteration_x)); + build_opts.add_option("-DNUM_ELEMS_PROCESSED_PER_THREAD_Y=" + support::cpp11::to_string(num_elems_processed_per_iteration_y)); + + _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("gemmlowp_mm", build_opts.options())); + + // Configure window + Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y)); + + AccessWindowStatic input0_access(input0->info(), 0, 0, input0->info()->dimension(0), ceil_to_multiple(input0->info()->dimension(1), num_elems_processed_per_iteration_y)); + AccessWindowStatic input1_access(input1->info(), 0, 0, ceil_to_multiple(input1->info()->dimension(0), num_elems_processed_per_iteration_x), input1->info()->dimension(1)); + AccessWindowRectangle output_access(output->info(), 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y); + + update_window_and_padding(win, input0_access, input1_access, output_access); + + Coordinates coord; + coord.set_num_dimensions(output->info()->num_dimensions()); + output_access.set_valid_region(win, ValidRegion(coord, output->info()->tensor_shape())); + + ICLKernel::configure(win); + } - ICLKernel::configure(win); + // Set config_id for enabling LWS tuning + _config_id = "gemmlowp_"; + _config_id += (is_interleaved_transposed ? "reshaped_" : ""); + _config_id += lower_string(string_from_data_type(input0->info()->data_type())); + _config_id += "_"; + _config_id += support::cpp11::to_string(output->info()->dimension(1)); + _config_id += "_"; + _config_id += support::cpp11::to_string(output->info()->dimension(0)); + _config_id += "_"; + _config_id += (is_interleaved_transposed ? support::cpp11::to_string(input1->info()->dimension(0)) : support::cpp11::to_string(input1->info()->dimension(1))); } void CLGEMMLowpMatrixMultiplyKernel::run(const Window &window, cl::CommandQueue &queue) @@ -117,7 +160,7 @@ void CLGEMMLowpMatrixMultiplyKernel::run(const Window &window, cl::CommandQueue add_2D_tensor_argument(idx, _input0, slice); add_2D_tensor_argument(idx, _input1, slice_b); add_2D_tensor_argument(idx, _output, slice); - enqueue(queue, *this, slice); + enqueue(queue, *this, slice, _lws_hint); } while(window.slide_window_slice_2D(slice)); } |