diff options
Diffstat (limited to 'src/core/CL/kernels/CLDequantizationLayerKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLDequantizationLayerKernel.cpp | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/src/core/CL/kernels/CLDequantizationLayerKernel.cpp b/src/core/CL/kernels/CLDequantizationLayerKernel.cpp index 0b066837a9..e383bc475d 100644 --- a/src/core/CL/kernels/CLDequantizationLayerKernel.cpp +++ b/src/core/CL/kernels/CLDequantizationLayerKernel.cpp @@ -40,7 +40,7 @@ namespace Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QSYMM8); if(output->tensor_shape().total_size() > 0) { @@ -95,15 +95,19 @@ void CLDequantizationLayerKernel::configure(const ICLTensor *input, ICLTensor *o } ICLKernel::configure_internal(win); - const UniformQuantizationInfo qinfo = input->info()->quantization_info().uniform(); + const UniformQuantizationInfo qinfo = input->info()->quantization_info().uniform(); + const int qoffset = is_data_type_quantized_asymmetric(input->info()->data_type()) ? qinfo.offset : 0; // Create kernel CLBuildOptions build_opts; build_opts.add_option("-DSCALE=" + float_to_string_with_full_precision(qinfo.scale)); - build_opts.add_option("-DOFFSET=" + support::cpp11::to_string(qinfo.offset)); + build_opts.add_option("-DOFFSET=" + support::cpp11::to_string(qoffset)); build_opts.add_option("-DVEC_SIZE=" + support::cpp11::to_string(vec_size_x)); - build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(output->info()->data_type())); + build_opts.add_option("-DDATA_TYPE_SRC=" + get_cl_type_from_data_type(input->info()->data_type())); + build_opts.add_option("-DDATA_TYPE_DST=" + get_cl_type_from_data_type(output->info()->data_type())); build_opts.add_option_if(multi_access_x, "-DLAST_ACCESSED_X=" + support::cpp11::to_string(std::max<int>(output_width_x - vec_size_x, 0))); + + // Create kernel name _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("dequantization_layer", build_opts.options())); } |