aboutsummaryrefslogtreecommitdiff
path: root/src/core
diff options
context:
space:
mode:
authorramelg01 <ramy.elgammal@arm.com>2021-10-29 10:52:53 +0100
committerramy.elgammal <ramy.elgammal@arm.com>2021-11-04 11:10:56 +0000
commit6049edadf0c89a026b3fcd1927ee7531d3c40278 (patch)
treec12fcea637e41cdb9e1f72dc734e4a87d0b31981 /src/core
parent71cbd28b7cf5115b0451d43e5c84cce4ae4d8ec7 (diff)
downloadComputeLibrary-6049edadf0c89a026b3fcd1927ee7531d3c40278.tar.gz
Add PRelu to supported PostOps in:
- ClGemmMatrixMultiplyReshapedKernel - ClGemmMatrixMultiplyNativeKernel - ClGemmMatrixMultiplyReshapedOnlyRhsKernel Resolves: COMPMID-4713 Change-Id: I3adcb1b3d4af37ebcbc3bee19cc1845885d08600 Signed-off-by: Ramy Elgammal <ramy.elgammal@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6553 Reviewed-by: SiCong Li <sicong.li@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core')
-rw-r--r--src/core/CL/CLUtils.cpp14
-rw-r--r--src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/fp_elementwise_op_helpers.h12
-rw-r--r--src/core/experimental/PostOp.h49
3 files changed, 74 insertions, 1 deletions
diff --git a/src/core/CL/CLUtils.cpp b/src/core/CL/CLUtils.cpp
index 748b0f55a1..88b31c8349 100644
--- a/src/core/CL/CLUtils.cpp
+++ b/src/core/CL/CLUtils.cpp
@@ -151,6 +151,20 @@ void PostOpCLKernelUtils::set_post_ops_cl_build_options(CLBuildOptions &build_op
++arg_id;
}
}
+ else if(post_op->type() == experimental::PostOpType::Eltwise_PRelu)
+ {
+ size_t arg_id = 1;
+ const auto eltwise_op = slot_prefix + "_ELTWISE_OP=PRELU" + "_X_POS_" + support::cpp11::to_string(post_op->prev_dst_pos());
+ build_opts.add_option(eltwise_op);
+ for(const auto &tensor : post_op->arguments())
+ {
+ const auto height = slot_prefix + "_ELTWISE_ARG" + support::cpp11::to_string(arg_id) + "_HEIGHT=" + support::cpp11::to_string((*tensor)->dimension(1));
+ const auto width = slot_prefix + "_ELTWISE_ARG" + support::cpp11::to_string(arg_id) + "_WIDTH=" + support::cpp11::to_string((*tensor)->dimension(0));
+ build_opts.add_option(height);
+ build_opts.add_option(width);
+ ++arg_id;
+ }
+ }
}
}
diff --git a/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/fp_elementwise_op_helpers.h b/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/fp_elementwise_op_helpers.h
index 9ddf51a13c..b584251c2a 100644
--- a/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/fp_elementwise_op_helpers.h
+++ b/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/fp_elementwise_op_helpers.h
@@ -45,7 +45,13 @@
#if VEC_SIZE == 1
#define PRELU_X_POS_0(x, y) (x > 0 ? x : x * y)
#else // VEC_SIZE == 1
+
+#if defined(MIXED_PRECISION)
+#define PRELU_X_POS_0(x, y) (select(y * x, x, CONVERT((x > (DATA_TYPE_ACCUMULATOR)0), SELECT_VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, VEC_SIZE))))
+#else // MIXED_PRECISION
#define PRELU_X_POS_0(x, y) (select(y * x, x, CONVERT((x > (DATA_TYPE)0), SELECT_VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))))
+#endif // MIXED_PRECISION
+
#endif // VEC_SIZE == 1
#define DIV_X_POS_0(x, y) (x / y)
#define AND_X_POS_0(x, y) (CONVERT((x && y), VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)) & ((VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))1))
@@ -60,7 +66,13 @@
#if VEC_SIZE == 1
#define PRELU_X_POS_1(x, y) (y > 0 ? y : y * x)
#else // VEC_SIZE == 1
+
+#if defined(MIXED_PRECISION)
+#define PRELU_X_POS_1(x, y) (select(x * y, y, CONVERT((y > (DATA_TYPE_ACCUMULATOR)0), SELECT_VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, VEC_SIZE))))
+#else // MIXED_PRECISION
#define PRELU_X_POS_1(x, y) (select(x * y, y, CONVERT((y > (DATA_TYPE)0), SELECT_VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))))
+#endif // MIXED_PRECISION
+
#endif // VEC_SIZE == 1
#define DIV_X_POS_1(x, y) (y / x)
#define AND_X_POS_1(x, y) AND_X_POS_0(x, y)
diff --git a/src/core/experimental/PostOp.h b/src/core/experimental/PostOp.h
index 7d62bd95e1..b29f67ec5c 100644
--- a/src/core/experimental/PostOp.h
+++ b/src/core/experimental/PostOp.h
@@ -116,6 +116,47 @@ public:
ConvertPolicy _policy;
};
+template <typename TensorRelatedT>
+struct PostOpEltwisePRelu : public IPostOp<TensorRelatedT>
+{
+public:
+ PostOpEltwisePRelu(TensorRelatedT alpha_param, int prev_dst_pos, ConvertPolicy policy)
+ : _alpha_param{ alpha_param },
+ _prev_dst_pos{ prev_dst_pos },
+ _policy{ policy }
+ {
+ }
+ // NOTE: PostOps do not own any resources pointed to by TensorRelatedT if it's a pointer type, thus allow shallow copy
+ ~PostOpEltwisePRelu() override = default;
+ PostOpEltwisePRelu(const PostOpEltwisePRelu &) = default;
+ PostOpEltwisePRelu &operator=(const PostOpEltwisePRelu &) = default;
+ PostOpEltwisePRelu(PostOpEltwisePRelu &&) = default;
+ PostOpEltwisePRelu &operator=(PostOpEltwisePRelu &&) = default;
+ int prev_dst_pos() const override
+ {
+ return _prev_dst_pos;
+ }
+ PostOpType type() const override
+ {
+ return PostOpType::Eltwise_PRelu;
+ }
+ std::vector<TensorRelatedT *> arguments() override
+ {
+ return { &_alpha_param };
+ }
+ std::vector<const TensorRelatedT *> arguments() const override
+ {
+ return { &_alpha_param };
+ }
+ std::unique_ptr<IPostOp<TensorRelatedT>> clone() const override
+ {
+ return std::make_unique<PostOpEltwisePRelu<TensorRelatedT>>(*this);
+ }
+ TensorRelatedT _alpha_param;
+ int _prev_dst_pos;
+ ConvertPolicy _policy;
+};
+
/** Transform a PostOpList of type FromTensorT to one of type ToTensorT */
template <typename FromTensorT, typename ToTensorT>
PostOpList<ToTensorT> transform_post_op_list_arguments(const PostOpList<FromTensorT> &post_ops, std::function<ToTensorT(FromTensorT)> transform_arg)
@@ -138,6 +179,12 @@ PostOpList<ToTensorT> transform_post_op_list_arguments(const PostOpList<FromTens
transformed_post_ops.template push_back_op<PostOpEltwiseAdd<ToTensorT>>(transform_arg(_post_op->_addend), _post_op->_prev_dst_pos, _post_op->_policy);
break;
}
+ case PostOpType::Eltwise_PRelu:
+ {
+ const auto _post_op = utils::cast::polymorphic_downcast<const PostOpEltwisePRelu<FromTensorT> *>(post_op.get());
+ transformed_post_ops.template push_back_op<PostOpEltwisePRelu<ToTensorT>>(transform_arg(_post_op->_alpha_param), _post_op->_prev_dst_pos, _post_op->_policy);
+ break;
+ }
default:
{
ARM_COMPUTE_ERROR("Unsupported PostOpType");
@@ -168,4 +215,4 @@ PostOpTypeSequence get_post_op_sequence(const PostOpList<T> &post_ops)
} // namespace experimental
} // namespace arm_compute
-#endif //ARM_COMPUTE_EXPERIMENTAL_POSTOP \ No newline at end of file
+#endif //ARM_COMPUTE_EXPERIMENTAL_POSTOP