diff options
Diffstat (limited to 'tests/validation')
-rw-r--r-- | tests/validation/NEON/GEMM.cpp | 61 |
1 files changed, 42 insertions, 19 deletions
diff --git a/tests/validation/NEON/GEMM.cpp b/tests/validation/NEON/GEMM.cpp index 9105638a6e..dfac72f3a5 100644 --- a/tests/validation/NEON/GEMM.cpp +++ b/tests/validation/NEON/GEMM.cpp @@ -68,27 +68,23 @@ const auto data_interleave = framework::dataset::make("M", 8, 12) * framework::d const auto data_transpose = framework::dataset::make("M", 8, 14) * framework::dataset::make("N", 7, 14); /** Zero padding test */ -bool validate_zero_padding(unsigned int m_value, unsigned int k_value) +template <typename FunctionType> +bool validate_zero_padding(unsigned int dim0_value, unsigned int dim1_value) { - const unsigned int M = m_value; - const unsigned int K = k_value; - - const TensorShape lhs_shape(K, M); - const TensorShape lhs_shape_reshaped(K * 4, std::ceil(M / 4.0f)); + const TensorShape in_shape(dim0_value, dim1_value); // Create tensors - Tensor lhs = create_tensor<Tensor>(lhs_shape, DataType::U32); - Tensor dst = create_tensor<Tensor>(lhs_shape_reshaped, DataType::U32); + Tensor in = create_tensor<Tensor>(in_shape, DataType::U32); + Tensor dst; - ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(in.info()->is_resizable(), framework::LogLevel::ERRORS); // Validate zero-padding - NEGEMMInterleave4x4Kernel lhs_reshape; + FunctionType func; - lhs_reshape.configure(&lhs, &dst); + func.configure(&in, &dst); - return lhs.info()->padding().empty(); + return in.info()->padding().empty(); } } // namespace @@ -97,15 +93,42 @@ TEST_SUITE(NEON) TEST_SUITE(GEMM) TEST_SUITE(TRANSPOSE_1XW) -using NEGEMMTranspose1xW = NESynthetizeFunctionWithZeroConstantBorder<NEGEMMTranspose1xWKernel, 4>; -using NEGEMMTranspose1xWFixture = GEMMTranspose1xWValidationFixture<Tensor, Accessor, NEGEMMTranspose1xW, float>; -TEST_SUITE(FP32) -FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * framework::dataset::make("DataType", DataType::F32)) +using NEGEMMTranspose1xW = NESynthetizeFunctionWithZeroConstantBorder<NEGEMMTranspose1xWKernel, 4>; +DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip( + framework::dataset::make("N", { 1, 23, 63, 101 }), + framework::dataset::make("K", { 1, 47, 29, 27 })), + n_value, k_value) +{ + bool status = validate_zero_padding<NEGEMMTranspose1xWKernel>(n_value, k_value); + ARM_COMPUTE_EXPECT(status, framework::LogLevel::ERRORS); +} + +TEST_SUITE(U32) +using NEGEMMTranspose1xWFixture = GEMMTranspose1xWValidationFixture<Tensor, Accessor, NEGEMMTranspose1xW, uint32_t>; +FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * framework::dataset::make("DataType", DataType::U32)) +{ + // Validate output + validate(Accessor(_target), _reference); +} +TEST_SUITE_END() // U32 + +TEST_SUITE(U16) +using NEGEMMTranspose1xWFixture = GEMMTranspose1xWValidationFixture<Tensor, Accessor, NEGEMMTranspose1xW, uint16_t>; +FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * framework::dataset::make("DataType", DataType::U16)) +{ + // Validate output + validate(Accessor(_target), _reference); +} +TEST_SUITE_END() // U16 + +TEST_SUITE(U8) +using NEGEMMTranspose1xWFixture = GEMMTranspose1xWValidationFixture<Tensor, Accessor, NEGEMMTranspose1xW, uint8_t>; +FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * framework::dataset::make("DataType", DataType::U8)) { // Validate output validate(Accessor(_target), _reference); } -TEST_SUITE_END() // FP32 +TEST_SUITE_END() // U8 TEST_SUITE_END() // TRANSPOSE_1XW @@ -117,7 +140,7 @@ DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip( framework::dataset::make("K", { 1, 47, 29, 27 })), m_value, k_value) { - bool status = validate_zero_padding(m_value, k_value); + bool status = validate_zero_padding<NEGEMMInterleave4x4Kernel>(m_value, k_value); ARM_COMPUTE_EXPECT(status, framework::LogLevel::ERRORS); } |