diff options
Diffstat (limited to 'src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp | 21 |
1 files changed, 13 insertions, 8 deletions
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp index 9c69800928..7b785bb8da 100644 --- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp @@ -55,6 +55,7 @@ inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *i ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input0, input1); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 3, "The number of dimensions for the matrix B must be <= 3"); if(!is_interleaved_transposed) { @@ -174,7 +175,7 @@ inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *inpu } // namespace CLGEMMMatrixMultiplyKernel::CLGEMMMatrixMultiplyKernel() - : _input0(nullptr), _input1(nullptr), _output(nullptr) + : _input0(nullptr), _input1(nullptr), _output(nullptr), _slide_matrix_b(true) { } @@ -192,9 +193,10 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen // Perform validate step ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info(), is_interleaved_transposed, reshape_info)); - _input0 = input0; - _input1 = input1; - _output = output; + _input0 = input0; + _input1 = input1; + _output = output; + _slide_matrix_b = _input1->info()->num_dimensions() >= _input0->info()->num_dimensions(); const DataType data_type = input0->info()->data_type(); const int fp_pos = input0->info()->fixed_point_position(); @@ -257,6 +259,9 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen "-DALPHA=" + float_to_string_with_full_precision(alpha)); } + // Do not slide matrix B if _slide_matrix_b = false + build_opts.add_option_if(!_slide_matrix_b, "-DMATRIX_B_DEPTH=" + support::cpp11::to_string(input1->info()->dimension(2))); + std::string kernel_name; if(is_interleaved_transposed) { @@ -365,7 +370,7 @@ void CLGEMMMatrixMultiplyKernel::run(const Window &window, cl::CommandQueue &que Window slice_b = slice; // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2 // This scenario can happen when the matrix multiplication is used to perform a convolution operation - if(_input1->info()->num_dimensions() < 3) + if(!_slide_matrix_b) { slice_b = slice_matrix_b; } @@ -374,9 +379,9 @@ void CLGEMMMatrixMultiplyKernel::run(const Window &window, cl::CommandQueue &que add_2D_tensor_argument(idx, _input0, slice); add_2D_tensor_argument(idx, _input1, slice_b); add_2D_tensor_argument(idx, _output, slice); - _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_input0->info()->strides_in_bytes()[3])); - _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_input1->info()->strides_in_bytes()[3])); - _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_output->info()->strides_in_bytes()[3])); + _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_input0->info()->strides_in_bytes()[2])); + _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_input1->info()->strides_in_bytes()[2])); + _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_output->info()->strides_in_bytes()[2])); enqueue(queue, *this, slice, _lws_hint); } while(window.slide_window_slice_3D(slice)); |