From 6049edadf0c89a026b3fcd1927ee7531d3c40278 Mon Sep 17 00:00:00 2001 From: ramelg01 Date: Fri, 29 Oct 2021 10:52:53 +0100 Subject: Add PRelu to supported PostOps in: - ClGemmMatrixMultiplyReshapedKernel - ClGemmMatrixMultiplyNativeKernel - ClGemmMatrixMultiplyReshapedOnlyRhsKernel Resolves: COMPMID-4713 Change-Id: I3adcb1b3d4af37ebcbc3bee19cc1845885d08600 Signed-off-by: Ramy Elgammal Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6553 Reviewed-by: SiCong Li Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins --- tests/validation/CL/GEMMMatrixMultiplyNative.cpp | 29 ++++++++++++++++++++-- tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp | 28 ++++++++++++++++++++- .../CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp | 27 +++++++++++++++++++- tests/validation/reference/PostOps.cpp | 16 ++++++++++++ 4 files changed, 96 insertions(+), 4 deletions(-) (limited to 'tests/validation') diff --git a/tests/validation/CL/GEMMMatrixMultiplyNative.cpp b/tests/validation/CL/GEMMMatrixMultiplyNative.cpp index e3f151a2ca..54e9d32afc 100644 --- a/tests/validation/CL/GEMMMatrixMultiplyNative.cpp +++ b/tests/validation/CL/GEMMMatrixMultiplyNative.cpp @@ -179,13 +179,38 @@ experimental::PostOpList post_ops_3() ConvertPolicy::SATURATE); return post_ops; } - +// To test that the output of the main op is the first parameter in prelu post op +experimental::PostOpList post_ops_4() +{ + experimental::PostOpList post_ops{}; + post_ops.push_back_op>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F}); + post_ops.push_back_op>( + std::make_tuple(false, false, true), // If true, broadcast in corresponding dim: 0, 1 or 2 + 0, + ConvertPolicy::SATURATE); + post_ops.push_back_op>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F}); + return post_ops; +} +// To test that the output of the main op is the second parameter in prelu post op i.e. it is the alpha_param +experimental::PostOpList post_ops_5() +{ + experimental::PostOpList post_ops{}; + post_ops.push_back_op>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F}); + post_ops.push_back_op>( + std::make_tuple(false, false, false), // If true, broadcast in corresponding dim: 0, 1 or 2 + 1, + ConvertPolicy::SATURATE); + post_ops.push_back_op>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F}); + return post_ops; +} /** Different Post Op Lists */ const auto post_op_lists = framework::dataset::make("post_op_lists", { post_ops_1(), post_ops_2(), post_ops_3(), -} ); + post_ops_4(), + post_ops_5() + } ); bool is_post_op_list_valid(unsigned int m, unsigned int n, unsigned int k, unsigned int batch, DataType data_type, const experimental::PostOpList& post_ops) { diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp index a598780bf6..bedd0f5bfb 100644 --- a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp +++ b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp @@ -216,11 +216,37 @@ experimental::PostOpList post_ops_3() ConvertPolicy::SATURATE); return post_ops; } +// To test that the output of the main op is the first parameter in prelu post op +experimental::PostOpList post_ops_4() +{ + experimental::PostOpList post_ops{}; + post_ops.push_back_op>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F}); + post_ops.push_back_op>( + std::make_tuple(false, false, true), // If true, broadcast in corresponding dim: 0, 1 or 2 + 0, + ConvertPolicy::SATURATE); + post_ops.push_back_op>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F}); + return post_ops; +} +// To test that the output of the main op is the second parameter in prelu post op i.e. it is the alpha_param +experimental::PostOpList post_ops_5() +{ + experimental::PostOpList post_ops{}; + post_ops.push_back_op>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F}); + post_ops.push_back_op>( + std::make_tuple(false, false, false), // If true, broadcast in corresponding dim: 0, 1 or 2 + 1, + ConvertPolicy::SATURATE); + post_ops.push_back_op>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F}); + return post_ops; +} /** Different Post Op Lists */ const auto post_op_lists = framework::dataset::make("post_op_lists", { post_ops_1(), post_ops_2(), post_ops_3(), + post_ops_4(), + post_ops_5() } ); bool is_post_op_list_valid(unsigned int m, unsigned int n, unsigned int k, unsigned int batch, DataType data_type, const experimental::PostOpList& post_ops) @@ -479,7 +505,7 @@ TEST_CASE(BroadcastInXDimOnly, framework::DatasetMode::ALL) ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == false, framework::LogLevel::ERRORS); } -TEST_SUITE_END() // Invalid +TEST_SUITE_END() // Invalid TEST_SUITE(Valid) TEST_CASE(EmptyPostOpList, framework::DatasetMode::ALL) { diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp index ca8b21cd0d..4c482b49aa 100644 --- a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp +++ b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp @@ -196,12 +196,37 @@ experimental::PostOpList post_ops_3() ConvertPolicy::SATURATE); return post_ops; } - +// To test that the output of the main op is the first parameter in prelu post op +experimental::PostOpList post_ops_4() +{ + experimental::PostOpList post_ops{}; + post_ops.push_back_op>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F}); + post_ops.push_back_op>( + std::make_tuple(false, false, true), // If true, broadcast in corresponding dim: 0, 1 or 2 + 0, + ConvertPolicy::SATURATE); + post_ops.push_back_op>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F}); + return post_ops; +} +// To test that the output of the main op is the second parameter in prelu post op i.e. it is the alpha_param +experimental::PostOpList post_ops_5() +{ + experimental::PostOpList post_ops{}; + post_ops.push_back_op>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F}); + post_ops.push_back_op>( + std::make_tuple(false, false, false), // If true, broadcast in corresponding dim: 0, 1 or 2 + 1, + ConvertPolicy::SATURATE); + post_ops.push_back_op>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F}); + return post_ops; +} /** Different Post Op Lists */ const auto post_op_lists = framework::dataset::make("post_op_lists", { post_ops_1(), post_ops_2(), post_ops_3(), + post_ops_4(), + post_ops_5() } ); bool is_post_op_list_valid(unsigned int m, unsigned int n, unsigned int k, unsigned int batch, DataType data_type, const experimental::PostOpList& post_ops) diff --git a/tests/validation/reference/PostOps.cpp b/tests/validation/reference/PostOps.cpp index 1a8fb990c8..a81b1c1905 100644 --- a/tests/validation/reference/PostOps.cpp +++ b/tests/validation/reference/PostOps.cpp @@ -59,6 +59,22 @@ SimpleTensor post_ops(const SimpleTensor &a, experimental::PostOpList_addend, dst, _post_op->_policy); break; } + case experimental::PostOpType::Eltwise_PRelu: + { + const auto _post_op = utils::cast::polymorphic_downcast> *>(post_op.get()); + + // If previous main operation output is the the first pRelu argument, then pass it as src1 parameter of the arithmetic operation + if(_post_op->_prev_dst_pos == 0) + { + dst = reference::arithmetic_operation(ArithmeticOperation::PRELU, dst, _post_op->_alpha_param, dst, _post_op->_policy); + } + // If previous main operation output is the the second pRelu argument, then pass it as src2 parameter of the arithmetic operation + else if(_post_op->_prev_dst_pos == 1) + { + dst = reference::arithmetic_operation(ArithmeticOperation::PRELU, _post_op->_alpha_param, dst, dst, _post_op->_policy); + } + break; + } default: { ARM_COMPUTE_ERROR("Unsupported PostOpType"); -- cgit v1.2.1