diff options
Diffstat (limited to 'src/core/CL/kernels/CLElementwiseOperationKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLElementwiseOperationKernel.cpp | 28 |
1 files changed, 20 insertions, 8 deletions
diff --git a/src/core/CL/kernels/CLElementwiseOperationKernel.cpp b/src/core/CL/kernels/CLElementwiseOperationKernel.cpp index 1ac35a286f..0f2e26f186 100644 --- a/src/core/CL/kernels/CLElementwiseOperationKernel.cpp +++ b/src/core/CL/kernels/CLElementwiseOperationKernel.cpp @@ -231,7 +231,7 @@ std::pair<Status, Window> validate_and_configure_window_for_division(ITensorInfo } // namespace CLElementwiseOperationKernel::CLElementwiseOperationKernel() - : _input1(nullptr), _input2(nullptr), _output(nullptr) + : _act_info(), _input1(nullptr), _input2(nullptr), _output(nullptr) { } @@ -256,6 +256,12 @@ void CLElementwiseOperationKernel::configure_common(const ICLTensor *input1, con // Set kernel build options CLBuildOptions build_opts = generate_build_options(*input1->info(), *input2->info(), *output->info()); + if(_act_info.enabled()) + { + build_opts.add_option("-DACTIVATION_TYPE=" + lower_string(string_from_activation_func(_act_info.activation()))); + build_opts.add_option("-DA_VAL=" + float_to_string_with_full_precision(_act_info.a())); + build_opts.add_option("-DB_VAL=" + float_to_string_with_full_precision(_act_info.b())); + } // Create kernel _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options())); @@ -320,19 +326,23 @@ BorderSize CLElementwiseOperationKernel::border_size() const /** Arithmetic operations with saturation*/ -void CLSaturatedArithmeticOperationKernel::configure(ArithmeticOperation op, const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, const ConvertPolicy &policy) +void CLSaturatedArithmeticOperationKernel::configure(ArithmeticOperation op, const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, const ConvertPolicy &policy, + const ActivationLayerInfo &act_info) { - _policy = policy; - _op = op; + _policy = policy; + _op = op; + _act_info = act_info; configure_common(input1, input2, output); } -Status CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation op, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, const ConvertPolicy &policy) +Status CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation op, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, const ConvertPolicy &policy, + const ActivationLayerInfo &act_info) { ARM_COMPUTE_UNUSED(op, policy); ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output); ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_with_arithmetic_rules(*input1, *input2, *output)); ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_for_arithmetic_operators(*input1->clone(), *input2->clone(), *output->clone()).first); + ARM_COMPUTE_RETURN_ERROR_ON(act_info.enabled() && !is_data_type_float(output->data_type())); return Status{}; } @@ -369,13 +379,14 @@ std::string CLSaturatedArithmeticOperationKernel::name() /** Arithmetic operations*/ -void CLArithmeticOperationKernel::configure(ArithmeticOperation op, const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output) +void CLArithmeticOperationKernel::configure(ArithmeticOperation op, const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, const ActivationLayerInfo &act_info) { - _op = op; + _op = op; + _act_info = act_info; configure_common(input1, input2, output); } -Status CLArithmeticOperationKernel::validate(ArithmeticOperation op, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output) +Status CLArithmeticOperationKernel::validate(ArithmeticOperation op, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, const ActivationLayerInfo &act_info) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output); if(op == ArithmeticOperation::DIV || op == ArithmeticOperation::POWER) @@ -389,6 +400,7 @@ Status CLArithmeticOperationKernel::validate(ArithmeticOperation op, const ITens ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_with_arithmetic_rules(*input1, *input2, *output)); ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_for_arithmetic_operators(*input1->clone(), *input2->clone(), *output->clone()).first); } + ARM_COMPUTE_RETURN_ERROR_ON(act_info.enabled() && !is_data_type_float(output->data_type())); return Status{}; } |