diff options
author | ramelg01 <ramy.elgammal@arm.com> | 2021-10-29 10:52:53 +0100 |
---|---|---|
committer | ramy.elgammal <ramy.elgammal@arm.com> | 2021-11-04 11:10:56 +0000 |
commit | 6049edadf0c89a026b3fcd1927ee7531d3c40278 (patch) | |
tree | c12fcea637e41cdb9e1f72dc734e4a87d0b31981 /tests/validation/reference | |
parent | 71cbd28b7cf5115b0451d43e5c84cce4ae4d8ec7 (diff) | |
download | ComputeLibrary-6049edadf0c89a026b3fcd1927ee7531d3c40278.tar.gz |
Add PRelu to supported PostOps in:
- ClGemmMatrixMultiplyReshapedKernel
- ClGemmMatrixMultiplyNativeKernel
- ClGemmMatrixMultiplyReshapedOnlyRhsKernel
Resolves: COMPMID-4713
Change-Id: I3adcb1b3d4af37ebcbc3bee19cc1845885d08600
Signed-off-by: Ramy Elgammal <ramy.elgammal@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6553
Reviewed-by: SiCong Li <sicong.li@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/reference')
-rw-r--r-- | tests/validation/reference/PostOps.cpp | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/tests/validation/reference/PostOps.cpp b/tests/validation/reference/PostOps.cpp index 1a8fb990c8..a81b1c1905 100644 --- a/tests/validation/reference/PostOps.cpp +++ b/tests/validation/reference/PostOps.cpp @@ -59,6 +59,22 @@ SimpleTensor<T> post_ops(const SimpleTensor<T> &a, experimental::PostOpList<Simp dst = reference::arithmetic_operation(ArithmeticOperation::ADD, dst, _post_op->_addend, dst, _post_op->_policy); break; } + case experimental::PostOpType::Eltwise_PRelu: + { + const auto _post_op = utils::cast::polymorphic_downcast<const experimental::PostOpEltwisePRelu<SimpleTensor<T>> *>(post_op.get()); + + // If previous main operation output is the the first pRelu argument, then pass it as src1 parameter of the arithmetic operation + if(_post_op->_prev_dst_pos == 0) + { + dst = reference::arithmetic_operation(ArithmeticOperation::PRELU, dst, _post_op->_alpha_param, dst, _post_op->_policy); + } + // If previous main operation output is the the second pRelu argument, then pass it as src2 parameter of the arithmetic operation + else if(_post_op->_prev_dst_pos == 1) + { + dst = reference::arithmetic_operation(ArithmeticOperation::PRELU, _post_op->_alpha_param, dst, dst, _post_op->_policy); + } + break; + } default: { ARM_COMPUTE_ERROR("Unsupported PostOpType"); |