aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/reference/PostOps.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/reference/PostOps.cpp')
-rw-r--r--tests/validation/reference/PostOps.cpp16
1 files changed, 16 insertions, 0 deletions
diff --git a/tests/validation/reference/PostOps.cpp b/tests/validation/reference/PostOps.cpp
index 1a8fb990c8..a81b1c1905 100644
--- a/tests/validation/reference/PostOps.cpp
+++ b/tests/validation/reference/PostOps.cpp
@@ -59,6 +59,22 @@ SimpleTensor<T> post_ops(const SimpleTensor<T> &a, experimental::PostOpList<Simp
dst = reference::arithmetic_operation(ArithmeticOperation::ADD, dst, _post_op->_addend, dst, _post_op->_policy);
break;
}
+ case experimental::PostOpType::Eltwise_PRelu:
+ {
+ const auto _post_op = utils::cast::polymorphic_downcast<const experimental::PostOpEltwisePRelu<SimpleTensor<T>> *>(post_op.get());
+
+ // If previous main operation output is the the first pRelu argument, then pass it as src1 parameter of the arithmetic operation
+ if(_post_op->_prev_dst_pos == 0)
+ {
+ dst = reference::arithmetic_operation(ArithmeticOperation::PRELU, dst, _post_op->_alpha_param, dst, _post_op->_policy);
+ }
+ // If previous main operation output is the the second pRelu argument, then pass it as src2 parameter of the arithmetic operation
+ else if(_post_op->_prev_dst_pos == 1)
+ {
+ dst = reference::arithmetic_operation(ArithmeticOperation::PRELU, _post_op->_alpha_param, dst, dst, _post_op->_policy);
+ }
+ break;
+ }
default:
{
ARM_COMPUTE_ERROR("Unsupported PostOpType");