diff options
Diffstat (limited to 'src/gpu/cl')
-rw-r--r-- | src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp index 4ca5580443..4ca4b83f9c 100644 --- a/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp +++ b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp @@ -41,7 +41,6 @@ #include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h" #include "support/Cast.h" #include "support/StringSupport.h" - namespace arm_compute { namespace opencl @@ -101,7 +100,7 @@ Status validate_arguments(const ITensorInfo *src0, ARM_COMPUTE_RETURN_ERROR_ON(src0->dimension(0) != k); // Validate the reinterpreted-as-3D-case - if (gemm_info.reinterpret_input_as_3d != 0) + if (gemm_info.reinterpret_input_as_3d) { ARM_COMPUTE_RETURN_ERROR_ON(src0->dimension(1) * src0->dimension(2) != m); } @@ -284,8 +283,12 @@ void ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel::configure(const CLCompileCon build_opts.add_option("-DA_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.a())); build_opts.add_option("-DB_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.b())); + build_opts.add_option_if(gemm_info.reinterpret_input_as_3d, "-DREINTERPRET_INPUT_AS_3D"); + build_opts.add_option_if(gemm_info.depth_output_gemm3d != 0, "-DREINTERPRET_OUTPUT_AS_3D"); + build_opts.add_option_if(src1->num_dimensions() > 2, "-DBATCHED_RHS"); + std::string kernel_name("gemm_mm_reshaped_only_rhs_nt_mmul"); - kernel_name += rhs_info.export_to_cl_image ? "_texture" : ""; + kernel_name += _export_to_cl_image ? "_texture" : ""; // A macro guard to compile ONLY the kernel of interest build_opts.add_option("-D" + upper_string(kernel_name)); |