From 304dfdba67958f5987d88ad0ce538399c3e50bc8 Mon Sep 17 00:00:00 2001 From: Adnan AlSinan Date: Wed, 21 Sep 2022 13:20:45 +0100 Subject: Fix Batch Matmul nightly failure Resolves COMPMID-5601 Signed-off-by: Adnan AlSinan Change-Id: I1baf92c4751d784d017c0b2f7de1fc09e42ce69c Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8309 Benchmark: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Gian Marco Iodice Comments-Addressed: Arm Jenkins --- .../kernels/gemm_matrix_mul/generic/neon/impl.cpp | 2 +- tests/validation/NEON/GEMM.cpp | 23 +++++++++++----------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/cpu/kernels/gemm_matrix_mul/generic/neon/impl.cpp b/src/cpu/kernels/gemm_matrix_mul/generic/neon/impl.cpp index 37c296fe5e..300dc3ffc7 100644 --- a/src/cpu/kernels/gemm_matrix_mul/generic/neon/impl.cpp +++ b/src/cpu/kernels/gemm_matrix_mul/generic/neon/impl.cpp @@ -856,7 +856,7 @@ void matrix_matrix_multiply_f16(const ITensor *lhs, const ITensor *rhs, ITensor } // Set step_x and step_y for matrix B. Scale by a factor of 8 the X range as the input transposed matrix A has 8 times less the cols of the dst matrix win_b.set(Window::DimX, Window::Dimension(window.x().start() / 8, window.x().end() / 8, in_b_stride)); - win_b.set(Window::DimY, Window::Dimension(0, 1, 0)); + win_b.set(Window::DimY, Window::Dimension(0, 0, 0)); Iterator ina(lhs, win_a); Iterator inb(rhs, win_b); diff --git a/tests/validation/NEON/GEMM.cpp b/tests/validation/NEON/GEMM.cpp index 0e07371281..127979e60b 100644 --- a/tests/validation/NEON/GEMM.cpp +++ b/tests/validation/NEON/GEMM.cpp @@ -351,6 +351,18 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMFixture, framework::DatasetMode::PR // Validate output validate(Accessor(_target), _reference, rel_tolerance_f16, tolerance_num, abs_tolerance_f16); } + +TEST_SUITE(BATCHED_MATMUL) + +FIXTURE_DATA_TEST_CASE(RunSmall, NEBatchedMatMulFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallBatchedMatMulDataset(), + framework::dataset::make("ReshapeWeights", { false })), + framework::dataset::make("DataType", DataType::F16))) +{ + // Validate output + validate(Accessor(_target), _reference, rel_tolerance_f16, tolerance_num, abs_tolerance_f16); +} +TEST_SUITE_END() + FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMDataset(), framework::dataset::make("ReshapeWeights", { true, false })), @@ -392,17 +404,6 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEBatchedMatMulFixture, framework::Datas } TEST_SUITE_END() -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -TEST_SUITE(FP16) -FIXTURE_DATA_TEST_CASE(RunSmall, NEBatchedMatMulFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallBatchedMatMulDataset(), - framework::dataset::make("ReshapeWeights", { false })), - framework::dataset::make("DataType", DataType::F16))) -{ - // Validate output - validate(Accessor(_target), _reference, rel_tolerance_f16, tolerance_num, abs_tolerance_f16); -} -TEST_SUITE_END() -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ TEST_SUITE_END() TEST_SUITE_END() -- cgit v1.2.1