From b3182b19251cd010baad8252e7607de7059ac986 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Fri, 4 Sep 2020 08:44:52 +0100 Subject: COMPMID-3157: Remove padding from NEGEMMTranspose1xWKernel - Remove padding from NEGEMMTranspose1xWKernel - Extend test for validating zero padding requirement Change-Id: I9ce4ca95a500229b045dc140cfff21fdf7373700 Signed-off-by: Gian Marco Iodice Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3920 Tested-by: Arm Jenkins Reviewed-by: Michalis Spyrou Comments-Addressed: Arm Jenkins --- tests/validation/NEON/GEMM.cpp | 61 +++++++++++++++++++++++++++++------------- 1 file changed, 42 insertions(+), 19 deletions(-) (limited to 'tests') diff --git a/tests/validation/NEON/GEMM.cpp b/tests/validation/NEON/GEMM.cpp index 9105638a6e..dfac72f3a5 100644 --- a/tests/validation/NEON/GEMM.cpp +++ b/tests/validation/NEON/GEMM.cpp @@ -68,27 +68,23 @@ const auto data_interleave = framework::dataset::make("M", 8, 12) * framework::d const auto data_transpose = framework::dataset::make("M", 8, 14) * framework::dataset::make("N", 7, 14); /** Zero padding test */ -bool validate_zero_padding(unsigned int m_value, unsigned int k_value) +template +bool validate_zero_padding(unsigned int dim0_value, unsigned int dim1_value) { - const unsigned int M = m_value; - const unsigned int K = k_value; - - const TensorShape lhs_shape(K, M); - const TensorShape lhs_shape_reshaped(K * 4, std::ceil(M / 4.0f)); + const TensorShape in_shape(dim0_value, dim1_value); // Create tensors - Tensor lhs = create_tensor(lhs_shape, DataType::U32); - Tensor dst = create_tensor(lhs_shape_reshaped, DataType::U32); + Tensor in = create_tensor(in_shape, DataType::U32); + Tensor dst; - ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(in.info()->is_resizable(), framework::LogLevel::ERRORS); // Validate zero-padding - NEGEMMInterleave4x4Kernel lhs_reshape; + FunctionType func; - lhs_reshape.configure(&lhs, &dst); + func.configure(&in, &dst); - return lhs.info()->padding().empty(); + return in.info()->padding().empty(); } } // namespace @@ -97,15 +93,42 @@ TEST_SUITE(NEON) TEST_SUITE(GEMM) TEST_SUITE(TRANSPOSE_1XW) -using NEGEMMTranspose1xW = NESynthetizeFunctionWithZeroConstantBorder; -using NEGEMMTranspose1xWFixture = GEMMTranspose1xWValidationFixture; -TEST_SUITE(FP32) -FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * framework::dataset::make("DataType", DataType::F32)) +using NEGEMMTranspose1xW = NESynthetizeFunctionWithZeroConstantBorder; +DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip( + framework::dataset::make("N", { 1, 23, 63, 101 }), + framework::dataset::make("K", { 1, 47, 29, 27 })), + n_value, k_value) +{ + bool status = validate_zero_padding(n_value, k_value); + ARM_COMPUTE_EXPECT(status, framework::LogLevel::ERRORS); +} + +TEST_SUITE(U32) +using NEGEMMTranspose1xWFixture = GEMMTranspose1xWValidationFixture; +FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * framework::dataset::make("DataType", DataType::U32)) +{ + // Validate output + validate(Accessor(_target), _reference); +} +TEST_SUITE_END() // U32 + +TEST_SUITE(U16) +using NEGEMMTranspose1xWFixture = GEMMTranspose1xWValidationFixture; +FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * framework::dataset::make("DataType", DataType::U16)) +{ + // Validate output + validate(Accessor(_target), _reference); +} +TEST_SUITE_END() // U16 + +TEST_SUITE(U8) +using NEGEMMTranspose1xWFixture = GEMMTranspose1xWValidationFixture; +FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * framework::dataset::make("DataType", DataType::U8)) { // Validate output validate(Accessor(_target), _reference); } -TEST_SUITE_END() // FP32 +TEST_SUITE_END() // U8 TEST_SUITE_END() // TRANSPOSE_1XW @@ -117,7 +140,7 @@ DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip( framework::dataset::make("K", { 1, 47, 29, 27 })), m_value, k_value) { - bool status = validate_zero_padding(m_value, k_value); + bool status = validate_zero_padding(m_value, k_value); ARM_COMPUTE_EXPECT(status, framework::LogLevel::ERRORS); } -- cgit v1.2.1