aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLElementwiseOperationKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLElementwiseOperationKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLElementwiseOperationKernel.cpp69
1 files changed, 69 insertions, 0 deletions
diff --git a/src/core/CL/kernels/CLElementwiseOperationKernel.cpp b/src/core/CL/kernels/CLElementwiseOperationKernel.cpp
index efb3fe79e3..47439e15ab 100644
--- a/src/core/CL/kernels/CLElementwiseOperationKernel.cpp
+++ b/src/core/CL/kernels/CLElementwiseOperationKernel.cpp
@@ -26,6 +26,7 @@
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "src/core/CL/CLValidate.h"
+#include "src/core/common/Validate.h"
#include "src/core/helpers/AutoConfiguration.h"
#include "src/core/helpers/WindowHelpers.h"
#include "support/Cast.h"
@@ -220,6 +221,18 @@ std::pair<Status, Window> validate_and_configure_window_for_arithmetic_operators
return configure_window_arithmetic_common(output);
}
+std::pair<Status, Window> validate_and_configure_window_for_logical_binary_operators(ITensorInfo &input1, ITensorInfo &input2, ITensorInfo &output)
+{
+ const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(input1, input2);
+ const TensorShape &out_shape = broadcast_pair.first;
+
+ set_shape_if_empty(output, out_shape);
+ set_data_type_if_unknown(output, DataType::U8);
+
+ // The arithmetic utility functions can be share
+ return configure_window_arithmetic_common(output);
+}
+
std::pair<Status, Window> validate_and_configure_window_for_division(ITensorInfo &input1, ITensorInfo &input2, ITensorInfo &output)
{
const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(input1, input2);
@@ -319,6 +332,62 @@ void CLElementwiseOperationKernel::run_op(ITensorPack &tensors, const Window &wi
while(collapsed.slide_window_slice_3D(slice));
}
+/** Logical binary */
+void CLLogicalBinaryKernel::configure(const CLCompileContext &compile_context, kernels::LogicalOperation op, ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output)
+{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
+ ARM_COMPUTE_ERROR_THROW_ON(CLLogicalBinaryKernel::validate(op, input1, input2, output));
+ _op = op;
+ configure_common(compile_context, input1, input2, output);
+}
+
+Status CLLogicalBinaryKernel::validate(kernels::LogicalOperation op, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
+{
+ ARM_COMPUTE_UNUSED(op);
+ ARM_COMPUTE_ASSERT(op != kernels::LogicalOperation::Unknown && op != kernels::LogicalOperation::Not);
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
+
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2);
+
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_with_arithmetic_rules(*input1, *input2, *output));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_for_logical_binary_operators(*input1->clone(), *input2->clone(), *output->clone()).first);
+
+ return Status{};
+}
+
+std::string CLLogicalBinaryKernel::name()
+{
+ switch(_op)
+ {
+ case kernels::LogicalOperation::And:
+ return "AND";
+ case kernels::LogicalOperation::Or:
+ return "OR";
+ case kernels::LogicalOperation::Not:
+ /* fall through */
+ default:
+ ARM_COMPUTE_ASSERT(true);
+ }
+ return "";
+}
+
+std::pair<Status, Window> CLLogicalBinaryKernel::validate_and_configure_window(ITensorInfo &input1, ITensorInfo &input2, ITensorInfo &output)
+{
+ return validate_and_configure_window_for_logical_binary_operators(input1, input2, output);
+}
+
+CLBuildOptions CLLogicalBinaryKernel::generate_build_options(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
+{
+ // The arithmetic utility functions can be share
+ return generate_build_options_with_arithmetic_rules(input1, input2, output, name());
+}
+
+std::string CLLogicalBinaryKernel::generate_id_for_tuning(const std::string &kernel_name, const ITensorInfo &input1, const ITensorInfo &output)
+{
+ return generate_id_for_tuning_common(kernel_name, input1, output);
+}
+
/** Arithmetic operations with saturation*/
void CLSaturatedArithmeticOperationKernel::configure(ArithmeticOperation op, ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output, const ConvertPolicy &policy,