aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLArithmeticAdditionKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLArithmeticAdditionKernel.cpp25
1 files changed, 21 insertions, 4 deletions
diff --git a/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp b/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp
index b0177ab8b6..011807ad88 100644
--- a/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp
+++ b/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp
@@ -37,9 +37,15 @@ Status validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2,
{
ARM_COMPUTE_UNUSED(policy);
ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(&input1);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::U8, DataType::QASYMM8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(&input2);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input2, 1, DataType::U8, DataType::QASYMM8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+
+ const bool is_qasymm = is_data_type_quantized_asymmetric(input1.data_type()) || is_data_type_quantized_asymmetric(input2.data_type());
+ if(is_qasymm)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input1, &input2);
+ }
const TensorShape out_shape = TensorShape::broadcast_shape(input1.tensor_shape(), input2.tensor_shape());
@@ -50,12 +56,16 @@ Status validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2,
if(output.total_size() > 0)
{
ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(&output);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&output, 1, DataType::U8, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MSG((output.data_type() == DataType::U8) && ((input1.data_type() != DataType::U8) || (input2.data_type() != DataType::U8)),
"Output can only be U8 if both inputs are U8");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output.tensor_shape(), 0),
"Wrong shape for output");
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(&input1, &output);
+ if(is_qasymm)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input1, &output);
+ }
}
return Status{};
@@ -124,6 +134,8 @@ void CLArithmeticAdditionKernel::configure(const ICLTensor *input1, const ICLTen
const bool has_float_out = is_data_type_float(output->info()->data_type());
+ std::string kernel_name = "arithmetic_add";
+
// Set kernel build options
std::set<std::string> build_opts;
build_opts.emplace((policy == ConvertPolicy::WRAP || has_float_out) ? "-DWRAP" : "-DSATURATE");
@@ -134,9 +146,14 @@ void CLArithmeticAdditionKernel::configure(const ICLTensor *input1, const ICLTen
{
build_opts.emplace("-DFIXED_POINT_POSITION=" + support::cpp11::to_string(input1->info()->fixed_point_position()));
}
+ else if(is_data_type_quantized_asymmetric(input1->info()->data_type()))
+ {
+ build_opts.emplace("-DOFFSET=" + support::cpp11::to_string(input1->info()->quantization_info().offset));
+ kernel_name += "_quantized";
+ }
// Create kernel
- _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("arithmetic_add", build_opts));
+ _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts));
ICLKernel::configure(win_config.second);
}