aboutsummaryrefslogtreecommitdiff
path: root/src/core/gpu/cl/kernels/ClElementwiseKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/gpu/cl/kernels/ClElementwiseKernel.cpp')
-rw-r--r--src/core/gpu/cl/kernels/ClElementwiseKernel.cpp35
1 files changed, 33 insertions, 2 deletions
diff --git a/src/core/gpu/cl/kernels/ClElementwiseKernel.cpp b/src/core/gpu/cl/kernels/ClElementwiseKernel.cpp
index 8f12eb2215..335ee9c392 100644
--- a/src/core/gpu/cl/kernels/ClElementwiseKernel.cpp
+++ b/src/core/gpu/cl/kernels/ClElementwiseKernel.cpp
@@ -98,6 +98,29 @@ Status validate_arguments_with_float_only_supported_rules(const ITensorInfo &src
return Status{};
}
+Status validate_arguments_divide_operation(const ITensorInfo* src1, const ITensorInfo* src2, const ITensorInfo* dst)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src1, src2, dst);
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(src1);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src1, 1, DataType::F16, DataType::F32, DataType::S32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src1, src2);
+
+ const TensorShape out_shape = TensorShape::broadcast_shape(src1->tensor_shape(), src2->tensor_shape());
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
+
+ // Validate in case of configured dst
+ if(dst->total_size() > 0)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(dst, 1, DataType::F16, DataType::F32, DataType::S32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src1, dst);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, dst->tensor_shape(), 0),
+ "Wrong shape for dst");
+ }
+
+ return Status{};
+}
+
Status validate_arguments_with_arithmetic_rules(const ITensorInfo &src1, const ITensorInfo &src2, const ITensorInfo &dst)
{
ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(&src1);
@@ -180,6 +203,8 @@ CLBuildOptions generate_build_options_with_arithmetic_rules(const ITensorInfo &s
build_opts.add_option("-DSCALE_IN2=" + float_to_string_with_full_precision(iq2info.scale));
build_opts.add_option("-DSCALE_OUT=" + float_to_string_with_full_precision(oqinfo.scale));
}
+ build_opts.add_option_if(src1.data_type() == DataType::S32, "-DS32");
+
return build_opts;
}
@@ -459,9 +484,15 @@ void ClArithmeticKernel::configure(const ClCompileContext &compile_context, Arit
Status ClArithmeticKernel::validate(ArithmeticOperation op, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const ActivationLayerInfo &act_info)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src1, src2, dst);
- if(op == ArithmeticOperation::DIV || op == ArithmeticOperation::POWER)
+ if(op == ArithmeticOperation::DIV)
{
- // Division and Power operators don't support integer arithmetic
+ // Partial integer support S32/F32/F16
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_divide_operation(src1, src2, dst));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_for_division(*src1->clone(), *src2->clone(), *dst->clone()).first);
+ }
+ else if(op == ArithmeticOperation::POWER)
+ {
+ // Power operators doesn't support integer arithmetic
ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_with_float_only_supported_rules(*src1, *src2, *dst));
ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_for_division(*src1->clone(), *src2->clone(), *dst->clone()).first);
}