aboutsummaryrefslogtreecommitdiff
path: root/src/core/experimental/PostOp.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/experimental/PostOp.h')
-rw-r--r--src/core/experimental/PostOp.h49
1 files changed, 48 insertions, 1 deletions
diff --git a/src/core/experimental/PostOp.h b/src/core/experimental/PostOp.h
index 7d62bd95e1..b29f67ec5c 100644
--- a/src/core/experimental/PostOp.h
+++ b/src/core/experimental/PostOp.h
@@ -116,6 +116,47 @@ public:
ConvertPolicy _policy;
};
+template <typename TensorRelatedT>
+struct PostOpEltwisePRelu : public IPostOp<TensorRelatedT>
+{
+public:
+ PostOpEltwisePRelu(TensorRelatedT alpha_param, int prev_dst_pos, ConvertPolicy policy)
+ : _alpha_param{ alpha_param },
+ _prev_dst_pos{ prev_dst_pos },
+ _policy{ policy }
+ {
+ }
+ // NOTE: PostOps do not own any resources pointed to by TensorRelatedT if it's a pointer type, thus allow shallow copy
+ ~PostOpEltwisePRelu() override = default;
+ PostOpEltwisePRelu(const PostOpEltwisePRelu &) = default;
+ PostOpEltwisePRelu &operator=(const PostOpEltwisePRelu &) = default;
+ PostOpEltwisePRelu(PostOpEltwisePRelu &&) = default;
+ PostOpEltwisePRelu &operator=(PostOpEltwisePRelu &&) = default;
+ int prev_dst_pos() const override
+ {
+ return _prev_dst_pos;
+ }
+ PostOpType type() const override
+ {
+ return PostOpType::Eltwise_PRelu;
+ }
+ std::vector<TensorRelatedT *> arguments() override
+ {
+ return { &_alpha_param };
+ }
+ std::vector<const TensorRelatedT *> arguments() const override
+ {
+ return { &_alpha_param };
+ }
+ std::unique_ptr<IPostOp<TensorRelatedT>> clone() const override
+ {
+ return std::make_unique<PostOpEltwisePRelu<TensorRelatedT>>(*this);
+ }
+ TensorRelatedT _alpha_param;
+ int _prev_dst_pos;
+ ConvertPolicy _policy;
+};
+
/** Transform a PostOpList of type FromTensorT to one of type ToTensorT */
template <typename FromTensorT, typename ToTensorT>
PostOpList<ToTensorT> transform_post_op_list_arguments(const PostOpList<FromTensorT> &post_ops, std::function<ToTensorT(FromTensorT)> transform_arg)
@@ -138,6 +179,12 @@ PostOpList<ToTensorT> transform_post_op_list_arguments(const PostOpList<FromTens
transformed_post_ops.template push_back_op<PostOpEltwiseAdd<ToTensorT>>(transform_arg(_post_op->_addend), _post_op->_prev_dst_pos, _post_op->_policy);
break;
}
+ case PostOpType::Eltwise_PRelu:
+ {
+ const auto _post_op = utils::cast::polymorphic_downcast<const PostOpEltwisePRelu<FromTensorT> *>(post_op.get());
+ transformed_post_ops.template push_back_op<PostOpEltwisePRelu<ToTensorT>>(transform_arg(_post_op->_alpha_param), _post_op->_prev_dst_pos, _post_op->_policy);
+ break;
+ }
default:
{
ARM_COMPUTE_ERROR("Unsupported PostOpType");
@@ -168,4 +215,4 @@ PostOpTypeSequence get_post_op_sequence(const PostOpList<T> &post_ops)
} // namespace experimental
} // namespace arm_compute
-#endif //ARM_COMPUTE_EXPERIMENTAL_POSTOP \ No newline at end of file
+#endif //ARM_COMPUTE_EXPERIMENTAL_POSTOP