From 81e671ef4f2a8fb3128fba402610b9de28b57891 Mon Sep 17 00:00:00 2001 From: Usama Arif Date: Mon, 13 May 2019 13:33:14 +0100 Subject: COMPMID-2269: Implement POW operator for NEON Change-Id: I7135f665d89da3c24c9bbe00e991a64713a41d0e Signed-off-by: Usama Arif Reviewed-on: https://review.mlplatform.org/c/1128 Reviewed-by: Michalis Spyrou Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins --- .../validation/reference/ElementwiseOperations.cpp | 71 +++++++++++++--------- 1 file changed, 42 insertions(+), 29 deletions(-) (limited to 'tests/validation/reference/ElementwiseOperations.cpp') diff --git a/tests/validation/reference/ElementwiseOperations.cpp b/tests/validation/reference/ElementwiseOperations.cpp index 2ffb0faa75..82f42a0c21 100644 --- a/tests/validation/reference/ElementwiseOperations.cpp +++ b/tests/validation/reference/ElementwiseOperations.cpp @@ -43,38 +43,51 @@ T arithm_op(ArithmeticOperation op, T src1, T src2, ConvertPolicy convert_policy intermediate_type val; - if(op == ArithmeticOperation::ADD) + switch(op) { - val = static_cast(src1) + static_cast(src2); - } - else if(op == ArithmeticOperation::SUB) - { - val = static_cast(src1) - static_cast(src2); - } - else if(op == ArithmeticOperation::MIN) - { - val = std::min(static_cast(src1), static_cast(src2)); - } - else if(op == ArithmeticOperation::MAX) - { - val = std::max(static_cast(src1), static_cast(src2)); - } - else if(op == ArithmeticOperation::SQUARED_DIFF) - { - intermediate_type tmp = (static_cast(src1) - static_cast(src2)); - val = tmp * tmp; - } - else if(op == ArithmeticOperation::DIV) - { - val = (static_cast(src1) / static_cast(src2)); - } - else - { - ARM_COMPUTE_ERROR("Not handled"); + case ArithmeticOperation::ADD: + { + val = static_cast(src1) + static_cast(src2); + break; + } + case ArithmeticOperation::SUB: + { + val = static_cast(src1) - static_cast(src2); + break; + } + case ArithmeticOperation::MIN: + { + val = std::min(static_cast(src1), static_cast(src2)); + break; + } + case ArithmeticOperation::MAX: + { + val = std::max(static_cast(src1), static_cast(src2)); + break; + } + case ArithmeticOperation::SQUARED_DIFF: + { + intermediate_type tmp = (static_cast(src1) - static_cast(src2)); + val = tmp * tmp; + break; + } + case ArithmeticOperation::DIV: + { + val = (static_cast(src1) / static_cast(src2)); + break; + } + case ArithmeticOperation::POWER: + { + val = std::pow(static_cast(src1), static_cast(src2)); + break; + } + default: + { + ARM_COMPUTE_ERROR("Not handled"); + } } - T result; - if(op == ArithmeticOperation::ADD || op == ArithmeticOperation::SUB || op == ArithmeticOperation::DIV) + if(op == ArithmeticOperation::ADD || op == ArithmeticOperation::SUB || op == ArithmeticOperation::DIV || op == ArithmeticOperation::POWER) { result = (convert_policy == ConvertPolicy::SATURATE) ? saturate_cast(val) : static_cast(val); } -- cgit v1.2.1