diff options
Diffstat (limited to 'tests/validation/reference/ElementwiseOperations.cpp')
-rw-r--r-- | tests/validation/reference/ElementwiseOperations.cpp | 71 |
1 files changed, 42 insertions, 29 deletions
diff --git a/tests/validation/reference/ElementwiseOperations.cpp b/tests/validation/reference/ElementwiseOperations.cpp index 2ffb0faa75..82f42a0c21 100644 --- a/tests/validation/reference/ElementwiseOperations.cpp +++ b/tests/validation/reference/ElementwiseOperations.cpp @@ -43,38 +43,51 @@ T arithm_op(ArithmeticOperation op, T src1, T src2, ConvertPolicy convert_policy intermediate_type val; - if(op == ArithmeticOperation::ADD) + switch(op) { - val = static_cast<intermediate_type>(src1) + static_cast<intermediate_type>(src2); - } - else if(op == ArithmeticOperation::SUB) - { - val = static_cast<intermediate_type>(src1) - static_cast<intermediate_type>(src2); - } - else if(op == ArithmeticOperation::MIN) - { - val = std::min(static_cast<intermediate_type>(src1), static_cast<intermediate_type>(src2)); - } - else if(op == ArithmeticOperation::MAX) - { - val = std::max(static_cast<intermediate_type>(src1), static_cast<intermediate_type>(src2)); - } - else if(op == ArithmeticOperation::SQUARED_DIFF) - { - intermediate_type tmp = (static_cast<intermediate_type>(src1) - static_cast<intermediate_type>(src2)); - val = tmp * tmp; - } - else if(op == ArithmeticOperation::DIV) - { - val = (static_cast<intermediate_type>(src1) / static_cast<intermediate_type>(src2)); - } - else - { - ARM_COMPUTE_ERROR("Not handled"); + case ArithmeticOperation::ADD: + { + val = static_cast<intermediate_type>(src1) + static_cast<intermediate_type>(src2); + break; + } + case ArithmeticOperation::SUB: + { + val = static_cast<intermediate_type>(src1) - static_cast<intermediate_type>(src2); + break; + } + case ArithmeticOperation::MIN: + { + val = std::min(static_cast<intermediate_type>(src1), static_cast<intermediate_type>(src2)); + break; + } + case ArithmeticOperation::MAX: + { + val = std::max(static_cast<intermediate_type>(src1), static_cast<intermediate_type>(src2)); + break; + } + case ArithmeticOperation::SQUARED_DIFF: + { + intermediate_type tmp = (static_cast<intermediate_type>(src1) - static_cast<intermediate_type>(src2)); + val = tmp * tmp; + break; + } + case ArithmeticOperation::DIV: + { + val = (static_cast<intermediate_type>(src1) / static_cast<intermediate_type>(src2)); + break; + } + case ArithmeticOperation::POWER: + { + val = std::pow(static_cast<intermediate_type>(src1), static_cast<intermediate_type>(src2)); + break; + } + default: + { + ARM_COMPUTE_ERROR("Not handled"); + } } - T result; - if(op == ArithmeticOperation::ADD || op == ArithmeticOperation::SUB || op == ArithmeticOperation::DIV) + if(op == ArithmeticOperation::ADD || op == ArithmeticOperation::SUB || op == ArithmeticOperation::DIV || op == ArithmeticOperation::POWER) { result = (convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T>(val) : static_cast<T>(val); } |