diff options
author | ramelg01 <ramy.elgammal@arm.com> | 2021-10-29 10:52:53 +0100 |
---|---|---|
committer | ramy.elgammal <ramy.elgammal@arm.com> | 2021-11-04 11:10:56 +0000 |
commit | 6049edadf0c89a026b3fcd1927ee7531d3c40278 (patch) | |
tree | c12fcea637e41cdb9e1f72dc734e4a87d0b31981 /src/core/experimental/PostOp.h | |
parent | 71cbd28b7cf5115b0451d43e5c84cce4ae4d8ec7 (diff) | |
download | ComputeLibrary-6049edadf0c89a026b3fcd1927ee7531d3c40278.tar.gz |
Add PRelu to supported PostOps in:
- ClGemmMatrixMultiplyReshapedKernel
- ClGemmMatrixMultiplyNativeKernel
- ClGemmMatrixMultiplyReshapedOnlyRhsKernel
Resolves: COMPMID-4713
Change-Id: I3adcb1b3d4af37ebcbc3bee19cc1845885d08600
Signed-off-by: Ramy Elgammal <ramy.elgammal@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6553
Reviewed-by: SiCong Li <sicong.li@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/experimental/PostOp.h')
-rw-r--r-- | src/core/experimental/PostOp.h | 49 |
1 files changed, 48 insertions, 1 deletions
diff --git a/src/core/experimental/PostOp.h b/src/core/experimental/PostOp.h index 7d62bd95e1..b29f67ec5c 100644 --- a/src/core/experimental/PostOp.h +++ b/src/core/experimental/PostOp.h @@ -116,6 +116,47 @@ public: ConvertPolicy _policy; }; +template <typename TensorRelatedT> +struct PostOpEltwisePRelu : public IPostOp<TensorRelatedT> +{ +public: + PostOpEltwisePRelu(TensorRelatedT alpha_param, int prev_dst_pos, ConvertPolicy policy) + : _alpha_param{ alpha_param }, + _prev_dst_pos{ prev_dst_pos }, + _policy{ policy } + { + } + // NOTE: PostOps do not own any resources pointed to by TensorRelatedT if it's a pointer type, thus allow shallow copy + ~PostOpEltwisePRelu() override = default; + PostOpEltwisePRelu(const PostOpEltwisePRelu &) = default; + PostOpEltwisePRelu &operator=(const PostOpEltwisePRelu &) = default; + PostOpEltwisePRelu(PostOpEltwisePRelu &&) = default; + PostOpEltwisePRelu &operator=(PostOpEltwisePRelu &&) = default; + int prev_dst_pos() const override + { + return _prev_dst_pos; + } + PostOpType type() const override + { + return PostOpType::Eltwise_PRelu; + } + std::vector<TensorRelatedT *> arguments() override + { + return { &_alpha_param }; + } + std::vector<const TensorRelatedT *> arguments() const override + { + return { &_alpha_param }; + } + std::unique_ptr<IPostOp<TensorRelatedT>> clone() const override + { + return std::make_unique<PostOpEltwisePRelu<TensorRelatedT>>(*this); + } + TensorRelatedT _alpha_param; + int _prev_dst_pos; + ConvertPolicy _policy; +}; + /** Transform a PostOpList of type FromTensorT to one of type ToTensorT */ template <typename FromTensorT, typename ToTensorT> PostOpList<ToTensorT> transform_post_op_list_arguments(const PostOpList<FromTensorT> &post_ops, std::function<ToTensorT(FromTensorT)> transform_arg) @@ -138,6 +179,12 @@ PostOpList<ToTensorT> transform_post_op_list_arguments(const PostOpList<FromTens transformed_post_ops.template push_back_op<PostOpEltwiseAdd<ToTensorT>>(transform_arg(_post_op->_addend), _post_op->_prev_dst_pos, _post_op->_policy); break; } + case PostOpType::Eltwise_PRelu: + { + const auto _post_op = utils::cast::polymorphic_downcast<const PostOpEltwisePRelu<FromTensorT> *>(post_op.get()); + transformed_post_ops.template push_back_op<PostOpEltwisePRelu<ToTensorT>>(transform_arg(_post_op->_alpha_param), _post_op->_prev_dst_pos, _post_op->_policy); + break; + } default: { ARM_COMPUTE_ERROR("Unsupported PostOpType"); @@ -168,4 +215,4 @@ PostOpTypeSequence get_post_op_sequence(const PostOpList<T> &post_ops) } // namespace experimental } // namespace arm_compute -#endif //ARM_COMPUTE_EXPERIMENTAL_POSTOP
\ No newline at end of file +#endif //ARM_COMPUTE_EXPERIMENTAL_POSTOP |