aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdnan AlSinan <adnan.alsinan@arm.com>2022-09-07 13:54:53 +0100
committerAdnan AlSinan <adnan.alsinan@arm.com>2022-09-12 13:12:38 +0000
commit26c9d1a787bccbe0b6c6749b80e2f5030395bda6 (patch)
treea3ec80da0862e616e1163ff92e8587c290848a2e
parent4478e1cb2d7be9190147be597c3cfbf4c6f99f09 (diff)
downloadComputeLibrary-26c9d1a787bccbe0b6c6749b80e2f5030395bda6.tar.gz
Add test for NEGEMM to test a batched matrix multiplication with variable input tensors
- Add a test for CPU to batched matrix multiplication with variable input tensors - Disable assembly kernel when using _reshape_b_only_on_first_run flag Resolves COMPMID-5501 Signed-off-by: Adnan AlSinan <adnan.alsinan@arm.com> Change-Id: If96b182584617806a9dfe597dbfaf05241b123c2 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8234 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--src/cpu/operators/CpuGemm.cpp2
-rw-r--r--tests/datasets/SmallGEMMDataset.h9
-rw-r--r--tests/validation/NEON/GEMM.cpp24
-rw-r--r--tests/validation/fixtures/GEMMFixture.h5
4 files changed, 17 insertions, 23 deletions
diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp
index f6582c73f8..a17e4f31d5 100644
--- a/src/cpu/operators/CpuGemm.cpp
+++ b/src/cpu/operators/CpuGemm.cpp
@@ -65,7 +65,7 @@ void CpuGemm::configure(const ITensorInfo *a, const ITensorInfo *b, const ITenso
const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
const bool is_c_bias = gemm_info.reshape_b_only_on_first_run();
- bool run_optimised = bool(cpu::CpuGemmAssemblyDispatch::validate(a, b, (is_c_bias) ? c : nullptr, d, asm_info));
+ bool run_optimised = bool(cpu::CpuGemmAssemblyDispatch::validate(a, b, (is_c_bias) ? c : nullptr, d, asm_info)) && gemm_info.reshape_b_only_on_first_run();
// Check if we need to reshape the matrix B only on the first run
_is_prepared = false;
diff --git a/tests/datasets/SmallGEMMDataset.h b/tests/datasets/SmallGEMMDataset.h
index 811064bb67..fabddb2ca0 100644
--- a/tests/datasets/SmallGEMMDataset.h
+++ b/tests/datasets/SmallGEMMDataset.h
@@ -51,15 +51,6 @@ public:
}
};
-class SmallBatchedGEMMDataset final : public GEMMDataset
-{
-public:
- SmallBatchedGEMMDataset()
- {
- add_config(TensorShape(2U, 4U, 1U, 3U), TensorShape(5U, 2U, 3U), TensorShape(5U), TensorShape(5U, 4U, 1U, 3U), 1.0f, 0.0f);
- }
-};
-
class SmallGEMMOutput3DDataset final : public GEMMDataset
{
public:
diff --git a/tests/validation/NEON/GEMM.cpp b/tests/validation/NEON/GEMM.cpp
index 5f6e75b705..0e07371281 100644
--- a/tests/validation/NEON/GEMM.cpp
+++ b/tests/validation/NEON/GEMM.cpp
@@ -319,7 +319,7 @@ template <typename T>
using NEGEMMFixture = GEMMValidationFixture<Tensor, Accessor, NEGEMM, T>;
template <typename T>
-using NEGEMMFixtureDisabledC = GEMMValidationFixture<Tensor, Accessor, NEGEMM, T, true>;
+using NEBatchedMatMulFixture = GEMMValidationFixture<Tensor, Accessor, NEGEMM, T, true, false, false, false, false, true>;
TEST_SUITE(Float)
DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip(framework::dataset::make("In0", { TensorShape(21U, 13U),
@@ -379,10 +379,12 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMFixture<float>, framework::DatasetMode::N
// Validate output
validate(Accessor(_target), _reference, tolerance_f);
}
-TEST_SUITE(DisabledC)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMFixtureDisabledC<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMDataset(),
- framework::dataset::make("ReshapeWeights", { true, false })),
+TEST_SUITE(BATCHED_MATMUL)
+
+TEST_SUITE(FP32)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEBatchedMatMulFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallBatchedMatMulDataset(),
+ framework::dataset::make("ReshapeWeights", { false })),
framework::dataset::make("DataType", DataType::F32)))
{
// Validate output
@@ -390,16 +392,18 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMFixtureDisabledC<float>, framework::Datas
}
TEST_SUITE_END()
-TEST_SUITE(BatchedGEMMDisabledC)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMFixtureDisabledC<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallBatchedGEMMDataset(),
- framework::dataset::make("ReshapeWeights", { true, false })),
-
- framework::dataset::make("DataType", DataType::F32)))
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+TEST_SUITE(FP16)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEBatchedMatMulFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallBatchedMatMulDataset(),
+ framework::dataset::make("ReshapeWeights", { false })),
+ framework::dataset::make("DataType", DataType::F16)))
{
// Validate output
- validate(Accessor(_target), _reference, tolerance_f);
+ validate(Accessor(_target), _reference, rel_tolerance_f16, tolerance_num, abs_tolerance_f16);
}
TEST_SUITE_END()
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+TEST_SUITE_END()
TEST_SUITE_END()
TEST_SUITE_END()
diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h
index 0682337c82..5dc2711753 100644
--- a/tests/validation/fixtures/GEMMFixture.h
+++ b/tests/validation/fixtures/GEMMFixture.h
@@ -152,7 +152,6 @@ protected:
DataType data_type)
{
TensorShape shape_a_to_use = shape_a;
-
if(reinterpret_input_as_3d)
{
// Collapse the second and third dimension if the input is 3D
@@ -213,13 +212,13 @@ protected:
reference::gemm<T>((pretranspose_a) ? a_transposed : a, (pretranspose_b) ? b_transposed : b, c, alpha, disable_c ? 0.f : beta);
fill((pretranspose_a) ? a_transposed : a, 3);
fill((pretranspose_b) ? b_transposed : b, 4);
- fill(c , 5);
+ fill(c, 5);
}
// Setting beta to 0 will effectively disable C for the
// computation of the reference: alpha * A * B + 0 * C
// Use transposed tensors if boolean enabled else use original tensors
- auto r = reference::gemm<T>((pretranspose_a) ? a_transposed : a, (pretranspose_b) ? b_transposed : b, c, alpha, disable_c ? 0.f : beta);
+ auto r = reference::gemm<T>((pretranspose_a) ? a_transposed : a, (pretranspose_b) ? b_transposed : b, c, alpha, disable_c ? 0.f : beta);
return r;
}