aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorramelg01 <ramy.elgammal@arm.com>2021-10-29 10:52:53 +0100
committerramy.elgammal <ramy.elgammal@arm.com>2021-11-04 11:10:56 +0000
commit6049edadf0c89a026b3fcd1927ee7531d3c40278 (patch)
treec12fcea637e41cdb9e1f72dc734e4a87d0b31981 /tests
parent71cbd28b7cf5115b0451d43e5c84cce4ae4d8ec7 (diff)
downloadComputeLibrary-6049edadf0c89a026b3fcd1927ee7531d3c40278.tar.gz
Add PRelu to supported PostOps in:
- ClGemmMatrixMultiplyReshapedKernel - ClGemmMatrixMultiplyNativeKernel - ClGemmMatrixMultiplyReshapedOnlyRhsKernel Resolves: COMPMID-4713 Change-Id: I3adcb1b3d4af37ebcbc3bee19cc1845885d08600 Signed-off-by: Ramy Elgammal <ramy.elgammal@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6553 Reviewed-by: SiCong Li <sicong.li@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests')
-rw-r--r--tests/validation/CL/GEMMMatrixMultiplyNative.cpp29
-rw-r--r--tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp28
-rw-r--r--tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp27
-rw-r--r--tests/validation/reference/PostOps.cpp16
4 files changed, 96 insertions, 4 deletions
diff --git a/tests/validation/CL/GEMMMatrixMultiplyNative.cpp b/tests/validation/CL/GEMMMatrixMultiplyNative.cpp
index e3f151a2ca..54e9d32afc 100644
--- a/tests/validation/CL/GEMMMatrixMultiplyNative.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiplyNative.cpp
@@ -179,13 +179,38 @@ experimental::PostOpList<PostOpArgBroadcast> post_ops_3()
ConvertPolicy::SATURATE);
return post_ops;
}
-
+// To test that the output of the main op is the first parameter in prelu post op
+experimental::PostOpList<PostOpArgBroadcast> post_ops_4()
+{
+ experimental::PostOpList<PostOpArgBroadcast> post_ops{};
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F});
+ post_ops.push_back_op<experimental::PostOpEltwisePRelu<PostOpArgBroadcast>>(
+ std::make_tuple(false, false, true), // If true, broadcast in corresponding dim: 0, 1 or 2
+ 0,
+ ConvertPolicy::SATURATE);
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F});
+ return post_ops;
+}
+// To test that the output of the main op is the second parameter in prelu post op i.e. it is the alpha_param
+experimental::PostOpList<PostOpArgBroadcast> post_ops_5()
+{
+ experimental::PostOpList<PostOpArgBroadcast> post_ops{};
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F});
+ post_ops.push_back_op<experimental::PostOpEltwisePRelu<PostOpArgBroadcast>>(
+ std::make_tuple(false, false, false), // If true, broadcast in corresponding dim: 0, 1 or 2
+ 1,
+ ConvertPolicy::SATURATE);
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F});
+ return post_ops;
+}
/** Different Post Op Lists */
const auto post_op_lists = framework::dataset::make("post_op_lists", {
post_ops_1(),
post_ops_2(),
post_ops_3(),
-} );
+ post_ops_4(),
+ post_ops_5()
+ } );
bool is_post_op_list_valid(unsigned int m, unsigned int n, unsigned int k, unsigned int batch, DataType data_type, const experimental::PostOpList<ITensorInfo*>& post_ops)
{
diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
index a598780bf6..bedd0f5bfb 100644
--- a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
@@ -216,11 +216,37 @@ experimental::PostOpList<PostOpArgBroadcast> post_ops_3()
ConvertPolicy::SATURATE);
return post_ops;
}
+// To test that the output of the main op is the first parameter in prelu post op
+experimental::PostOpList<PostOpArgBroadcast> post_ops_4()
+{
+ experimental::PostOpList<PostOpArgBroadcast> post_ops{};
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F});
+ post_ops.push_back_op<experimental::PostOpEltwisePRelu<PostOpArgBroadcast>>(
+ std::make_tuple(false, false, true), // If true, broadcast in corresponding dim: 0, 1 or 2
+ 0,
+ ConvertPolicy::SATURATE);
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F});
+ return post_ops;
+}
+// To test that the output of the main op is the second parameter in prelu post op i.e. it is the alpha_param
+experimental::PostOpList<PostOpArgBroadcast> post_ops_5()
+{
+ experimental::PostOpList<PostOpArgBroadcast> post_ops{};
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F});
+ post_ops.push_back_op<experimental::PostOpEltwisePRelu<PostOpArgBroadcast>>(
+ std::make_tuple(false, false, false), // If true, broadcast in corresponding dim: 0, 1 or 2
+ 1,
+ ConvertPolicy::SATURATE);
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F});
+ return post_ops;
+}
/** Different Post Op Lists */
const auto post_op_lists = framework::dataset::make("post_op_lists", {
post_ops_1(),
post_ops_2(),
post_ops_3(),
+ post_ops_4(),
+ post_ops_5()
} );
bool is_post_op_list_valid(unsigned int m, unsigned int n, unsigned int k, unsigned int batch, DataType data_type, const experimental::PostOpList<ITensorInfo*>& post_ops)
@@ -479,7 +505,7 @@ TEST_CASE(BroadcastInXDimOnly, framework::DatasetMode::ALL)
ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == false, framework::LogLevel::ERRORS);
}
-TEST_SUITE_END() // Invalid
+TEST_SUITE_END() // Invalid
TEST_SUITE(Valid)
TEST_CASE(EmptyPostOpList, framework::DatasetMode::ALL)
{
diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
index ca8b21cd0d..4c482b49aa 100644
--- a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
@@ -196,12 +196,37 @@ experimental::PostOpList<PostOpArgBroadcast> post_ops_3()
ConvertPolicy::SATURATE);
return post_ops;
}
-
+// To test that the output of the main op is the first parameter in prelu post op
+experimental::PostOpList<PostOpArgBroadcast> post_ops_4()
+{
+ experimental::PostOpList<PostOpArgBroadcast> post_ops{};
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F});
+ post_ops.push_back_op<experimental::PostOpEltwisePRelu<PostOpArgBroadcast>>(
+ std::make_tuple(false, false, true), // If true, broadcast in corresponding dim: 0, 1 or 2
+ 0,
+ ConvertPolicy::SATURATE);
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F});
+ return post_ops;
+}
+// To test that the output of the main op is the second parameter in prelu post op i.e. it is the alpha_param
+experimental::PostOpList<PostOpArgBroadcast> post_ops_5()
+{
+ experimental::PostOpList<PostOpArgBroadcast> post_ops{};
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F});
+ post_ops.push_back_op<experimental::PostOpEltwisePRelu<PostOpArgBroadcast>>(
+ std::make_tuple(false, false, false), // If true, broadcast in corresponding dim: 0, 1 or 2
+ 1,
+ ConvertPolicy::SATURATE);
+ post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F});
+ return post_ops;
+}
/** Different Post Op Lists */
const auto post_op_lists = framework::dataset::make("post_op_lists", {
post_ops_1(),
post_ops_2(),
post_ops_3(),
+ post_ops_4(),
+ post_ops_5()
} );
bool is_post_op_list_valid(unsigned int m, unsigned int n, unsigned int k, unsigned int batch, DataType data_type, const experimental::PostOpList<ITensorInfo*>& post_ops)
diff --git a/tests/validation/reference/PostOps.cpp b/tests/validation/reference/PostOps.cpp
index 1a8fb990c8..a81b1c1905 100644
--- a/tests/validation/reference/PostOps.cpp
+++ b/tests/validation/reference/PostOps.cpp
@@ -59,6 +59,22 @@ SimpleTensor<T> post_ops(const SimpleTensor<T> &a, experimental::PostOpList<Simp
dst = reference::arithmetic_operation(ArithmeticOperation::ADD, dst, _post_op->_addend, dst, _post_op->_policy);
break;
}
+ case experimental::PostOpType::Eltwise_PRelu:
+ {
+ const auto _post_op = utils::cast::polymorphic_downcast<const experimental::PostOpEltwisePRelu<SimpleTensor<T>> *>(post_op.get());
+
+ // If previous main operation output is the the first pRelu argument, then pass it as src1 parameter of the arithmetic operation
+ if(_post_op->_prev_dst_pos == 0)
+ {
+ dst = reference::arithmetic_operation(ArithmeticOperation::PRELU, dst, _post_op->_alpha_param, dst, _post_op->_policy);
+ }
+ // If previous main operation output is the the second pRelu argument, then pass it as src2 parameter of the arithmetic operation
+ else if(_post_op->_prev_dst_pos == 1)
+ {
+ dst = reference::arithmetic_operation(ArithmeticOperation::PRELU, _post_op->_alpha_param, dst, dst, _post_op->_policy);
+ }
+ break;
+ }
default:
{
ARM_COMPUTE_ERROR("Unsupported PostOpType");