diff options
Diffstat (limited to 'src/core/experimental')
-rw-r--r-- | src/core/experimental/PostOp.h | 49 |
1 files changed, 48 insertions, 1 deletions
diff --git a/src/core/experimental/PostOp.h b/src/core/experimental/PostOp.h index 7d62bd95e1..b29f67ec5c 100644 --- a/src/core/experimental/PostOp.h +++ b/src/core/experimental/PostOp.h @@ -116,6 +116,47 @@ public: ConvertPolicy _policy; }; +template <typename TensorRelatedT> +struct PostOpEltwisePRelu : public IPostOp<TensorRelatedT> +{ +public: + PostOpEltwisePRelu(TensorRelatedT alpha_param, int prev_dst_pos, ConvertPolicy policy) + : _alpha_param{ alpha_param }, + _prev_dst_pos{ prev_dst_pos }, + _policy{ policy } + { + } + // NOTE: PostOps do not own any resources pointed to by TensorRelatedT if it's a pointer type, thus allow shallow copy + ~PostOpEltwisePRelu() override = default; + PostOpEltwisePRelu(const PostOpEltwisePRelu &) = default; + PostOpEltwisePRelu &operator=(const PostOpEltwisePRelu &) = default; + PostOpEltwisePRelu(PostOpEltwisePRelu &&) = default; + PostOpEltwisePRelu &operator=(PostOpEltwisePRelu &&) = default; + int prev_dst_pos() const override + { + return _prev_dst_pos; + } + PostOpType type() const override + { + return PostOpType::Eltwise_PRelu; + } + std::vector<TensorRelatedT *> arguments() override + { + return { &_alpha_param }; + } + std::vector<const TensorRelatedT *> arguments() const override + { + return { &_alpha_param }; + } + std::unique_ptr<IPostOp<TensorRelatedT>> clone() const override + { + return std::make_unique<PostOpEltwisePRelu<TensorRelatedT>>(*this); + } + TensorRelatedT _alpha_param; + int _prev_dst_pos; + ConvertPolicy _policy; +}; + /** Transform a PostOpList of type FromTensorT to one of type ToTensorT */ template <typename FromTensorT, typename ToTensorT> PostOpList<ToTensorT> transform_post_op_list_arguments(const PostOpList<FromTensorT> &post_ops, std::function<ToTensorT(FromTensorT)> transform_arg) @@ -138,6 +179,12 @@ PostOpList<ToTensorT> transform_post_op_list_arguments(const PostOpList<FromTens transformed_post_ops.template push_back_op<PostOpEltwiseAdd<ToTensorT>>(transform_arg(_post_op->_addend), _post_op->_prev_dst_pos, _post_op->_policy); break; } + case PostOpType::Eltwise_PRelu: + { + const auto _post_op = utils::cast::polymorphic_downcast<const PostOpEltwisePRelu<FromTensorT> *>(post_op.get()); + transformed_post_ops.template push_back_op<PostOpEltwisePRelu<ToTensorT>>(transform_arg(_post_op->_alpha_param), _post_op->_prev_dst_pos, _post_op->_policy); + break; + } default: { ARM_COMPUTE_ERROR("Unsupported PostOpType"); @@ -168,4 +215,4 @@ PostOpTypeSequence get_post_op_sequence(const PostOpList<T> &post_ops) } // namespace experimental } // namespace arm_compute -#endif //ARM_COMPUTE_EXPERIMENTAL_POSTOP
\ No newline at end of file +#endif //ARM_COMPUTE_EXPERIMENTAL_POSTOP |