From 6782452c16a286a6dd4a071cfc70bbbcbabb20be Mon Sep 17 00:00:00 2001 From: Mohammed Suhail Munshi Date: Thu, 29 Sep 2022 13:07:21 +0100 Subject: Add test in GEMMLowp for batch matmul - Adds tests for batched matrix multiplication - Bugfix for issue : 3d tensors input tensors with offsets in GemmLowp results in mismatches Resolves : [COMPMID-5507] Signed-off-by: Mohammed Suhail Munshi Change-Id: I68e036fbca642c1841dd4321033045aadc8f5636 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/c/VisualCompute/ComputeLibrary/+/461298 Comments-Addressed: bsgcomp Tested-by: bsgcomp Reviewed-by: Gunes Bayir Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8482 Tested-by: Arm Jenkins Reviewed-by: Viet-Hoa Do Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins --- .../kernels/CpuGemmLowpOffsetContributionKernel.cpp | 20 ++++++++++++-------- tests/validation/NEON/GEMMLowp.cpp | 10 +++++++++- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.cpp b/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.cpp index a9896772f6..a65f1a33de 100644 --- a/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.cpp +++ b/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -108,7 +108,9 @@ void run_offset_contribution(const Window &window, const int window_end_x = window.x().end(); const int window_step_x = 16; - Iterator mm_result_it(mm_result, collapsed_window); + // if vector_sum_col is nullptr then stride_y is 0, else get stride_y + const size_t sum_col_stride_y = (vector_sum_col != nullptr) ? (vector_sum_col->info()->strides_in_bytes().y()) : 0; + Iterator mm_result_it(mm_result, collapsed_window); if((a_offset != 0) && (b_offset != 0) && (vector_sum_col != nullptr) && (vector_sum_row != nullptr)) // true, true { @@ -133,9 +135,10 @@ void run_offset_contribution(const Window &window, execute_window_loop(collapsed_window, [&](const Coordinates & id) { - const int batch_id = id.z() / depth_input; - auto vector_sum_col_ptr = reinterpret_cast(vector_sum_col_it.ptr() + batch_id * vector_sum_col_batch_offset); - auto mm_result_ptr = reinterpret_cast(mm_result_it.ptr()); + const int batch_id = id.z() / depth_input; + const size_t batch_offset_col = batch_id * (sum_col_stride_y ); + auto vector_sum_col_ptr = reinterpret_cast(vector_sum_col_it.ptr() + batch_offset_col + batch_id * vector_sum_col_batch_offset); + auto mm_result_ptr = reinterpret_cast(mm_result_it.ptr()); // Compute the leftover term due to b_offset. int32_t b_offset_term_s32 = *(reinterpret_cast(vector_sum_row_it.ptr() + batch_id * sum_row_stride_y) + id.y() + (id.z() % depth_input) * height_input); @@ -291,9 +294,10 @@ void run_offset_contribution(const Window &window, execute_window_loop(collapsed_window, [&](const Coordinates & id) { - const int batch_id = id.z() / depth_input; - auto vector_sum_col_ptr = reinterpret_cast(vector_sum_col_it.ptr() + batch_id * vector_sum_col_batch_offset); - auto mm_result_ptr = reinterpret_cast(mm_result_it.ptr()); + const int batch_id = id.z() / depth_input; + const size_t batch_offset_col = batch_id * (sum_col_stride_y ); // Value to offset vector_sum_col_ptr to allow for iteration of y values in tensor + auto vector_sum_col_ptr = reinterpret_cast(vector_sum_col_it.ptr() + batch_offset_col + batch_id * vector_sum_col_batch_offset); + auto mm_result_ptr = reinterpret_cast(mm_result_it.ptr()); int x = window_start_x; for(; x <= (window_end_x - window_step_x); x += window_step_x) diff --git a/tests/validation/NEON/GEMMLowp.cpp b/tests/validation/NEON/GEMMLowp.cpp index 7dd1a479fe..2dcc740b97 100644 --- a/tests/validation/NEON/GEMMLowp.cpp +++ b/tests/validation/NEON/GEMMLowp.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -51,6 +51,7 @@ TEST_SUITE(NEON) TEST_SUITE(GEMMLowp) TEST_SUITE(MatrixMultiplyCore) using NEGEMMLowpMatrixMultiplyCoreFixture = GEMMLowpMatrixMultiplyCoreValidationFixture; +using NEGEMMLowpBatchedMatMulFixture = GEMMLowpMatrixMultiplyCoreValidationFixture; DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, framework::dataset::concat(datasets::SmallGEMMLowpDataset(), datasets::LargeGEMMLowpDataset()), shape_a, shape_b, shape_c, a_offset, b_offset) @@ -210,6 +211,13 @@ TEST_CASE(MultipleExecutionWithConfigure, framework::DatasetMode::ALL) } } +TEST_SUITE(BatchedMatMul) +FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpBatchedMatMulFixture, framework::DatasetMode::ALL, datasets::SmallGEMMLowpBatchedMatMulDataset()) +{ + validate(Accessor(_target), _reference); +} +TEST_SUITE_END() // BatchedMatMul + FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpMatrixMultiplyCoreFixture, framework::DatasetMode::ALL, datasets::SmallGEMMLowpDataset()) { // Validate output -- cgit v1.2.1