From 75eea338eb232ebdafa2fb84d22e711b5f964785 Mon Sep 17 00:00:00 2001 From: Sang-Hoon Park Date: Fri, 13 Nov 2020 13:44:13 +0000 Subject: COMPMID-3961: Add Logical OR/AND/NOT operator on CL Change-Id: I612aeed6affa17624fb9044964dd59c41a5c9888 Signed-off-by: Sang-Hoon Park Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4448 Reviewed-by: Pablo Marquez Tello Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins --- tests/validation/reference/Logical.cpp | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) (limited to 'tests/validation/reference/Logical.cpp') diff --git a/tests/validation/reference/Logical.cpp b/tests/validation/reference/Logical.cpp index 394525c392..9989ec841e 100644 --- a/tests/validation/reference/Logical.cpp +++ b/tests/validation/reference/Logical.cpp @@ -22,6 +22,8 @@ * SOFTWARE. */ #include "tests/validation/reference/Logical.h" +#include "src/core/KernelTypes.h" +#include "tests/framework/Asserts.h" namespace arm_compute { @@ -32,27 +34,30 @@ namespace validation namespace reference { template -T logical_op(LogicalBinaryOperation op, T src1, T src2) +T logical_binary_op(arm_compute::kernels::LogicalOperation op, T src1, T src2) { switch(op) { - case LogicalBinaryOperation::AND: + case arm_compute::kernels::LogicalOperation::And: return src1 && src2; - case LogicalBinaryOperation::OR: + case arm_compute::kernels::LogicalOperation::Or: return src1 || src2; - case LogicalBinaryOperation::UNKNOWN: + // The following operators are either invalid or not binary operator + case arm_compute::kernels::LogicalOperation::Not: + /* fall through */ + case arm_compute::kernels::LogicalOperation::Unknown: + /* fall through */ default: - ARM_COMPUTE_ERROR_ON_MSG(true, "unknown logical binary operation is given"); + ARM_COMPUTE_ASSERT(true); } - return false; + return T{}; } template struct BroadcastUnroll { template - static void unroll(LogicalBinaryOperation op, - const SimpleTensor &src1, const SimpleTensor &src2, SimpleTensor &dst, + static void unroll(arm_compute::kernels::LogicalOperation op, const SimpleTensor &src1, const SimpleTensor &src2, SimpleTensor &dst, Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst) { const bool src1_is_broadcast = (src1.shape()[dim - 1] != dst.shape()[dim - 1]); @@ -79,10 +84,10 @@ template <> struct BroadcastUnroll<0> { template - static void unroll(LogicalBinaryOperation op, const SimpleTensor &src1, const SimpleTensor &src2, SimpleTensor &dst, + static void unroll(arm_compute::kernels::LogicalOperation op, const SimpleTensor &src1, const SimpleTensor &src2, SimpleTensor &dst, Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst) { - dst[coord2index(dst.shape(), id_dst)] = logical_op(op, src1[coord2index(src1.shape(), id_src1)], src2[coord2index(src2.shape(), id_src2)]); + dst[coord2index(dst.shape(), id_dst)] = logical_binary_op(op, src1[coord2index(src1.shape(), id_src1)], src2[coord2index(src2.shape(), id_src2)]); } }; @@ -94,7 +99,7 @@ SimpleTensor logical_or(const SimpleTensor &src1, const SimpleTensor &s Coordinates id_dst{}; SimpleTensor dst{ TensorShape::broadcast_shape(src1.shape(), src2.shape()), src1.data_type() }; - BroadcastUnroll::unroll(LogicalBinaryOperation::OR, src1, src2, dst, id_src1, id_src2, id_dst); + BroadcastUnroll::unroll(arm_compute::kernels::LogicalOperation::Or, src1, src2, dst, id_src1, id_src2, id_dst); return dst; } @@ -107,7 +112,7 @@ SimpleTensor logical_and(const SimpleTensor &src1, const SimpleTensor & Coordinates id_dst{}; SimpleTensor dst{ TensorShape::broadcast_shape(src1.shape(), src2.shape()), src1.data_type() }; - BroadcastUnroll::unroll(LogicalBinaryOperation::AND, src1, src2, dst, id_src1, id_src2, id_dst); + BroadcastUnroll::unroll(arm_compute::kernels::LogicalOperation::And, src1, src2, dst, id_src1, id_src2, id_dst); return dst; } @@ -133,4 +138,4 @@ template SimpleTensor logical_not(const SimpleTensor &src1); } // namespace reference } // namespace validation } // namespace test -} // namespace arm_compute \ No newline at end of file +} // namespace arm_compute -- cgit v1.2.1