aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/NEON/GEMMLowp.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/NEON/GEMMLowp.cpp')
-rw-r--r--tests/validation/NEON/GEMMLowp.cpp41
1 files changed, 23 insertions, 18 deletions
diff --git a/tests/validation/NEON/GEMMLowp.cpp b/tests/validation/NEON/GEMMLowp.cpp
index 46058bd148..9c4d1741eb 100644
--- a/tests/validation/NEON/GEMMLowp.cpp
+++ b/tests/validation/NEON/GEMMLowp.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -50,9 +50,12 @@ namespace validation
TEST_SUITE(NEON)
TEST_SUITE(GEMMLowp)
TEST_SUITE(MatrixMultiplyCore)
+
using NEGEMMLowpMatrixMultiplyCoreFixture = GEMMLowpMatrixMultiplyCoreValidationFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore>;
using NEGEMMLowpBatchedMatMulFixture = GEMMLowpMatrixMultiplyCoreValidationFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore, false, false, true>;
+using framework::dataset::make;
+
DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, framework::dataset::concat(datasets::SmallGEMMLowpDataset(), datasets::LargeGEMMLowpDataset()),
shape_a, shape_b, shape_c, a_offset, b_offset)
{
@@ -80,26 +83,26 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, framework::dataset::c
// *INDENT-OFF*
// clang-format off
-DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
- framework::dataset::make("InputAInfo", { TensorInfo(TensorShape(21U, 13U), 1, DataType::QASYMM8, QuantizationInfo(1.f/255, 10)), // Input not a multiple of 4
+DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(
+ make("InputAInfo", { TensorInfo(TensorShape(21U, 13U), 1, DataType::QASYMM8, QuantizationInfo(1.f/255, 10)), // Input not a multiple of 4
TensorInfo(TensorShape(21U, 13U), 1, DataType::S32), // Mismatching data type
TensorInfo(TensorShape(20U, 13U), 1, DataType::QASYMM8, QuantizationInfo(1.f/255, 10)), // Invalid dimensions
TensorInfo(TensorShape(21U, 13U), 1, DataType::QASYMM8, QuantizationInfo(1.f/255, 10)), // Invalid dimensions
TensorInfo(TensorShape(16U, 32U), 1, DataType::QASYMM8, QuantizationInfo(1.f/255, 10)),
}),
- framework::dataset::make("InputBInfo",{ TensorInfo(TensorShape(33U, 21U), 1, DataType::QASYMM8, QuantizationInfo(1.f/256, 10)),
+ make("InputBInfo",{ TensorInfo(TensorShape(33U, 21U), 1, DataType::QASYMM8, QuantizationInfo(1.f/256, 10)),
TensorInfo(TensorShape(33U, 21U), 1, DataType::QASYMM8, QuantizationInfo(1.f/256, 10)),
TensorInfo(TensorShape(33U, 21U), 1, DataType::QASYMM8, QuantizationInfo(1.f/256, 10)),
TensorInfo(TensorShape(33U, 21U), 1, DataType::QASYMM8, QuantizationInfo(1.f/256, 10)),
TensorInfo(TensorShape(64U, 16U), 1, DataType::QASYMM8, QuantizationInfo(1.f/256, 10)),
- })),
- framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(33U, 13U), 1, DataType::S32),
+ }),
+ make("OutputInfo",{ TensorInfo(TensorShape(33U, 13U), 1, DataType::S32),
TensorInfo(TensorShape(33U, 13U), 1, DataType::S32),
TensorInfo(TensorShape(33U, 13U), 1, DataType::S32),
TensorInfo(TensorShape(8U, 11U), 1, DataType::S32),
TensorInfo(TensorShape(64U, 32U), 1, DataType::S32),
- })),
- framework::dataset::make("Expected", { true, false, false, false, true })),
+ }),
+ make("Expected", { true, false, false, false, true })),
a_info, b_info, output_info, expected)
{
// Lock tensors
@@ -231,9 +234,9 @@ using NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedUnsigned =
TEST_SUITE(BatchedMatMul)
TEST_SUITE(QASYMM8)
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedUnsigned, framework::DatasetMode::ALL,
- combine(combine(datasets::SmallGEMMLowpFusedBatchedMatMulDatasetUnsigned(),
- framework::dataset::make("DataType", { DataType::QASYMM8 })),
- framework::dataset::make("bool", { false })))
+ combine(datasets::SmallGEMMLowpFusedBatchedMatMulDataset(),
+ make("DataType", { DataType::QASYMM8 }),
+ make("reshape_b_only_on_first_run", { false })))
{
validate(Accessor(_target), _reference, tolerance_batched);
}
@@ -243,9 +246,9 @@ using NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedSigned =
GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore, false, false, int8_t, int8_t, true>;
TEST_SUITE(QASYMM8_SIGNED)
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedSigned, framework::DatasetMode::ALL,
- combine(combine(datasets::SmallGEMMLowpFusedBatchedMatMulDatasetSigned(),
- framework::dataset::make("DataType", { DataType::QASYMM8_SIGNED })),
- framework::dataset::make("bool", { false })))
+ combine(datasets::SmallGEMMLowpFusedBatchedMatMulDataset(),
+ make("DataType", { DataType::QASYMM8_SIGNED }),
+ make("reshape_b_only_on_first_run", { false })))
{
validate(Accessor(_target), _reference, tolerance_batched);
}
@@ -256,15 +259,17 @@ using NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixture = GEMMLowpMatrixMulti
constexpr AbsoluteTolerance<float> tolerance_quant(1);
TEST_SUITE(FusedOffsetOutput)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixture, framework::DatasetMode::ALL, combine(datasets::SmallGEMMLowpFusedOffsetOutputUint8Dataset(),
- framework::dataset::make("DataType", { DataType::QASYMM8 })))
+FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixture, framework::DatasetMode::ALL,
+ combine(datasets::SmallGEMMLowpFusedOffsetOutputUint8Dataset(),
+ make("DataType", { DataType::QASYMM8 })))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_quant);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixture, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMMLowpFusedOffsetOutputUint8Dataset(),
- framework::dataset::make("DataType", { DataType::QASYMM8 })))
+FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixture, framework::DatasetMode::NIGHTLY,
+ combine(datasets::LargeGEMMLowpFusedOffsetOutputUint8Dataset(),
+ make("DataType", { DataType::QASYMM8 })))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_quant);