aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp')
-rw-r--r--tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp28
1 files changed, 27 insertions, 1 deletions
diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
index a598780bf6..bedd0f5bfb 100644
--- a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
@@ -216,11 +216,37 @@ experimental::PostOpList<PostOpArgBroadcast> post_ops_3()
ConvertPolicy::SATURATE);
return post_ops;
}
+// To test that the output of the main op is the first parameter in prelu post op
+experimental::PostOpList<PostOpArgBroadcast> post_ops_4()
+{
+ experimental::PostOpList<PostOpArgBroadcast> post_ops{};
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F});
+ post_ops.push_back_op<experimental::PostOpEltwisePRelu<PostOpArgBroadcast>>(
+ std::make_tuple(false, false, true), // If true, broadcast in corresponding dim: 0, 1 or 2
+ 0,
+ ConvertPolicy::SATURATE);
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F});
+ return post_ops;
+}
+// To test that the output of the main op is the second parameter in prelu post op i.e. it is the alpha_param
+experimental::PostOpList<PostOpArgBroadcast> post_ops_5()
+{
+ experimental::PostOpList<PostOpArgBroadcast> post_ops{};
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F});
+ post_ops.push_back_op<experimental::PostOpEltwisePRelu<PostOpArgBroadcast>>(
+ std::make_tuple(false, false, false), // If true, broadcast in corresponding dim: 0, 1 or 2
+ 1,
+ ConvertPolicy::SATURATE);
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F});
+ return post_ops;
+}
/** Different Post Op Lists */
const auto post_op_lists = framework::dataset::make("post_op_lists", {
post_ops_1(),
post_ops_2(),
post_ops_3(),
+ post_ops_4(),
+ post_ops_5()
} );
bool is_post_op_list_valid(unsigned int m, unsigned int n, unsigned int k, unsigned int batch, DataType data_type, const experimental::PostOpList<ITensorInfo*>& post_ops)
@@ -479,7 +505,7 @@ TEST_CASE(BroadcastInXDimOnly, framework::DatasetMode::ALL)
ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == false, framework::LogLevel::ERRORS);
}
-TEST_SUITE_END() // Invalid
+TEST_SUITE_END() // Invalid
TEST_SUITE(Valid)
TEST_CASE(EmptyPostOpList, framework::DatasetMode::ALL)
{