From 08ddd7b1c6c6c08361115142eb58e43267d5f264 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Wed, 19 Dec 2018 10:01:18 +0000 Subject: COMPMID-1834: Add transpose support to CLGEMMReshapeLHSMatrixKernel Change-Id: I913a7297a0c34a05b1d37eab1489b430423700e8 Reviewed-on: https://review.mlplatform.org/417 Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio --- tests/validation/CL/GEMMReshapeLHSMatrix.cpp | 188 +++++++++++++++------------ 1 file changed, 105 insertions(+), 83 deletions(-) (limited to 'tests') diff --git a/tests/validation/CL/GEMMReshapeLHSMatrix.cpp b/tests/validation/CL/GEMMReshapeLHSMatrix.cpp index ea6589df22..a2dd9c2a0a 100644 --- a/tests/validation/CL/GEMMReshapeLHSMatrix.cpp +++ b/tests/validation/CL/GEMMReshapeLHSMatrix.cpp @@ -42,21 +42,40 @@ namespace test { namespace validation { -namespace -{ +using namespace arm_compute::misc::shape_calculator; + +// Initialize the output tensor with zero and fill the border with zero +using CLGEMMReshapeLHSMatrix = CLSynthetizeFunctionInitOutputWithZeroAndWithZeroConstantBorder; + +template +using CLGEMMReshapeLHSMatrixFixture = GEMMReshapeLHSMatrixValidationFixture; + +// Fixture to use when the input has to be reinterpreted as 3D +template +using CLGEMMReshapeLHSMatrix3DFixture = GEMMReshapeLHSMatrixValidationFixture; + // *INDENT-OFF* // clang-format off /** Data types */ + +namespace +{ const auto data_types = framework::dataset::make("DataType", { DataType::QASYMM8, DataType::F16, DataType::F32 }); /** Batch size values to test */ const auto b_values = framework::dataset::make("batchsize", 1, 3); -/** M0 values to test */ -const auto m0_values = framework::dataset::make("M0", 2, 9); +/** M0 values to test - Precommit */ +const auto m0_values_precommit = framework::dataset::make("M0", { 2, 4, 5 }); -/** K0 values to test */ -const auto k0_values = framework::dataset::make("K0", { 2, 4, 8, 16 }); +/** K0 values to test - Precommit */ +const auto k0_values_precommit = framework::dataset::make("K0", { 2, 4 }); + +/** M0 values to test - Precommit */ +const auto m0_values_nightly = framework::dataset::make("M0", 2, 9); + +/** K0 values to test - Precommit */ +const auto k0_values_nightly = framework::dataset::make("K0", { 2, 4, 8, 16 }); /** V0 values to test */ const auto v0_values = framework::dataset::make("V0", 1, 4); @@ -65,33 +84,10 @@ const auto v0_values = framework::dataset::make("V0", 1, 4); const auto i_values = framework::dataset::make("interleave", { true, false }); /** Transpose values to test */ -const auto t_values = framework::dataset::make("transpose", { false }); -} // namespace - -using namespace arm_compute::misc::shape_calculator; +const auto t_values = framework::dataset::make("transpose", { true, false }); -// Initialize the output tensor with zero and fill the border with zero -using CLGEMMReshapeLHSMatrix = CLSynthetizeFunctionInitOutputWithZeroAndWithZeroConstantBorder; - -template -using CLGEMMReshapeLHSMatrixFixture = GEMMReshapeLHSMatrixValidationFixture; - -// Fixture to use when the input has to be reinterpreted as 3D -template -using CLGEMMReshapeLHSMatrix3DFixture = GEMMReshapeLHSMatrixValidationFixture; - -TEST_SUITE(CL) -TEST_SUITE(GEMMReshapeLHSMatrix) - -DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape2DShapes(), - b_values), - data_types), - m0_values), - k0_values), - v0_values), - i_values), - t_values), -shape_in, b_value, data_type, m0_value, k0_value, v0_value, i_value, t_value) +/** Configuration test */ +void validate_configuration(TensorShape shape_in, unsigned int b_value, DataType data_type, unsigned int m0_value, unsigned int k0_value, unsigned int v0_value, bool i_value, bool t_value, bool reinterpret_input_as_3d) { GEMMLHSMatrixInfo lhs_info; lhs_info.m0 = m0_value; @@ -100,8 +96,10 @@ shape_in, b_value, data_type, m0_value, k0_value, v0_value, i_value, t_value) lhs_info.interleave = i_value; lhs_info.transpose = t_value; - const TensorShape shape_src(shape_in[0], shape_in[1], b_value); - const TensorShape shape_dst = compute_lhs_reshaped_shape(TensorInfo(shape_src, 1, data_type), lhs_info, false); + TensorShape shape_src = shape_in; + shape_src.set(shape_src.num_dimensions(), b_value); + + const TensorShape shape_dst = compute_lhs_reshaped_shape(TensorInfo(shape_src, 1, data_type), lhs_info, reinterpret_input_as_3d); // Create tensors CLTensor src = create_tensor(shape_src, data_type); @@ -112,7 +110,37 @@ shape_in, b_value, data_type, m0_value, k0_value, v0_value, i_value, t_value) // Create and configure function CLGEMMReshapeLHSMatrixKernel reshape_lhs; - reshape_lhs.configure(&src, &dst, lhs_info, false); + reshape_lhs.configure(&src, &dst, lhs_info, reinterpret_input_as_3d); +} +} // namespace + +TEST_SUITE(CL) +TEST_SUITE(GEMMReshapeLHSMatrix) + +DATA_TEST_CASE(ConfigurationSmall, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape2DShapes(), + b_values), + data_types), + m0_values_precommit), + k0_values_precommit), + v0_values), + i_values), + t_values), +shape_in, b_value, data_type, m0_value, k0_value, v0_value, i_value, t_value) +{ + validate_configuration(shape_in, b_value, data_type, m0_value, k0_value, v0_value, i_value, t_value, false); +} + +DATA_TEST_CASE(ConfigurationLarge, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(datasets::LargeGEMMReshape2DShapes(), + b_values), + data_types), + m0_values_nightly), + k0_values_nightly), + v0_values), + i_values), + t_values), +shape_in, b_value, data_type, m0_value, k0_value, v0_value, i_value, t_value) +{ + validate_configuration(shape_in, b_value, data_type, m0_value, k0_value, v0_value, i_value, t_value, false); } TEST_SUITE(S32) @@ -120,8 +148,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMReshapeLHSMatrixFixture, framework:: combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape2DShapes(), b_values), framework::dataset::make("DataType", DataType::S32)), - m0_values), - k0_values), + m0_values_precommit), + k0_values_precommit), v0_values), i_values), t_values)) @@ -134,8 +162,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMReshapeLHSMatrixFixture, framework:: combine(combine(combine(combine(combine(combine(combine(datasets::LargeGEMMReshape2DShapes(), b_values), framework::dataset::make("DataType", DataType::S32)), - m0_values), - k0_values), + m0_values_nightly), + k0_values_nightly), v0_values), i_values), t_values)) @@ -150,8 +178,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMReshapeLHSMatrixFixture, framework combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape2DShapes(), b_values), framework::dataset::make("DataType", DataType::S16)), - m0_values), - k0_values), + m0_values_precommit), + k0_values_precommit), v0_values), i_values), t_values)) @@ -164,8 +192,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMReshapeLHSMatrixFixture, framework combine(combine(combine(combine(combine(combine(combine(datasets::LargeGEMMReshape2DShapes(), b_values), framework::dataset::make("DataType", DataType::S16)), - m0_values), - k0_values), + m0_values_nightly), + k0_values_nightly), v0_values), i_values), t_values)) @@ -180,8 +208,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMReshapeLHSMatrixFixture, framework: combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape2DShapes(), b_values), framework::dataset::make("DataType", DataType::S8)), - m0_values), - k0_values), + m0_values_precommit), + k0_values_precommit), v0_values), i_values), t_values)) @@ -194,8 +222,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMReshapeLHSMatrixFixture, framework: combine(combine(combine(combine(combine(combine(combine(datasets::LargeGEMMReshape2DShapes(), b_values), framework::dataset::make("DataType", DataType::S8)), - m0_values), - k0_values), + m0_values_nightly), + k0_values_nightly), v0_values), i_values), t_values)) @@ -205,37 +233,31 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMReshapeLHSMatrixFixture, framework: } TEST_SUITE_END() // S8 -TEST_SUITE(REINTERPRET_INPUT_AS_3D) -DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape3DShapes(), +TEST_SUITE(ReinterpretInputAs3D) +DATA_TEST_CASE(ConfigurationSmall, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape3DShapes(), b_values), data_types), - m0_values), - k0_values), + m0_values_precommit), + k0_values_precommit), v0_values), i_values), t_values), shape_in, b_value, data_type, m0_value, k0_value, v0_value, i_value, t_value) { - GEMMLHSMatrixInfo lhs_info; - lhs_info.m0 = m0_value; - lhs_info.k0 = k0_value; - lhs_info.v0 = v0_value; - lhs_info.interleave = i_value; - lhs_info.transpose = t_value; - - const TensorShape shape_src(shape_in[0], shape_in[1], shape_in[2], b_value); - const TensorShape shape_dst = compute_lhs_reshaped_shape(TensorInfo(shape_src, 1, data_type), lhs_info, true); - - // Create tensors - CLTensor src = create_tensor(shape_src, data_type); - CLTensor dst = create_tensor(shape_dst, data_type); - - ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); + validate_configuration(shape_in, b_value, data_type, m0_value, k0_value, v0_value, i_value, t_value, true); +} - // Create and configure function - CLGEMMReshapeLHSMatrixKernel reshape_lhs; - reshape_lhs.configure(&src, &dst, lhs_info, true); +DATA_TEST_CASE(ConfigurationLarge, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(datasets::LargeGEMMReshape3DShapes(), + b_values), + data_types), + m0_values_nightly), + k0_values_nightly), + v0_values), + i_values), + t_values), +shape_in, b_value, data_type, m0_value, k0_value, v0_value, i_value, t_value) +{ + validate_configuration(shape_in, b_value, data_type, m0_value, k0_value, v0_value, i_value, t_value, true); } TEST_SUITE(S32) @@ -243,8 +265,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMReshapeLHSMatrix3DFixture, framework combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape3DShapes(), b_values), framework::dataset::make("DataType", DataType::S32)), - m0_values), - k0_values), + m0_values_precommit), + k0_values_precommit), v0_values), i_values), t_values)) @@ -257,8 +279,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMReshapeLHSMatrix3DFixture, framework combine(combine(combine(combine(combine(combine(combine(datasets::LargeGEMMReshape3DShapes(), b_values), framework::dataset::make("DataType", DataType::S32)), - m0_values), - k0_values), + m0_values_nightly), + k0_values_nightly), v0_values), i_values), t_values)) @@ -273,8 +295,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMReshapeLHSMatrix3DFixture, framewo combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape3DShapes(), b_values), framework::dataset::make("DataType", DataType::S16)), - m0_values), - k0_values), + m0_values_precommit), + k0_values_precommit), v0_values), i_values), t_values)) @@ -287,8 +309,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMReshapeLHSMatrix3DFixture, framewo combine(combine(combine(combine(combine(combine(combine(datasets::LargeGEMMReshape3DShapes(), b_values), framework::dataset::make("DataType", DataType::S16)), - m0_values), - k0_values), + m0_values_nightly), + k0_values_nightly), v0_values), i_values), t_values)) @@ -303,8 +325,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMReshapeLHSMatrix3DFixture, framewor combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape3DShapes(), b_values), framework::dataset::make("DataType", DataType::S8)), - m0_values), - k0_values), + m0_values_precommit), + k0_values_precommit), v0_values), i_values), t_values)) @@ -317,8 +339,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMReshapeLHSMatrix3DFixture, framewor combine(combine(combine(combine(combine(combine(combine(datasets::LargeGEMMReshape3DShapes(), b_values), framework::dataset::make("DataType", DataType::S8)), - m0_values), - k0_values), + m0_values_nightly), + k0_values_nightly), v0_values), i_values), t_values)) @@ -327,7 +349,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMReshapeLHSMatrix3DFixture, framewor validate(CLAccessor(_target), _reference); } TEST_SUITE_END() // S8 -TEST_SUITE_END() // REINTERPRET_INPUT_AS_3D +TEST_SUITE_END() // ReinterpretInputAs3D TEST_SUITE_END() // GEMMReshapeLHSMatrix TEST_SUITE_END() // CL } // namespace validation -- cgit v1.2.1