aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorramelg01 <ramy.elgammal@arm.com>2021-10-29 10:52:53 +0100
committerramy.elgammal <ramy.elgammal@arm.com>2021-11-04 11:10:56 +0000
commit6049edadf0c89a026b3fcd1927ee7531d3c40278 (patch)
treec12fcea637e41cdb9e1f72dc734e4a87d0b31981
parent71cbd28b7cf5115b0451d43e5c84cce4ae4d8ec7 (diff)
downloadComputeLibrary-6049edadf0c89a026b3fcd1927ee7531d3c40278.tar.gz
Add PRelu to supported PostOps in:
- ClGemmMatrixMultiplyReshapedKernel - ClGemmMatrixMultiplyNativeKernel - ClGemmMatrixMultiplyReshapedOnlyRhsKernel Resolves: COMPMID-4713 Change-Id: I3adcb1b3d4af37ebcbc3bee19cc1845885d08600 Signed-off-by: Ramy Elgammal <ramy.elgammal@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6553 Reviewed-by: SiCong Li <sicong.li@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/core/experimental/IPostOp.h3
-rw-r--r--src/core/CL/CLUtils.cpp14
-rw-r--r--src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/fp_elementwise_op_helpers.h12
-rw-r--r--src/core/experimental/PostOp.h49
-rw-r--r--src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.cpp10
-rw-r--r--src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.cpp11
-rw-r--r--src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.cpp10
-rw-r--r--tests/validation/CL/GEMMMatrixMultiplyNative.cpp29
-rw-r--r--tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp28
-rw-r--r--tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp27
-rw-r--r--tests/validation/reference/PostOps.cpp16
-rw-r--r--utils/TypePrinter.h11
12 files changed, 211 insertions, 9 deletions
diff --git a/arm_compute/core/experimental/IPostOp.h b/arm_compute/core/experimental/IPostOp.h
index 4fac4c88e9..178c83aa75 100644
--- a/arm_compute/core/experimental/IPostOp.h
+++ b/arm_compute/core/experimental/IPostOp.h
@@ -37,6 +37,7 @@ enum class PostOpType
{
Activation,
Eltwise_Add,
+ Eltwise_PRelu
};
/** An ordered sequence of type of Post Ops */
using PostOpTypeSequence = std::vector<PostOpType>;
@@ -167,4 +168,4 @@ private:
} // namespace experimental
} // namespace arm_compute
-#endif //ARM_COMPUTE_EXPERIMENTAL_IPOSTOP \ No newline at end of file
+#endif //ARM_COMPUTE_EXPERIMENTAL_IPOSTOP
diff --git a/src/core/CL/CLUtils.cpp b/src/core/CL/CLUtils.cpp
index 748b0f55a1..88b31c8349 100644
--- a/src/core/CL/CLUtils.cpp
+++ b/src/core/CL/CLUtils.cpp
@@ -151,6 +151,20 @@ void PostOpCLKernelUtils::set_post_ops_cl_build_options(CLBuildOptions &build_op
++arg_id;
}
}
+ else if(post_op->type() == experimental::PostOpType::Eltwise_PRelu)
+ {
+ size_t arg_id = 1;
+ const auto eltwise_op = slot_prefix + "_ELTWISE_OP=PRELU" + "_X_POS_" + support::cpp11::to_string(post_op->prev_dst_pos());
+ build_opts.add_option(eltwise_op);
+ for(const auto &tensor : post_op->arguments())
+ {
+ const auto height = slot_prefix + "_ELTWISE_ARG" + support::cpp11::to_string(arg_id) + "_HEIGHT=" + support::cpp11::to_string((*tensor)->dimension(1));
+ const auto width = slot_prefix + "_ELTWISE_ARG" + support::cpp11::to_string(arg_id) + "_WIDTH=" + support::cpp11::to_string((*tensor)->dimension(0));
+ build_opts.add_option(height);
+ build_opts.add_option(width);
+ ++arg_id;
+ }
+ }
}
}
diff --git a/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/fp_elementwise_op_helpers.h b/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/fp_elementwise_op_helpers.h
index 9ddf51a13c..b584251c2a 100644
--- a/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/fp_elementwise_op_helpers.h
+++ b/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/fp_elementwise_op_helpers.h
@@ -45,7 +45,13 @@
#if VEC_SIZE == 1
#define PRELU_X_POS_0(x, y) (x > 0 ? x : x * y)
#else // VEC_SIZE == 1
+
+#if defined(MIXED_PRECISION)
+#define PRELU_X_POS_0(x, y) (select(y * x, x, CONVERT((x > (DATA_TYPE_ACCUMULATOR)0), SELECT_VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, VEC_SIZE))))
+#else // MIXED_PRECISION
#define PRELU_X_POS_0(x, y) (select(y * x, x, CONVERT((x > (DATA_TYPE)0), SELECT_VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))))
+#endif // MIXED_PRECISION
+
#endif // VEC_SIZE == 1
#define DIV_X_POS_0(x, y) (x / y)
#define AND_X_POS_0(x, y) (CONVERT((x && y), VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)) & ((VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))1))
@@ -60,7 +66,13 @@
#if VEC_SIZE == 1
#define PRELU_X_POS_1(x, y) (y > 0 ? y : y * x)
#else // VEC_SIZE == 1
+
+#if defined(MIXED_PRECISION)
+#define PRELU_X_POS_1(x, y) (select(x * y, y, CONVERT((y > (DATA_TYPE_ACCUMULATOR)0), SELECT_VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, VEC_SIZE))))
+#else // MIXED_PRECISION
#define PRELU_X_POS_1(x, y) (select(x * y, y, CONVERT((y > (DATA_TYPE)0), SELECT_VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))))
+#endif // MIXED_PRECISION
+
#endif // VEC_SIZE == 1
#define DIV_X_POS_1(x, y) (y / x)
#define AND_X_POS_1(x, y) AND_X_POS_0(x, y)
diff --git a/src/core/experimental/PostOp.h b/src/core/experimental/PostOp.h
index 7d62bd95e1..b29f67ec5c 100644
--- a/src/core/experimental/PostOp.h
+++ b/src/core/experimental/PostOp.h
@@ -116,6 +116,47 @@ public:
ConvertPolicy _policy;
};
+template <typename TensorRelatedT>
+struct PostOpEltwisePRelu : public IPostOp<TensorRelatedT>
+{
+public:
+ PostOpEltwisePRelu(TensorRelatedT alpha_param, int prev_dst_pos, ConvertPolicy policy)
+ : _alpha_param{ alpha_param },
+ _prev_dst_pos{ prev_dst_pos },
+ _policy{ policy }
+ {
+ }
+ // NOTE: PostOps do not own any resources pointed to by TensorRelatedT if it's a pointer type, thus allow shallow copy
+ ~PostOpEltwisePRelu() override = default;
+ PostOpEltwisePRelu(const PostOpEltwisePRelu &) = default;
+ PostOpEltwisePRelu &operator=(const PostOpEltwisePRelu &) = default;
+ PostOpEltwisePRelu(PostOpEltwisePRelu &&) = default;
+ PostOpEltwisePRelu &operator=(PostOpEltwisePRelu &&) = default;
+ int prev_dst_pos() const override
+ {
+ return _prev_dst_pos;
+ }
+ PostOpType type() const override
+ {
+ return PostOpType::Eltwise_PRelu;
+ }
+ std::vector<TensorRelatedT *> arguments() override
+ {
+ return { &_alpha_param };
+ }
+ std::vector<const TensorRelatedT *> arguments() const override
+ {
+ return { &_alpha_param };
+ }
+ std::unique_ptr<IPostOp<TensorRelatedT>> clone() const override
+ {
+ return std::make_unique<PostOpEltwisePRelu<TensorRelatedT>>(*this);
+ }
+ TensorRelatedT _alpha_param;
+ int _prev_dst_pos;
+ ConvertPolicy _policy;
+};
+
/** Transform a PostOpList of type FromTensorT to one of type ToTensorT */
template <typename FromTensorT, typename ToTensorT>
PostOpList<ToTensorT> transform_post_op_list_arguments(const PostOpList<FromTensorT> &post_ops, std::function<ToTensorT(FromTensorT)> transform_arg)
@@ -138,6 +179,12 @@ PostOpList<ToTensorT> transform_post_op_list_arguments(const PostOpList<FromTens
transformed_post_ops.template push_back_op<PostOpEltwiseAdd<ToTensorT>>(transform_arg(_post_op->_addend), _post_op->_prev_dst_pos, _post_op->_policy);
break;
}
+ case PostOpType::Eltwise_PRelu:
+ {
+ const auto _post_op = utils::cast::polymorphic_downcast<const PostOpEltwisePRelu<FromTensorT> *>(post_op.get());
+ transformed_post_ops.template push_back_op<PostOpEltwisePRelu<ToTensorT>>(transform_arg(_post_op->_alpha_param), _post_op->_prev_dst_pos, _post_op->_policy);
+ break;
+ }
default:
{
ARM_COMPUTE_ERROR("Unsupported PostOpType");
@@ -168,4 +215,4 @@ PostOpTypeSequence get_post_op_sequence(const PostOpList<T> &post_ops)
} // namespace experimental
} // namespace arm_compute
-#endif //ARM_COMPUTE_EXPERIMENTAL_POSTOP \ No newline at end of file
+#endif //ARM_COMPUTE_EXPERIMENTAL_POSTOP
diff --git a/src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.cpp b/src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.cpp
index c3efc24fa9..350312f6fe 100644
--- a/src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.cpp
+++ b/src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.cpp
@@ -56,10 +56,18 @@ const auto post_op_utils = experimental::PostOpCLKernelUtils(
// PostOp sequence -> {Kernel Postfix, PostOp Slots}
{ {}, { "", {} } },
{ { experimental::PostOpType::Activation }, { "", { 1 } } },
+
{ { experimental::PostOpType::Eltwise_Add }, { "_post_act_eltwise_op_act", { 2 } } },
+ { { experimental::PostOpType::Eltwise_PRelu }, { "_post_act_eltwise_op_act", { 2 } } },
+
{ { experimental::PostOpType::Activation, experimental::PostOpType::Eltwise_Add }, { "_post_act_eltwise_op_act", { 1, 2 } } },
+ { { experimental::PostOpType::Activation, experimental::PostOpType::Eltwise_PRelu }, { "_post_act_eltwise_op_act", { 1, 2 } } },
+
{ { experimental::PostOpType::Eltwise_Add, experimental::PostOpType::Activation }, { "_post_act_eltwise_op_act", { 2, 3 } } },
- { { experimental::PostOpType::Activation, experimental::PostOpType::Eltwise_Add, experimental::PostOpType::Activation }, { "_post_act_eltwise_op_act", { 1, 2, 3 } } }
+ { { experimental::PostOpType::Eltwise_PRelu, experimental::PostOpType::Activation }, { "_post_act_eltwise_op_act", { 2, 3 } } },
+
+ { { experimental::PostOpType::Activation, experimental::PostOpType::Eltwise_Add, experimental::PostOpType::Activation }, { "_post_act_eltwise_op_act", { 1, 2, 3 } } },
+ { { experimental::PostOpType::Activation, experimental::PostOpType::Eltwise_PRelu, experimental::PostOpType::Activation }, { "_post_act_eltwise_op_act", { 1, 2, 3 } } }
});
Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info,
diff --git a/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.cpp b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.cpp
index 8ee72d3f03..52c8cd4fd5 100644
--- a/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.cpp
+++ b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.cpp
@@ -57,11 +57,20 @@ const auto post_op_utils = experimental::PostOpCLKernelUtils(
// PostOp sequence -> {Kernel Postfix, PostOp Slots}
{ {}, { "", {} } },
{ { experimental::PostOpType::Activation }, { "", { 1 } } },
+
{ { experimental::PostOpType::Eltwise_Add }, { "_post_act_eltwise_op_act", { 2 } } },
+ { { experimental::PostOpType::Eltwise_PRelu }, { "_post_act_eltwise_op_act", { 2 } } },
+
{ { experimental::PostOpType::Activation, experimental::PostOpType::Eltwise_Add }, { "_post_act_eltwise_op_act", { 1, 2 } } },
+ { { experimental::PostOpType::Activation, experimental::PostOpType::Eltwise_PRelu }, { "_post_act_eltwise_op_act", { 1, 2 } } },
+
{ { experimental::PostOpType::Eltwise_Add, experimental::PostOpType::Activation }, { "_post_act_eltwise_op_act", { 2, 3 } } },
- { { experimental::PostOpType::Activation, experimental::PostOpType::Eltwise_Add, experimental::PostOpType::Activation }, { "_post_act_eltwise_op_act", { 1, 2, 3 } } }
+ { { experimental::PostOpType::Eltwise_PRelu, experimental::PostOpType::Activation }, { "_post_act_eltwise_op_act", { 2, 3 } } },
+
+ { { experimental::PostOpType::Activation, experimental::PostOpType::Eltwise_Add, experimental::PostOpType::Activation }, { "_post_act_eltwise_op_act", { 1, 2, 3 } } },
+ { { experimental::PostOpType::Activation, experimental::PostOpType::Eltwise_PRelu, experimental::PostOpType::Activation }, { "_post_act_eltwise_op_act", { 1, 2, 3 } } }
});
+
Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info,
const GEMMRHSMatrixInfo &rhs_info,
const GEMMKernelInfo &gemm_info)
diff --git a/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.cpp b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.cpp
index 260ed134e4..633c2630cf 100644
--- a/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.cpp
+++ b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.cpp
@@ -50,10 +50,18 @@ const auto post_op_utils = experimental::PostOpCLKernelUtils(
// PostOp sequence -> {Kernel Postfix, PostOp Slots}
{ {}, { "", {} } },
{ { experimental::PostOpType::Activation }, { "", { 1 } } },
+
{ { experimental::PostOpType::Eltwise_Add }, { "_post_act_eltwise_op_act", { 2 } } },
+ { { experimental::PostOpType::Eltwise_PRelu }, { "_post_act_eltwise_op_act", { 2 } } },
+
{ { experimental::PostOpType::Activation, experimental::PostOpType::Eltwise_Add }, { "_post_act_eltwise_op_act", { 1, 2 } } },
+ { { experimental::PostOpType::Activation, experimental::PostOpType::Eltwise_PRelu }, { "_post_act_eltwise_op_act", { 1, 2 } } },
+
{ { experimental::PostOpType::Eltwise_Add, experimental::PostOpType::Activation }, { "_post_act_eltwise_op_act", { 2, 3 } } },
- { { experimental::PostOpType::Activation, experimental::PostOpType::Eltwise_Add, experimental::PostOpType::Activation }, { "_post_act_eltwise_op_act", { 1, 2, 3 } } }
+ { { experimental::PostOpType::Eltwise_PRelu, experimental::PostOpType::Activation }, { "_post_act_eltwise_op_act", { 2, 3 } } },
+
+ { { experimental::PostOpType::Activation, experimental::PostOpType::Eltwise_Add, experimental::PostOpType::Activation }, { "_post_act_eltwise_op_act", { 1, 2, 3 } } },
+ { { experimental::PostOpType::Activation, experimental::PostOpType::Eltwise_PRelu, experimental::PostOpType::Activation }, { "_post_act_eltwise_op_act", { 1, 2, 3 } } }
});
Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float alpha, float beta,
diff --git a/tests/validation/CL/GEMMMatrixMultiplyNative.cpp b/tests/validation/CL/GEMMMatrixMultiplyNative.cpp
index e3f151a2ca..54e9d32afc 100644
--- a/tests/validation/CL/GEMMMatrixMultiplyNative.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiplyNative.cpp
@@ -179,13 +179,38 @@ experimental::PostOpList<PostOpArgBroadcast> post_ops_3()
ConvertPolicy::SATURATE);
return post_ops;
}
-
+// To test that the output of the main op is the first parameter in prelu post op
+experimental::PostOpList<PostOpArgBroadcast> post_ops_4()
+{
+ experimental::PostOpList<PostOpArgBroadcast> post_ops{};
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F});
+ post_ops.push_back_op<experimental::PostOpEltwisePRelu<PostOpArgBroadcast>>(
+ std::make_tuple(false, false, true), // If true, broadcast in corresponding dim: 0, 1 or 2
+ 0,
+ ConvertPolicy::SATURATE);
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F});
+ return post_ops;
+}
+// To test that the output of the main op is the second parameter in prelu post op i.e. it is the alpha_param
+experimental::PostOpList<PostOpArgBroadcast> post_ops_5()
+{
+ experimental::PostOpList<PostOpArgBroadcast> post_ops{};
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F});
+ post_ops.push_back_op<experimental::PostOpEltwisePRelu<PostOpArgBroadcast>>(
+ std::make_tuple(false, false, false), // If true, broadcast in corresponding dim: 0, 1 or 2
+ 1,
+ ConvertPolicy::SATURATE);
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F});
+ return post_ops;
+}
/** Different Post Op Lists */
const auto post_op_lists = framework::dataset::make("post_op_lists", {
post_ops_1(),
post_ops_2(),
post_ops_3(),
-} );
+ post_ops_4(),
+ post_ops_5()
+ } );
bool is_post_op_list_valid(unsigned int m, unsigned int n, unsigned int k, unsigned int batch, DataType data_type, const experimental::PostOpList<ITensorInfo*>& post_ops)
{
diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
index a598780bf6..bedd0f5bfb 100644
--- a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
@@ -216,11 +216,37 @@ experimental::PostOpList<PostOpArgBroadcast> post_ops_3()
ConvertPolicy::SATURATE);
return post_ops;
}
+// To test that the output of the main op is the first parameter in prelu post op
+experimental::PostOpList<PostOpArgBroadcast> post_ops_4()
+{
+ experimental::PostOpList<PostOpArgBroadcast> post_ops{};
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F});
+ post_ops.push_back_op<experimental::PostOpEltwisePRelu<PostOpArgBroadcast>>(
+ std::make_tuple(false, false, true), // If true, broadcast in corresponding dim: 0, 1 or 2
+ 0,
+ ConvertPolicy::SATURATE);
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F});
+ return post_ops;
+}
+// To test that the output of the main op is the second parameter in prelu post op i.e. it is the alpha_param
+experimental::PostOpList<PostOpArgBroadcast> post_ops_5()
+{
+ experimental::PostOpList<PostOpArgBroadcast> post_ops{};
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F});
+ post_ops.push_back_op<experimental::PostOpEltwisePRelu<PostOpArgBroadcast>>(
+ std::make_tuple(false, false, false), // If true, broadcast in corresponding dim: 0, 1 or 2
+ 1,
+ ConvertPolicy::SATURATE);
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F});
+ return post_ops;
+}
/** Different Post Op Lists */
const auto post_op_lists = framework::dataset::make("post_op_lists", {
post_ops_1(),
post_ops_2(),
post_ops_3(),
+ post_ops_4(),
+ post_ops_5()
} );
bool is_post_op_list_valid(unsigned int m, unsigned int n, unsigned int k, unsigned int batch, DataType data_type, const experimental::PostOpList<ITensorInfo*>& post_ops)
@@ -479,7 +505,7 @@ TEST_CASE(BroadcastInXDimOnly, framework::DatasetMode::ALL)
ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == false, framework::LogLevel::ERRORS);
}
-TEST_SUITE_END() // Invalid
+TEST_SUITE_END() // Invalid
TEST_SUITE(Valid)
TEST_CASE(EmptyPostOpList, framework::DatasetMode::ALL)
{
diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
index ca8b21cd0d..4c482b49aa 100644
--- a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
@@ -196,12 +196,37 @@ experimental::PostOpList<PostOpArgBroadcast> post_ops_3()
ConvertPolicy::SATURATE);
return post_ops;
}
-
+// To test that the output of the main op is the first parameter in prelu post op
+experimental::PostOpList<PostOpArgBroadcast> post_ops_4()
+{
+ experimental::PostOpList<PostOpArgBroadcast> post_ops{};
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F});
+ post_ops.push_back_op<experimental::PostOpEltwisePRelu<PostOpArgBroadcast>>(
+ std::make_tuple(false, false, true), // If true, broadcast in corresponding dim: 0, 1 or 2
+ 0,
+ ConvertPolicy::SATURATE);
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F});
+ return post_ops;
+}
+// To test that the output of the main op is the second parameter in prelu post op i.e. it is the alpha_param
+experimental::PostOpList<PostOpArgBroadcast> post_ops_5()
+{
+ experimental::PostOpList<PostOpArgBroadcast> post_ops{};
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F});
+ post_ops.push_back_op<experimental::PostOpEltwisePRelu<PostOpArgBroadcast>>(
+ std::make_tuple(false, false, false), // If true, broadcast in corresponding dim: 0, 1 or 2
+ 1,
+ ConvertPolicy::SATURATE);
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F});
+ return post_ops;
+}
/** Different Post Op Lists */
const auto post_op_lists = framework::dataset::make("post_op_lists", {
post_ops_1(),
post_ops_2(),
post_ops_3(),
+ post_ops_4(),
+ post_ops_5()
} );
bool is_post_op_list_valid(unsigned int m, unsigned int n, unsigned int k, unsigned int batch, DataType data_type, const experimental::PostOpList<ITensorInfo*>& post_ops)
diff --git a/tests/validation/reference/PostOps.cpp b/tests/validation/reference/PostOps.cpp
index 1a8fb990c8..a81b1c1905 100644
--- a/tests/validation/reference/PostOps.cpp
+++ b/tests/validation/reference/PostOps.cpp
@@ -59,6 +59,22 @@ SimpleTensor<T> post_ops(const SimpleTensor<T> &a, experimental::PostOpList<Simp
dst = reference::arithmetic_operation(ArithmeticOperation::ADD, dst, _post_op->_addend, dst, _post_op->_policy);
break;
}
+ case experimental::PostOpType::Eltwise_PRelu:
+ {
+ const auto _post_op = utils::cast::polymorphic_downcast<const experimental::PostOpEltwisePRelu<SimpleTensor<T>> *>(post_op.get());
+
+ // If previous main operation output is the the first pRelu argument, then pass it as src1 parameter of the arithmetic operation
+ if(_post_op->_prev_dst_pos == 0)
+ {
+ dst = reference::arithmetic_operation(ArithmeticOperation::PRELU, dst, _post_op->_alpha_param, dst, _post_op->_policy);
+ }
+ // If previous main operation output is the the second pRelu argument, then pass it as src2 parameter of the arithmetic operation
+ else if(_post_op->_prev_dst_pos == 1)
+ {
+ dst = reference::arithmetic_operation(ArithmeticOperation::PRELU, _post_op->_alpha_param, dst, dst, _post_op->_policy);
+ }
+ break;
+ }
default:
{
ARM_COMPUTE_ERROR("Unsupported PostOpType");
diff --git a/utils/TypePrinter.h b/utils/TypePrinter.h
index 950d32284a..785b41fc62 100644
--- a/utils/TypePrinter.h
+++ b/utils/TypePrinter.h
@@ -161,6 +161,11 @@ inline ::std::ostream &operator<<(::std::ostream &os, experimental::PostOpType p
os << "Eltwise_Add";
break;
}
+ case experimental::PostOpType::Eltwise_PRelu:
+ {
+ os << "Eltwise_PRelu";
+ break;
+ }
default:
{
ARM_COMPUTE_ERROR("Unsupported PostOpType");
@@ -208,6 +213,12 @@ inline ::std::ostream &operator<<(::std::ostream &os, const experimental::IPostO
os << "convert_policy=" << _post_op->_policy;
break;
}
+ case experimental::PostOpType::Eltwise_PRelu:
+ {
+ const auto _post_op = utils::cast::polymorphic_downcast<const experimental::PostOpEltwisePRelu<T> *>(&post_op);
+ os << "convert_policy=" << _post_op->_policy;
+ break;
+ }
default:
{
ARM_COMPUTE_ERROR("Unsupported PostOpType");