diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/validation/CL/GEMMReshapeLHSMatrix.cpp | 188 |
1 files changed, 105 insertions, 83 deletions
diff --git a/tests/validation/CL/GEMMReshapeLHSMatrix.cpp b/tests/validation/CL/GEMMReshapeLHSMatrix.cpp index ea6589df22..a2dd9c2a0a 100644 --- a/tests/validation/CL/GEMMReshapeLHSMatrix.cpp +++ b/tests/validation/CL/GEMMReshapeLHSMatrix.cpp @@ -42,21 +42,40 @@ namespace test { namespace validation { -namespace -{ +using namespace arm_compute::misc::shape_calculator; + +// Initialize the output tensor with zero and fill the border with zero +using CLGEMMReshapeLHSMatrix = CLSynthetizeFunctionInitOutputWithZeroAndWithZeroConstantBorder<CLGEMMReshapeLHSMatrixKernel, 16>; + +template <typename T> +using CLGEMMReshapeLHSMatrixFixture = GEMMReshapeLHSMatrixValidationFixture<CLTensor, CLAccessor, CLGEMMReshapeLHSMatrix, T, false>; + +// Fixture to use when the input has to be reinterpreted as 3D +template <typename T> +using CLGEMMReshapeLHSMatrix3DFixture = GEMMReshapeLHSMatrixValidationFixture<CLTensor, CLAccessor, CLGEMMReshapeLHSMatrix, T, true>; + // *INDENT-OFF* // clang-format off /** Data types */ + +namespace +{ const auto data_types = framework::dataset::make("DataType", { DataType::QASYMM8, DataType::F16, DataType::F32 }); /** Batch size values to test */ const auto b_values = framework::dataset::make("batchsize", 1, 3); -/** M0 values to test */ -const auto m0_values = framework::dataset::make("M0", 2, 9); +/** M0 values to test - Precommit */ +const auto m0_values_precommit = framework::dataset::make("M0", { 2, 4, 5 }); -/** K0 values to test */ -const auto k0_values = framework::dataset::make("K0", { 2, 4, 8, 16 }); +/** K0 values to test - Precommit */ +const auto k0_values_precommit = framework::dataset::make("K0", { 2, 4 }); + +/** M0 values to test - Precommit */ +const auto m0_values_nightly = framework::dataset::make("M0", 2, 9); + +/** K0 values to test - Precommit */ +const auto k0_values_nightly = framework::dataset::make("K0", { 2, 4, 8, 16 }); /** V0 values to test */ const auto v0_values = framework::dataset::make("V0", 1, 4); @@ -65,33 +84,10 @@ const auto v0_values = framework::dataset::make("V0", 1, 4); const auto i_values = framework::dataset::make("interleave", { true, false }); /** Transpose values to test */ -const auto t_values = framework::dataset::make("transpose", { false }); -} // namespace - -using namespace arm_compute::misc::shape_calculator; +const auto t_values = framework::dataset::make("transpose", { true, false }); -// Initialize the output tensor with zero and fill the border with zero -using CLGEMMReshapeLHSMatrix = CLSynthetizeFunctionInitOutputWithZeroAndWithZeroConstantBorder<CLGEMMReshapeLHSMatrixKernel, 16>; - -template <typename T> -using CLGEMMReshapeLHSMatrixFixture = GEMMReshapeLHSMatrixValidationFixture<CLTensor, CLAccessor, CLGEMMReshapeLHSMatrix, T, false>; - -// Fixture to use when the input has to be reinterpreted as 3D -template <typename T> -using CLGEMMReshapeLHSMatrix3DFixture = GEMMReshapeLHSMatrixValidationFixture<CLTensor, CLAccessor, CLGEMMReshapeLHSMatrix, T, true>; - -TEST_SUITE(CL) -TEST_SUITE(GEMMReshapeLHSMatrix) - -DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape2DShapes(), - b_values), - data_types), - m0_values), - k0_values), - v0_values), - i_values), - t_values), -shape_in, b_value, data_type, m0_value, k0_value, v0_value, i_value, t_value) +/** Configuration test */ +void validate_configuration(TensorShape shape_in, unsigned int b_value, DataType data_type, unsigned int m0_value, unsigned int k0_value, unsigned int v0_value, bool i_value, bool t_value, bool reinterpret_input_as_3d) { GEMMLHSMatrixInfo lhs_info; lhs_info.m0 = m0_value; @@ -100,8 +96,10 @@ shape_in, b_value, data_type, m0_value, k0_value, v0_value, i_value, t_value) lhs_info.interleave = i_value; lhs_info.transpose = t_value; - const TensorShape shape_src(shape_in[0], shape_in[1], b_value); - const TensorShape shape_dst = compute_lhs_reshaped_shape(TensorInfo(shape_src, 1, data_type), lhs_info, false); + TensorShape shape_src = shape_in; + shape_src.set(shape_src.num_dimensions(), b_value); + + const TensorShape shape_dst = compute_lhs_reshaped_shape(TensorInfo(shape_src, 1, data_type), lhs_info, reinterpret_input_as_3d); // Create tensors CLTensor src = create_tensor<CLTensor>(shape_src, data_type); @@ -112,7 +110,37 @@ shape_in, b_value, data_type, m0_value, k0_value, v0_value, i_value, t_value) // Create and configure function CLGEMMReshapeLHSMatrixKernel reshape_lhs; - reshape_lhs.configure(&src, &dst, lhs_info, false); + reshape_lhs.configure(&src, &dst, lhs_info, reinterpret_input_as_3d); +} +} // namespace + +TEST_SUITE(CL) +TEST_SUITE(GEMMReshapeLHSMatrix) + +DATA_TEST_CASE(ConfigurationSmall, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape2DShapes(), + b_values), + data_types), + m0_values_precommit), + k0_values_precommit), + v0_values), + i_values), + t_values), +shape_in, b_value, data_type, m0_value, k0_value, v0_value, i_value, t_value) +{ + validate_configuration(shape_in, b_value, data_type, m0_value, k0_value, v0_value, i_value, t_value, false); +} + +DATA_TEST_CASE(ConfigurationLarge, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(datasets::LargeGEMMReshape2DShapes(), + b_values), + data_types), + m0_values_nightly), + k0_values_nightly), + v0_values), + i_values), + t_values), +shape_in, b_value, data_type, m0_value, k0_value, v0_value, i_value, t_value) +{ + validate_configuration(shape_in, b_value, data_type, m0_value, k0_value, v0_value, i_value, t_value, false); } TEST_SUITE(S32) @@ -120,8 +148,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMReshapeLHSMatrixFixture<int>, framework:: combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape2DShapes(), b_values), framework::dataset::make("DataType", DataType::S32)), - m0_values), - k0_values), + m0_values_precommit), + k0_values_precommit), v0_values), i_values), t_values)) @@ -134,8 +162,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMReshapeLHSMatrixFixture<int>, framework:: combine(combine(combine(combine(combine(combine(combine(datasets::LargeGEMMReshape2DShapes(), b_values), framework::dataset::make("DataType", DataType::S32)), - m0_values), - k0_values), + m0_values_nightly), + k0_values_nightly), v0_values), i_values), t_values)) @@ -150,8 +178,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMReshapeLHSMatrixFixture<short>, framework combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape2DShapes(), b_values), framework::dataset::make("DataType", DataType::S16)), - m0_values), - k0_values), + m0_values_precommit), + k0_values_precommit), v0_values), i_values), t_values)) @@ -164,8 +192,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMReshapeLHSMatrixFixture<short>, framework combine(combine(combine(combine(combine(combine(combine(datasets::LargeGEMMReshape2DShapes(), b_values), framework::dataset::make("DataType", DataType::S16)), - m0_values), - k0_values), + m0_values_nightly), + k0_values_nightly), v0_values), i_values), t_values)) @@ -180,8 +208,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMReshapeLHSMatrixFixture<char>, framework: combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape2DShapes(), b_values), framework::dataset::make("DataType", DataType::S8)), - m0_values), - k0_values), + m0_values_precommit), + k0_values_precommit), v0_values), i_values), t_values)) @@ -194,8 +222,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMReshapeLHSMatrixFixture<char>, framework: combine(combine(combine(combine(combine(combine(combine(datasets::LargeGEMMReshape2DShapes(), b_values), framework::dataset::make("DataType", DataType::S8)), - m0_values), - k0_values), + m0_values_nightly), + k0_values_nightly), v0_values), i_values), t_values)) @@ -205,37 +233,31 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMReshapeLHSMatrixFixture<char>, framework: } TEST_SUITE_END() // S8 -TEST_SUITE(REINTERPRET_INPUT_AS_3D) -DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape3DShapes(), +TEST_SUITE(ReinterpretInputAs3D) +DATA_TEST_CASE(ConfigurationSmall, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape3DShapes(), b_values), data_types), - m0_values), - k0_values), + m0_values_precommit), + k0_values_precommit), v0_values), i_values), t_values), shape_in, b_value, data_type, m0_value, k0_value, v0_value, i_value, t_value) { - GEMMLHSMatrixInfo lhs_info; - lhs_info.m0 = m0_value; - lhs_info.k0 = k0_value; - lhs_info.v0 = v0_value; - lhs_info.interleave = i_value; - lhs_info.transpose = t_value; - - const TensorShape shape_src(shape_in[0], shape_in[1], shape_in[2], b_value); - const TensorShape shape_dst = compute_lhs_reshaped_shape(TensorInfo(shape_src, 1, data_type), lhs_info, true); - - // Create tensors - CLTensor src = create_tensor<CLTensor>(shape_src, data_type); - CLTensor dst = create_tensor<CLTensor>(shape_dst, data_type); - - ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); + validate_configuration(shape_in, b_value, data_type, m0_value, k0_value, v0_value, i_value, t_value, true); +} - // Create and configure function - CLGEMMReshapeLHSMatrixKernel reshape_lhs; - reshape_lhs.configure(&src, &dst, lhs_info, true); +DATA_TEST_CASE(ConfigurationLarge, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(datasets::LargeGEMMReshape3DShapes(), + b_values), + data_types), + m0_values_nightly), + k0_values_nightly), + v0_values), + i_values), + t_values), +shape_in, b_value, data_type, m0_value, k0_value, v0_value, i_value, t_value) +{ + validate_configuration(shape_in, b_value, data_type, m0_value, k0_value, v0_value, i_value, t_value, true); } TEST_SUITE(S32) @@ -243,8 +265,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMReshapeLHSMatrix3DFixture<int>, framework combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape3DShapes(), b_values), framework::dataset::make("DataType", DataType::S32)), - m0_values), - k0_values), + m0_values_precommit), + k0_values_precommit), v0_values), i_values), t_values)) @@ -257,8 +279,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMReshapeLHSMatrix3DFixture<int>, framework combine(combine(combine(combine(combine(combine(combine(datasets::LargeGEMMReshape3DShapes(), b_values), framework::dataset::make("DataType", DataType::S32)), - m0_values), - k0_values), + m0_values_nightly), + k0_values_nightly), v0_values), i_values), t_values)) @@ -273,8 +295,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMReshapeLHSMatrix3DFixture<short>, framewo combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape3DShapes(), b_values), framework::dataset::make("DataType", DataType::S16)), - m0_values), - k0_values), + m0_values_precommit), + k0_values_precommit), v0_values), i_values), t_values)) @@ -287,8 +309,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMReshapeLHSMatrix3DFixture<short>, framewo combine(combine(combine(combine(combine(combine(combine(datasets::LargeGEMMReshape3DShapes(), b_values), framework::dataset::make("DataType", DataType::S16)), - m0_values), - k0_values), + m0_values_nightly), + k0_values_nightly), v0_values), i_values), t_values)) @@ -303,8 +325,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMReshapeLHSMatrix3DFixture<char>, framewor combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape3DShapes(), b_values), framework::dataset::make("DataType", DataType::S8)), - m0_values), - k0_values), + m0_values_precommit), + k0_values_precommit), v0_values), i_values), t_values)) @@ -317,8 +339,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMReshapeLHSMatrix3DFixture<char>, framewor combine(combine(combine(combine(combine(combine(combine(datasets::LargeGEMMReshape3DShapes(), b_values), framework::dataset::make("DataType", DataType::S8)), - m0_values), - k0_values), + m0_values_nightly), + k0_values_nightly), v0_values), i_values), t_values)) @@ -327,7 +349,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMReshapeLHSMatrix3DFixture<char>, framewor validate(CLAccessor(_target), _reference); } TEST_SUITE_END() // S8 -TEST_SUITE_END() // REINTERPRET_INPUT_AS_3D +TEST_SUITE_END() // ReinterpretInputAs3D TEST_SUITE_END() // GEMMReshapeLHSMatrix TEST_SUITE_END() // CL } // namespace validation |