diff options
Diffstat (limited to 'src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp | 29 |
1 files changed, 21 insertions, 8 deletions
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp index af06fecd00..24372657f5 100644 --- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp @@ -68,20 +68,23 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const int n = gemm_info.n(); const int k = gemm_info.k(); - TensorShape tensor_shape0{ input0->tensor_shape() }; - tensor_shape0.set(0, k); - tensor_shape0.set(1, m); - TensorShape tensor_shape1{ input1->tensor_shape() }; tensor_shape1.set(0, n); tensor_shape1.set(1, k); - const TensorInfo tensor_info0 = input0->clone()->set_tensor_shape(tensor_shape0); const TensorInfo tensor_info1 = input1->clone()->set_tensor_shape(tensor_shape1); const TensorInfo tensor_info_reshaped1 = input1->clone()->set_tensor_shape(compute_rhs_reshaped_shape(tensor_info1, rhs_info)); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input0, &tensor_info0); + ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(0) != static_cast<unsigned int>(k)); + if(gemm_info.reinterpret_input_as_3d()) + { + ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(1) * input0->dimension(2) != static_cast<unsigned int>(m)); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(1) != static_cast<unsigned int>(m)); + } ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, &tensor_info_reshaped1); if(output->total_size() != 0) @@ -99,6 +102,7 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITe { unsigned int &num_elems_processed_per_iteration_x = num_elements_processed[0]; unsigned int &num_elems_processed_per_iteration_y = num_elements_processed[1]; + bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); bool reinterpret_output_as_3d = (gemm_info.depth_output_gemm3d() != 0); Window win{}; @@ -107,6 +111,10 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITe // In case both input and output have to be reinterpreted as 3D tensors, // force reinterpret_input_as_3d and reinterpret_output_as_3d to be false. + if(reinterpret_input_as_3d == reinterpret_output_as_3d) + { + reinterpret_output_as_3d = false; + } // Output tensor auto initialization if not yet initialized auto_init_if_empty(*output, input0->clone()->set_tensor_shape(compute_mm_shape(*input0, *input1, gemm_info))); @@ -147,7 +155,7 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITe window_changed = update_window_and_padding(win, input0_access, input1_access) || // window used by the execute_window_loop update_window_and_padding(win_out, output_access); // window used to update the padding requirements of output tensor - output_access.set_valid_region(win_out, ValidRegion(Coordinates(0, 0), output->tensor_shape())); + output_access.set_valid_region(win_out, ValidRegion(Coordinates(), output->tensor_shape())); // Collapse along the Z direction // This collapse needs to be here in order to tune the Z dimension of LWS @@ -181,6 +189,11 @@ void CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::configure(const ICLTensor *input // In case both input and output have to be reinterpreted as 3D tensors, // force reinterpret_input_as_3d and reinterpret_output_as_3d to be false. + if(_reinterpret_input_as_3d == _reinterpret_output_as_3d) + { + _reinterpret_input_as_3d = false; + _reinterpret_output_as_3d = false; + } // Check if we need to slide the matrix B const unsigned int num_dimensions_input0 = _input0->info()->num_dimensions(); @@ -204,7 +217,7 @@ void CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::configure(const ICLTensor *input build_opts.add_option_if(!_slide_matrix_b, "-DMATRIX_B_DEPTH=" + support::cpp11::to_string(input1->info()->dimension(2))); build_opts.add_option_if(rhs_info.interleave, "-DRHS_INTERLEAVE"); build_opts.add_option_if(_use_dummy_work_items, "-DDUMMY_WORK_ITEMS"); - build_opts.add_option("-DM=" + support::cpp11::to_string(gemm_info.m())); + build_opts.add_option("-DM=" + support::cpp11::to_string(input0->info()->dimension(1))); build_opts.add_option("-DN=" + support::cpp11::to_string(gemm_info.n())); build_opts.add_option("-DK=" + support::cpp11::to_string(gemm_info.k())); build_opts.add_option("-DM0=" + support::cpp11::to_string(lhs_info.m0)); |