diff options
Diffstat (limited to 'tests/validation/reference')
-rw-r--r-- | tests/validation/reference/Logical.cpp | 31 | ||||
-rw-r--r-- | tests/validation/reference/Logical.h | 9 |
2 files changed, 19 insertions, 21 deletions
diff --git a/tests/validation/reference/Logical.cpp b/tests/validation/reference/Logical.cpp index 394525c392..9989ec841e 100644 --- a/tests/validation/reference/Logical.cpp +++ b/tests/validation/reference/Logical.cpp @@ -22,6 +22,8 @@ * SOFTWARE. */ #include "tests/validation/reference/Logical.h" +#include "src/core/KernelTypes.h" +#include "tests/framework/Asserts.h" namespace arm_compute { @@ -32,27 +34,30 @@ namespace validation namespace reference { template <typename T> -T logical_op(LogicalBinaryOperation op, T src1, T src2) +T logical_binary_op(arm_compute::kernels::LogicalOperation op, T src1, T src2) { switch(op) { - case LogicalBinaryOperation::AND: + case arm_compute::kernels::LogicalOperation::And: return src1 && src2; - case LogicalBinaryOperation::OR: + case arm_compute::kernels::LogicalOperation::Or: return src1 || src2; - case LogicalBinaryOperation::UNKNOWN: + // The following operators are either invalid or not binary operator + case arm_compute::kernels::LogicalOperation::Not: + /* fall through */ + case arm_compute::kernels::LogicalOperation::Unknown: + /* fall through */ default: - ARM_COMPUTE_ERROR_ON_MSG(true, "unknown logical binary operation is given"); + ARM_COMPUTE_ASSERT(true); } - return false; + return T{}; } template <size_t dim> struct BroadcastUnroll { template <typename T> - static void unroll(LogicalBinaryOperation op, - const SimpleTensor<T> &src1, const SimpleTensor<T> &src2, SimpleTensor<T> &dst, + static void unroll(arm_compute::kernels::LogicalOperation op, const SimpleTensor<T> &src1, const SimpleTensor<T> &src2, SimpleTensor<T> &dst, Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst) { const bool src1_is_broadcast = (src1.shape()[dim - 1] != dst.shape()[dim - 1]); @@ -79,10 +84,10 @@ template <> struct BroadcastUnroll<0> { template <typename T> - static void unroll(LogicalBinaryOperation op, const SimpleTensor<T> &src1, const SimpleTensor<T> &src2, SimpleTensor<T> &dst, + static void unroll(arm_compute::kernels::LogicalOperation op, const SimpleTensor<T> &src1, const SimpleTensor<T> &src2, SimpleTensor<T> &dst, Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst) { - dst[coord2index(dst.shape(), id_dst)] = logical_op(op, src1[coord2index(src1.shape(), id_src1)], src2[coord2index(src2.shape(), id_src2)]); + dst[coord2index(dst.shape(), id_dst)] = logical_binary_op(op, src1[coord2index(src1.shape(), id_src1)], src2[coord2index(src2.shape(), id_src2)]); } }; @@ -94,7 +99,7 @@ SimpleTensor<T> logical_or(const SimpleTensor<T> &src1, const SimpleTensor<T> &s Coordinates id_dst{}; SimpleTensor<T> dst{ TensorShape::broadcast_shape(src1.shape(), src2.shape()), src1.data_type() }; - BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(LogicalBinaryOperation::OR, src1, src2, dst, id_src1, id_src2, id_dst); + BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(arm_compute::kernels::LogicalOperation::Or, src1, src2, dst, id_src1, id_src2, id_dst); return dst; } @@ -107,7 +112,7 @@ SimpleTensor<T> logical_and(const SimpleTensor<T> &src1, const SimpleTensor<T> & Coordinates id_dst{}; SimpleTensor<T> dst{ TensorShape::broadcast_shape(src1.shape(), src2.shape()), src1.data_type() }; - BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(LogicalBinaryOperation::AND, src1, src2, dst, id_src1, id_src2, id_dst); + BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(arm_compute::kernels::LogicalOperation::And, src1, src2, dst, id_src1, id_src2, id_dst); return dst; } @@ -133,4 +138,4 @@ template SimpleTensor<uint8_t> logical_not(const SimpleTensor<uint8_t> &src1); } // namespace reference } // namespace validation } // namespace test -} // namespace arm_compute
\ No newline at end of file +} // namespace arm_compute diff --git a/tests/validation/reference/Logical.h b/tests/validation/reference/Logical.h index fb906b70b6..0d2bef9a43 100644 --- a/tests/validation/reference/Logical.h +++ b/tests/validation/reference/Logical.h @@ -34,13 +34,6 @@ namespace validation { namespace reference { -enum class LogicalBinaryOperation -{ - UNKNOWN = 0, - AND = 1, - OR = 2 -}; - template <typename T> SimpleTensor<T> logical_or(const SimpleTensor<T> &src1, const SimpleTensor<T> &src2); template <typename T> @@ -51,4 +44,4 @@ SimpleTensor<T> logical_not(const SimpleTensor<T> &src1); } // namespace validation } // namespace test } // namespace arm_compute -#endif /* ARM_COMPUTE_TEST_LOGICAL_H */
\ No newline at end of file +#endif /* ARM_COMPUTE_TEST_LOGICAL_H */ |