diff options
Diffstat (limited to 'src/core/CL/kernels/CLGEMMMatrixAccumulateBiasesKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLGEMMMatrixAccumulateBiasesKernel.cpp | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/src/core/CL/kernels/CLGEMMMatrixAccumulateBiasesKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixAccumulateBiasesKernel.cpp index 289873c23f..75c1a6e629 100644 --- a/src/core/CL/kernels/CLGEMMMatrixAccumulateBiasesKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixAccumulateBiasesKernel.cpp @@ -43,20 +43,22 @@ CLGEMMMatrixAccumulateBiasesKernel::CLGEMMMatrixAccumulateBiasesKernel() void CLGEMMMatrixAccumulateBiasesKernel::configure(ICLTensor *accum, const ICLTensor *biases) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(accum, 1, DataType::F16, DataType::F32); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(biases, 1, DataType::F16, DataType::F32); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(accum, 1, DataType::QS8, DataType::F16, DataType::F32); ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(biases, accum); + ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(biases, accum); ARM_COMPUTE_ERROR_ON(biases->info()->num_dimensions() != 1); _biases = biases; _accum = accum; + std::set<std::string> build_opts; + build_opts.insert(("-DDATA_TYPE=" + get_cl_type_from_data_type(accum->info()->data_type()))); + // Create kernel - std::string data_type_name = lower_string(string_from_data_type(accum->info()->data_type())); - _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("gemm_accumulate_biases_" + data_type_name)); + _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("gemm_accumulate_biases", build_opts)); // Configure kernel window - const unsigned int num_elems_processed_per_iteration = max_cl_vector_width / data_size_from_type(accum->info()->data_type()); + const unsigned int num_elems_processed_per_iteration = 16; Window win = calculate_max_window(*_accum->info(), Steps(num_elems_processed_per_iteration)); |