aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/reference/ElementwiseOperations.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/reference/ElementwiseOperations.cpp')
-rw-r--r--tests/validation/reference/ElementwiseOperations.cpp71
1 files changed, 42 insertions, 29 deletions
diff --git a/tests/validation/reference/ElementwiseOperations.cpp b/tests/validation/reference/ElementwiseOperations.cpp
index 2ffb0faa75..82f42a0c21 100644
--- a/tests/validation/reference/ElementwiseOperations.cpp
+++ b/tests/validation/reference/ElementwiseOperations.cpp
@@ -43,38 +43,51 @@ T arithm_op(ArithmeticOperation op, T src1, T src2, ConvertPolicy convert_policy
intermediate_type val;
- if(op == ArithmeticOperation::ADD)
+ switch(op)
{
- val = static_cast<intermediate_type>(src1) + static_cast<intermediate_type>(src2);
- }
- else if(op == ArithmeticOperation::SUB)
- {
- val = static_cast<intermediate_type>(src1) - static_cast<intermediate_type>(src2);
- }
- else if(op == ArithmeticOperation::MIN)
- {
- val = std::min(static_cast<intermediate_type>(src1), static_cast<intermediate_type>(src2));
- }
- else if(op == ArithmeticOperation::MAX)
- {
- val = std::max(static_cast<intermediate_type>(src1), static_cast<intermediate_type>(src2));
- }
- else if(op == ArithmeticOperation::SQUARED_DIFF)
- {
- intermediate_type tmp = (static_cast<intermediate_type>(src1) - static_cast<intermediate_type>(src2));
- val = tmp * tmp;
- }
- else if(op == ArithmeticOperation::DIV)
- {
- val = (static_cast<intermediate_type>(src1) / static_cast<intermediate_type>(src2));
- }
- else
- {
- ARM_COMPUTE_ERROR("Not handled");
+ case ArithmeticOperation::ADD:
+ {
+ val = static_cast<intermediate_type>(src1) + static_cast<intermediate_type>(src2);
+ break;
+ }
+ case ArithmeticOperation::SUB:
+ {
+ val = static_cast<intermediate_type>(src1) - static_cast<intermediate_type>(src2);
+ break;
+ }
+ case ArithmeticOperation::MIN:
+ {
+ val = std::min(static_cast<intermediate_type>(src1), static_cast<intermediate_type>(src2));
+ break;
+ }
+ case ArithmeticOperation::MAX:
+ {
+ val = std::max(static_cast<intermediate_type>(src1), static_cast<intermediate_type>(src2));
+ break;
+ }
+ case ArithmeticOperation::SQUARED_DIFF:
+ {
+ intermediate_type tmp = (static_cast<intermediate_type>(src1) - static_cast<intermediate_type>(src2));
+ val = tmp * tmp;
+ break;
+ }
+ case ArithmeticOperation::DIV:
+ {
+ val = (static_cast<intermediate_type>(src1) / static_cast<intermediate_type>(src2));
+ break;
+ }
+ case ArithmeticOperation::POWER:
+ {
+ val = std::pow(static_cast<intermediate_type>(src1), static_cast<intermediate_type>(src2));
+ break;
+ }
+ default:
+ {
+ ARM_COMPUTE_ERROR("Not handled");
+ }
}
-
T result;
- if(op == ArithmeticOperation::ADD || op == ArithmeticOperation::SUB || op == ArithmeticOperation::DIV)
+ if(op == ArithmeticOperation::ADD || op == ArithmeticOperation::SUB || op == ArithmeticOperation::DIV || op == ArithmeticOperation::POWER)
{
result = (convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T>(val) : static_cast<T>(val);
}