From eb8bd81a625f0f87080dbde55b434362ad57324a Mon Sep 17 00:00:00 2001 From: SiCongLi Date: Fri, 29 Oct 2021 15:05:49 +0100 Subject: Fix dst "widening" validation * Auto-initialize the dst tensor before checking for PostOp shape compliance so that we catch the invalid case of "widening" dst tensor shape * Rework post op validate test cases to be more readable Partially resolves: COMPMID-4435 Change-Id: I79943994182942f962e4d59a7fa0d6f017ae9ac7 Signed-off-by: SiCongLi Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6548 Reviewed-by: Gian Marco Iodice Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins --- tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp | 272 +++++++++++---------- 1 file changed, 146 insertions(+), 126 deletions(-) (limited to 'tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp') diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp index b13c380470..a598780bf6 100644 --- a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp +++ b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp @@ -185,11 +185,6 @@ const auto lhs_transpose_values = framework::dataset::make("lhs_transpose", { fa /** Post Ops */ using PostOpArgBroadcast = CLGEMMMatrixMultiplyReshapedWithPostOpsFixture::PostOpArgBroadcast; -experimental::PostOpList empty_post_ops() -{ - return experimental::PostOpList{}; -} - experimental::PostOpList post_ops_1() { experimental::PostOpList post_ops{}; @@ -221,20 +216,6 @@ experimental::PostOpList post_ops_3() ConvertPolicy::SATURATE); return post_ops; } -experimental::PostOpList invalid_post_ops_1() -{ - experimental::PostOpList post_ops{}; - post_ops.push_back_op>( - std::make_tuple(true, true, false), // If broadcast in dims 0, 1 and 2 - 1, - ConvertPolicy::SATURATE); - post_ops.push_back_op>( - std::make_tuple(false, true, false), // If broadcast in dims 0, 1 and 2 - 0, - ConvertPolicy::SATURATE); - return post_ops; -} - /** Different Post Op Lists */ const auto post_op_lists = framework::dataset::make("post_op_lists", { post_ops_1(), @@ -242,6 +223,42 @@ const auto post_op_lists = framework::dataset::make("post_op_lists", { post_ops_3(), } ); +bool is_post_op_list_valid(unsigned int m, unsigned int n, unsigned int k, unsigned int batch, DataType data_type, const experimental::PostOpList& post_ops) +{ + const auto lhs_info = GEMMLHSMatrixInfo(4,4,1,false,true); + const auto rhs_info = GEMMRHSMatrixInfo(4,4,1,true,true,false); + + // Create TensorInfo for post op arguments + TensorInfo input0_info(TensorShape(k, m, batch), 1, data_type); + TensorInfo input1_info(TensorShape(n, k, batch), 1, data_type); + TensorInfo input2_info(TensorShape(n), 1, data_type); + TensorInfo output_info(TensorShape(n, m, batch), 1, data_type); + + const TensorInfo reshaped_input0_info = input0_info.clone()->set_tensor_shape(misc::shape_calculator::compute_lhs_reshaped_shape(input0_info, lhs_info)); + const TensorInfo reshaped_input1_info = input1_info.clone()->set_tensor_shape(misc::shape_calculator::compute_rhs_reshaped_shape(input1_info, rhs_info)); + + GEMMKernelInfo gemm_info(m, n, k, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */, + false /**< reinterpret the input as 3D */, + true /**< Flag used to broadcast the bias addition */, + false /**< wider accumm */, + false /**< has pad y */, + ActivationLayerInfo::ActivationFunction::IDENTITY, + 1 /**< Multiplication factor for the width of the 1xW transposed block */, + 1 /**< Multiplication factor for the height of the 4x4 interleaved block */, + lhs_info, + rhs_info, + 0 /**< Offset to be added to each element of the matrix A */, + 0 /**< Offset to be added to each element of the matrix B */, + post_ops); + return bool(ClGemmMatrixMultiplyReshapedKernel::validate(&reshaped_input0_info.clone()->set_is_resizable(true), + &reshaped_input1_info.clone()->set_is_resizable(true), + &input2_info.clone()->set_is_resizable(true), + &output_info.clone()->set_is_resizable(true),1.f,1.f, + lhs_info, + rhs_info, + gemm_info)); +} + } // namespace TEST_SUITE(CL) @@ -406,116 +423,119 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zi rhs_info, gemm_info)) == expected, framework::LogLevel::ERRORS); } -DATA_TEST_CASE(ValidateFusedPosOps, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zip(zip( - framework::dataset::make("Input0Info", { TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F32), // OK. Empty post ops - TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F32), // Invalid post op sequences - TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F32), // OK. Supported post ops - - }), - framework::dataset::make("Input1Info",{ TensorInfo(TensorShape(64U, 6U, 2U), 1, DataType::F32), - TensorInfo(TensorShape(64U, 6U, 2U), 1, DataType::F32), - TensorInfo(TensorShape(64U, 6U, 2U), 1, DataType::F32), - - })), - framework::dataset::make("Input2Info", { TensorInfo(TensorShape(21U), 1, DataType::F32), - TensorInfo(TensorShape(21U), 1, DataType::F32), - TensorInfo(TensorShape(21U), 1, DataType::F32), - - })), - framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F32), - TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F32), - TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F32), - - })), - framework::dataset::make("LHSMInfo",{ - GEMMLHSMatrixInfo(4,4,1,false,true), - GEMMLHSMatrixInfo(4,4,1,false,true), - GEMMLHSMatrixInfo(4,4,1,false,true), - - })), - framework::dataset::make("RHSMInfo",{ - GEMMRHSMatrixInfo(4,4,1,true,true,false), - GEMMRHSMatrixInfo(4,4,1,true,true,false), - GEMMRHSMatrixInfo(4,4,1,true,true,false), - - - })), - - - framework::dataset::make("GEMMInfo",{ - GEMMKernelInfo( 17 /** post_ops{}; + post_ops.push_back_op>( + &post_op_arg_info, + 1, + ConvertPolicy::SATURATE); + post_ops.push_back_op>( + post_op_arg1_info.get(), + 0, + ConvertPolicy::SATURATE); - GEMMKernelInfo( 17 /** post_op_tensor_infos; - auto populated_post_ops = experimental::transform_post_op_list_arguments(post_ops, - [&output_info, &post_op_tensor_infos](auto broadcast){ - post_op_tensor_infos.emplace_back(TensorShape{ - std::get<0>(broadcast) ? 1 : output_info.dimension(0), - std::get<1>(broadcast) ? 1 : output_info.dimension(1), - std::get<2>(broadcast) ? 1 : output_info.dimension(2) - }, 1, output_info.data_type()); - return &post_op_tensor_infos.back(); - }); - GEMMKernelInfo gemm_info_with_post_ops(std::move(gemm_info)); - gemm_info_with_post_ops.post_ops = populated_post_ops; - ARM_COMPUTE_EXPECT(bool(ClGemmMatrixMultiplyReshapedKernel::validate(&input0_info.clone()->set_is_resizable(true), - &input1_info.clone()->set_is_resizable(true), - &input2_info.clone()->set_is_resizable(true), - &output_info.clone()->set_is_resizable(true),1.f,1.f, - lhs_info, - rhs_info, - gemm_info_with_post_ops)) == expected, framework::LogLevel::ERRORS); + // Invalid broadcast: post op tensors "widen" the output tensor + const auto data_type = DataType::F32; + const unsigned int m = 17; + const unsigned int n = 1; + const unsigned int k = 13; + const unsigned int batch = 2; + TensorShape post_op_arg_shape(n + 4, m, batch); // output's X dimension (n) is "widened", which is not allowed + TensorInfo post_op_arg_info(post_op_arg_shape, 1, data_type); + experimental::PostOpList post_ops{}; + post_ops.push_back_op>( &post_op_arg_info, 0, ConvertPolicy::SATURATE); + + ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == false, framework::LogLevel::ERRORS); +} +TEST_CASE(BroadcastInXDimOnly, framework::DatasetMode::ALL) +{ + // Invalid broadcast: post op tensors broadcast in the first dimension (X) only + const auto data_type = DataType::F32; + const unsigned int m = 22; + const unsigned int n = 16; + const unsigned int k = 15; + const unsigned int batch = 3; + TensorShape post_op_arg_shape(1, m, batch); + TensorInfo post_op_arg_info(post_op_arg_shape, 1, data_type); + experimental::PostOpList post_ops{}; + post_ops.push_back_op>( &post_op_arg_info, 0, ConvertPolicy::SATURATE); + + ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == false, framework::LogLevel::ERRORS); +} +TEST_SUITE_END() // Invalid +TEST_SUITE(Valid) +TEST_CASE(EmptyPostOpList, framework::DatasetMode::ALL) +{ + const auto data_type = DataType::F32; + const unsigned int m = 22; + const unsigned int n = 16; + const unsigned int k = 15; + const unsigned int batch = 3; + experimental::PostOpList post_ops{}; + + ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == true, framework::LogLevel::ERRORS); +} +TEST_CASE(BroadcastInYDimOnly, framework::DatasetMode::ALL) +{ + const auto data_type = DataType::F32; + const unsigned int m = 22; + const unsigned int n = 16; + const unsigned int k = 15; + const unsigned int batch = 3; + TensorShape post_op_arg_shape(n, 1, batch); + TensorInfo post_op_arg_info(post_op_arg_shape, 1, data_type); + experimental::PostOpList post_ops{}; + post_ops.push_back_op>( &post_op_arg_info, 0, ConvertPolicy::SATURATE); + + ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == true, framework::LogLevel::ERRORS); +} +TEST_CASE(BroadcastInBothXandYDims, framework::DatasetMode::ALL) +{ + const auto data_type = DataType::F32; + const unsigned int m = 22; + const unsigned int n = 16; + const unsigned int k = 15; + const unsigned int batch = 3; + TensorShape post_op_arg_shape(1, 1, batch); + TensorInfo post_op_arg_info(post_op_arg_shape, 1, data_type); + experimental::PostOpList post_ops{}; + post_ops.push_back_op>( &post_op_arg_info, 0, ConvertPolicy::SATURATE); + + ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == true, framework::LogLevel::ERRORS); +} +TEST_CASE(BroadcastInAllDims, framework::DatasetMode::ALL) +{ + const auto data_type = DataType::F32; + const unsigned int m = 22; + const unsigned int n = 16; + const unsigned int k = 15; + const unsigned int batch = 3; + TensorShape post_op_arg_shape(1, 1, 1); + TensorInfo post_op_arg_info(post_op_arg_shape, 1, data_type); + experimental::PostOpList post_ops{}; + post_ops.push_back_op>( &post_op_arg_info, 0, ConvertPolicy::SATURATE); + + ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == true, framework::LogLevel::ERRORS); } +TEST_SUITE_END() // Valid +TEST_SUITE_END() // ValidateFusedPostOps TEST_SUITE(Float) TEST_SUITE(FP32) -- cgit v1.2.1