aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp99
1 files changed, 71 insertions, 28 deletions
diff --git a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp
index ef572cfc7e..b3227c0db9 100644
--- a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp
@@ -51,45 +51,88 @@ CLGEMMLowpMatrixMultiplyKernel::CLGEMMLowpMatrixMultiplyKernel()
{
}
-void CLGEMMLowpMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output,
- int32_t a_offset, int32_t b_offset, int32_t output_offset, int32_t output_mult_int, int32_t shift)
+void CLGEMMLowpMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, bool is_interleaved_transposed)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::U8);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1, output);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::QASYMM8);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::QASYMM8);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32);
+
+ if(!is_interleaved_transposed)
+ {
+ ARM_COMPUTE_ERROR_ON(input0->info()->dimension(0) != input1->info()->dimension(1));
+ }
+
+ TensorShape in1_shape = input1->info()->tensor_shape();
+ in1_shape.collapse(2);
_input0 = input0;
_input1 = input1;
_output = output;
- // Create kernel and set static arguments
- std::set<std::string> build_opts = { ("-DCOLS_B=" + support::cpp11::to_string(input1->info()->dimension(0))) };
- _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("gemm_mm_interleaved_transposed_u8", build_opts));
- unsigned int idx = 3 * num_arguments_per_2D_tensor(); //Skip the input and output parameters
- _kernel.setArg<int32_t>(idx++, a_offset);
- _kernel.setArg<int32_t>(idx++, b_offset);
- _kernel.setArg<int32_t>(idx++, output_offset);
- _kernel.setArg<int32_t>(idx++, output_mult_int);
- _kernel.setArg<int32_t>(idx++, shift);
+ CLBuildOptions build_opts;
- // Configure window
- constexpr unsigned int num_elems_processed_per_iteration_x = 16;
- constexpr unsigned int num_elems_processed_per_iteration_y = 4;
- constexpr unsigned int num_elems_read_per_iteration_input0 = 4;
- constexpr unsigned int num_elems_read_per_iteration_input1 = 16;
+ if(is_interleaved_transposed)
+ {
+ // Create kernel and set static arguments
+ build_opts.add_option("-DCOLS_B=" + support::cpp11::to_string(input1->info()->dimension(0)));
+ _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("gemmlowp_mm_interleaved_transposed", build_opts.options()));
- Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+ // Configure window
+ constexpr unsigned int num_elems_processed_per_iteration_x = 16;
+ constexpr unsigned int num_elems_processed_per_iteration_y = 4;
+ constexpr unsigned int num_elems_read_per_iteration_input0 = 4;
+ constexpr unsigned int num_elems_read_per_iteration_input1 = 16;
- AccessWindowRectangle input0_access(input0->info(), 0, 0, num_elems_read_per_iteration_input0, 1);
- AccessWindowRectangle input1_access(input1->info(), 0, 0, num_elems_read_per_iteration_input1, 1);
- AccessWindowRectangle output_access(output->info(), 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y);
+ Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
- update_window_and_padding(win, input0_access, input1_access, output_access);
+ AccessWindowRectangle input0_access(input0->info(), 0, 0, num_elems_read_per_iteration_input0, 1);
+ AccessWindowRectangle input1_access(input1->info(), 0, 0, num_elems_read_per_iteration_input1, 1);
+ AccessWindowRectangle output_access(output->info(), 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y);
- output_access.set_valid_region(win, ValidRegion(Coordinates(0, 0), output->info()->tensor_shape()));
+ update_window_and_padding(win, input0_access, input1_access, output_access);
+
+ output_access.set_valid_region(win, ValidRegion(Coordinates(0, 0), output->info()->tensor_shape()));
+
+ ICLKernel::configure(win);
+ }
+ else
+ {
+ // Special case for 1xN, 2xN, 3xN and 4xN input0 tensor. num_elems_processed_per_iteration_x
+ constexpr unsigned int num_elems_processed_per_iteration_x = 16;
+ const unsigned int num_elems_processed_per_iteration_y = std::min(static_cast<int>(output->info()->dimension(1)), 4);
+
+ build_opts.add_option("-DCOLS_A=" + support::cpp11::to_string(input0->info()->dimension(0)));
+ build_opts.add_option("-DNUM_ELEMS_PROCESSED_PER_THREAD_X=" + support::cpp11::to_string(num_elems_processed_per_iteration_x));
+ build_opts.add_option("-DNUM_ELEMS_PROCESSED_PER_THREAD_Y=" + support::cpp11::to_string(num_elems_processed_per_iteration_y));
+
+ _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("gemmlowp_mm", build_opts.options()));
+
+ // Configure window
+ Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+
+ AccessWindowStatic input0_access(input0->info(), 0, 0, input0->info()->dimension(0), ceil_to_multiple(input0->info()->dimension(1), num_elems_processed_per_iteration_y));
+ AccessWindowStatic input1_access(input1->info(), 0, 0, ceil_to_multiple(input1->info()->dimension(0), num_elems_processed_per_iteration_x), input1->info()->dimension(1));
+ AccessWindowRectangle output_access(output->info(), 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y);
+
+ update_window_and_padding(win, input0_access, input1_access, output_access);
+
+ Coordinates coord;
+ coord.set_num_dimensions(output->info()->num_dimensions());
+ output_access.set_valid_region(win, ValidRegion(coord, output->info()->tensor_shape()));
+
+ ICLKernel::configure(win);
+ }
- ICLKernel::configure(win);
+ // Set config_id for enabling LWS tuning
+ _config_id = "gemmlowp_";
+ _config_id += (is_interleaved_transposed ? "reshaped_" : "");
+ _config_id += lower_string(string_from_data_type(input0->info()->data_type()));
+ _config_id += "_";
+ _config_id += support::cpp11::to_string(output->info()->dimension(1));
+ _config_id += "_";
+ _config_id += support::cpp11::to_string(output->info()->dimension(0));
+ _config_id += "_";
+ _config_id += (is_interleaved_transposed ? support::cpp11::to_string(input1->info()->dimension(0)) : support::cpp11::to_string(input1->info()->dimension(1)));
}
void CLGEMMLowpMatrixMultiplyKernel::run(const Window &window, cl::CommandQueue &queue)
@@ -117,7 +160,7 @@ void CLGEMMLowpMatrixMultiplyKernel::run(const Window &window, cl::CommandQueue
add_2D_tensor_argument(idx, _input0, slice);
add_2D_tensor_argument(idx, _input1, slice_b);
add_2D_tensor_argument(idx, _output, slice);
- enqueue(queue, *this, slice);
+ enqueue(queue, *this, slice, _lws_hint);
}
while(window.slide_window_slice_2D(slice));
}