aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp')
-rw-r--r--tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp27
1 files changed, 26 insertions, 1 deletions
diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
index ca8b21cd0d..4c482b49aa 100644
--- a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
@@ -196,12 +196,37 @@ experimental::PostOpList<PostOpArgBroadcast> post_ops_3()
ConvertPolicy::SATURATE);
return post_ops;
}
-
+// To test that the output of the main op is the first parameter in prelu post op
+experimental::PostOpList<PostOpArgBroadcast> post_ops_4()
+{
+ experimental::PostOpList<PostOpArgBroadcast> post_ops{};
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F});
+ post_ops.push_back_op<experimental::PostOpEltwisePRelu<PostOpArgBroadcast>>(
+ std::make_tuple(false, false, true), // If true, broadcast in corresponding dim: 0, 1 or 2
+ 0,
+ ConvertPolicy::SATURATE);
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F});
+ return post_ops;
+}
+// To test that the output of the main op is the second parameter in prelu post op i.e. it is the alpha_param
+experimental::PostOpList<PostOpArgBroadcast> post_ops_5()
+{
+ experimental::PostOpList<PostOpArgBroadcast> post_ops{};
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F});
+ post_ops.push_back_op<experimental::PostOpEltwisePRelu<PostOpArgBroadcast>>(
+ std::make_tuple(false, false, false), // If true, broadcast in corresponding dim: 0, 1 or 2
+ 1,
+ ConvertPolicy::SATURATE);
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F});
+ return post_ops;
+}
/** Different Post Op Lists */
const auto post_op_lists = framework::dataset::make("post_op_lists", {
post_ops_1(),
post_ops_2(),
post_ops_3(),
+ post_ops_4(),
+ post_ops_5()
} );
bool is_post_op_list_valid(unsigned int m, unsigned int n, unsigned int k, unsigned int batch, DataType data_type, const experimental::PostOpList<ITensorInfo*>& post_ops)