diff options
Diffstat (limited to 'src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp index 99e184050e..73b1d41eb1 100644 --- a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp +++ b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp @@ -108,6 +108,7 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITensorInfo *input1, ITensorInfo *output, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, ElementsProcessed &num_elements_processed) { + const bool is_dot8_supported = dot8_supported(CLKernelLibrary::get().get_device()); unsigned int &num_elems_processed_per_iteration_x = num_elements_processed[0]; unsigned int &num_elems_processed_per_iteration_y = num_elements_processed[1]; bool reinterpret_input_as_3d = reshape_info.reinterpret_input_as_3d(); @@ -126,7 +127,7 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITe } // Output tensor auto inizialitation if not yet initialized - auto_init_if_empty(*output, input0->clone()->set_tensor_shape(compute_mm_shape(*input0, *input1, is_interleaved_transposed, reshape_info))); + auto_init_if_empty(*output, input0->clone()->set_tensor_shape(compute_mm_shape(*input0, *input1, is_interleaved_transposed, reshape_info)).set_data_type(DataType::S32)); TensorInfo tmp_info(*output); @@ -173,8 +174,9 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITe else { // Special case for 1xN, 2xN, 3xN and 4xN input0 tensor. num_elems_processed_per_iteration_x - num_elems_processed_per_iteration_x = 4; - num_elems_processed_per_iteration_y = std::min(static_cast<int>(output->dimension(1)), 5); + // Note: if the dot product instruction is available, the 8x2 tile has to be used + num_elems_processed_per_iteration_x = is_dot8_supported ? 8 : 4; + num_elems_processed_per_iteration_y = std::min(static_cast<int>(output->dimension(1)), is_dot8_supported ? 2 : 4); // Note: bottom paddings are calculated manually as the output can be reinterpreted as 3D tensor // The only way to set properly the paddings, it is to set those explicitly through the AccessWindowStatic @@ -270,6 +272,7 @@ void CLGEMMLowpMatrixMultiplyKernel::configure(const ICLTensor *input0, const IC // the correct step which is calculated as (16 * mult_transpose1xW_width) / 4) build_opts.add_option("-DCOLS_B=" + support::cpp11::to_string(input1->info()->dimension(0))); + build_opts.add_option("-DMULT_TRANSPOSE1XW_WIDTH=" + support::cpp11::to_string(mult_transpose1xW_width)); build_opts.add_option("-DTRANSPOSE1XW_WIDTH_STEP=" + support::cpp11::to_string(4 * mult_transpose1xW_width)); build_opts.add_option("-DMULT_INTERLEAVE4X4_HEIGHT=" + support::cpp11::to_string(mult_interleave4x4_height)); |