aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL')
-rw-r--r--src/core/CL/CLUtils.cpp14
-rw-r--r--src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/fp_elementwise_op_helpers.h12
2 files changed, 26 insertions, 0 deletions
diff --git a/src/core/CL/CLUtils.cpp b/src/core/CL/CLUtils.cpp
index 748b0f55a1..88b31c8349 100644
--- a/src/core/CL/CLUtils.cpp
+++ b/src/core/CL/CLUtils.cpp
@@ -151,6 +151,20 @@ void PostOpCLKernelUtils::set_post_ops_cl_build_options(CLBuildOptions &build_op
++arg_id;
}
}
+ else if(post_op->type() == experimental::PostOpType::Eltwise_PRelu)
+ {
+ size_t arg_id = 1;
+ const auto eltwise_op = slot_prefix + "_ELTWISE_OP=PRELU" + "_X_POS_" + support::cpp11::to_string(post_op->prev_dst_pos());
+ build_opts.add_option(eltwise_op);
+ for(const auto &tensor : post_op->arguments())
+ {
+ const auto height = slot_prefix + "_ELTWISE_ARG" + support::cpp11::to_string(arg_id) + "_HEIGHT=" + support::cpp11::to_string((*tensor)->dimension(1));
+ const auto width = slot_prefix + "_ELTWISE_ARG" + support::cpp11::to_string(arg_id) + "_WIDTH=" + support::cpp11::to_string((*tensor)->dimension(0));
+ build_opts.add_option(height);
+ build_opts.add_option(width);
+ ++arg_id;
+ }
+ }
}
}
diff --git a/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/fp_elementwise_op_helpers.h b/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/fp_elementwise_op_helpers.h
index 9ddf51a13c..b584251c2a 100644
--- a/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/fp_elementwise_op_helpers.h
+++ b/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/fp_elementwise_op_helpers.h
@@ -45,7 +45,13 @@
#if VEC_SIZE == 1
#define PRELU_X_POS_0(x, y) (x > 0 ? x : x * y)
#else // VEC_SIZE == 1
+
+#if defined(MIXED_PRECISION)
+#define PRELU_X_POS_0(x, y) (select(y * x, x, CONVERT((x > (DATA_TYPE_ACCUMULATOR)0), SELECT_VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, VEC_SIZE))))
+#else // MIXED_PRECISION
#define PRELU_X_POS_0(x, y) (select(y * x, x, CONVERT((x > (DATA_TYPE)0), SELECT_VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))))
+#endif // MIXED_PRECISION
+
#endif // VEC_SIZE == 1
#define DIV_X_POS_0(x, y) (x / y)
#define AND_X_POS_0(x, y) (CONVERT((x && y), VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)) & ((VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))1))
@@ -60,7 +66,13 @@
#if VEC_SIZE == 1
#define PRELU_X_POS_1(x, y) (y > 0 ? y : y * x)
#else // VEC_SIZE == 1
+
+#if defined(MIXED_PRECISION)
+#define PRELU_X_POS_1(x, y) (select(x * y, y, CONVERT((y > (DATA_TYPE_ACCUMULATOR)0), SELECT_VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, VEC_SIZE))))
+#else // MIXED_PRECISION
#define PRELU_X_POS_1(x, y) (select(x * y, y, CONVERT((y > (DATA_TYPE)0), SELECT_VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))))
+#endif // MIXED_PRECISION
+
#endif // VEC_SIZE == 1
#define DIV_X_POS_1(x, y) (y / x)
#define AND_X_POS_1(x, y) AND_X_POS_0(x, y)