aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/CL/MatMulKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/CL/MatMulKernel.cpp')
-rw-r--r--tests/validation/CL/MatMulKernel.cpp49
1 files changed, 35 insertions, 14 deletions
diff --git a/tests/validation/CL/MatMulKernel.cpp b/tests/validation/CL/MatMulKernel.cpp
index ff872aaa0a..b47f8bc924 100644
--- a/tests/validation/CL/MatMulKernel.cpp
+++ b/tests/validation/CL/MatMulKernel.cpp
@@ -75,6 +75,9 @@ const auto k0_values_nightly_lhs_t_rhs_nt = framework::dataset::make("K0", { 1,
template <typename T>
using CLMatMulKernelFixture = MatMulKernelValidationFixture<T, ClMatMulNativeKernel>;
+template <typename T>
+using CLMatMulKernelBiasFixture = MatMulKernelWithBiasValidation<T, ClMatMulNativeKernel>;
+
TEST_SUITE(CL)
TEST_SUITE(MatMulKernel)
TEST_SUITE(Validate)
@@ -162,7 +165,7 @@ TEST_CASE(SupportedBlockSizes, framework::DatasetMode::ALL)
for(auto &pair : supported_block_sizes)
{
TensorInfo output_info;
- Status status = ClMatMulNativeKernel::validate(&lhs_info, &rhs_info, &output_info, pair.first);
+ Status status = ClMatMulNativeKernel::validate(&lhs_info, &rhs_info, nullptr, &output_info, pair.first);
if(!pair.first.export_rhs_to_cl_image || export_to_cl_image_supported)
{
@@ -222,7 +225,7 @@ TEST_CASE(ExportToCLImage, framework::DatasetMode::ALL)
};
TensorInfo output_info;
- Status status = ClMatMulNativeKernel::validate(&lhs_info, &rhs_info, &output_info, matmul_kernel_info);
+ Status status = ClMatMulNativeKernel::validate(&lhs_info, &rhs_info, nullptr, &output_info, matmul_kernel_info);
const bool expected = std::get<4>(tuple);
ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS);
@@ -233,22 +236,25 @@ TEST_CASE(ExportToCLImage, framework::DatasetMode::ALL)
TEST_CASE(ValidateInputShapes, framework::DatasetMode::ALL)
{
// Configurations are assumed to be Nt/Nt, but will be transposed inside the test to test other configurations
- using ShapeConfigurationTuple = std::tuple<TensorShape, TensorShape, bool>;
+ using ShapeConfigurationTuple = std::tuple<TensorShape, TensorShape, TensorShape, bool>;
const std::vector<ShapeConfigurationTuple> shape_configurations =
{
- { TensorShape(5U, 1U), TensorShape(3U, 5U), true },
- { TensorShape(10U, 12U), TensorShape(3U, 10U), true },
- { TensorShape(8U, 4U), TensorShape(2U, 8U), true },
- { TensorShape(8U, 4U), TensorShape(2U, 5U), false }, // Mismatch in the K dimension
- { TensorShape(5U, 0U), TensorShape(2U, 5U), false }, // Invalid dimension
- { TensorShape(5U, 4U, 3U, 4U, 5U, 6U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), true },
- { TensorShape(5U, 4U, 3U, 4U, 5U, 1U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), false }, // no batch broadcasting
- { TensorShape(5U, 4U, 3U, 4U, 9U, 6U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), false }, // mismatch in batch dimension
+ { TensorShape(5U, 1U), TensorShape(3U, 5U), TensorShape(3U), true },
+ { TensorShape(10U, 12U), TensorShape(3U, 10U), TensorShape(3U), true },
+ { TensorShape(8U, 4U), TensorShape(2U, 8U), TensorShape(2U), true },
+ { TensorShape(8U, 4U), TensorShape(2U, 5U), TensorShape(2U), false }, // Mismatch in the K dimension
+ { TensorShape(5U, 0U), TensorShape(2U, 5U), TensorShape(2U), false }, // Invalid dimension
+ { TensorShape(5U, 4U, 3U, 4U, 5U, 6U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), TensorShape(2U), true },
+ { TensorShape(5U, 4U, 3U, 4U, 5U, 1U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), TensorShape(2U), false }, // no batch broadcasting
+ { TensorShape(5U, 4U, 3U, 4U, 9U, 6U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), TensorShape(2U), false }, // mismatch in batch dimension
+ { TensorShape(5U, 1U), TensorShape(3U, 5U), TensorShape(1U), false }, // Unsupported bias broadcasting.
+ { TensorShape(5U, 1U), TensorShape(3U, 5U), TensorShape(3U, 3U), false }, // 2D bias is unsupported.
+ { TensorShape(5U, 1U), TensorShape(3U, 5U), TensorShape(6U), false }, // bias first dimension != dst first dimension
};
for(auto &tuple : shape_configurations)
{
- const bool expected = std::get<2>(tuple);
+ const bool expected = std::get<3>(tuple);
for(bool adj_lhs :
{
@@ -262,6 +268,7 @@ TEST_CASE(ValidateInputShapes, framework::DatasetMode::ALL)
{
TensorShape lhs_shape = std::get<0>(tuple);
TensorShape rhs_shape = std::get<1>(tuple);
+ TensorShape bia_shape = std::get<2>(tuple);
if(adj_lhs)
{
@@ -275,11 +282,12 @@ TEST_CASE(ValidateInputShapes, framework::DatasetMode::ALL)
const TensorInfo lhs_info = TensorInfo(lhs_shape, 1, DataType::F32);
const TensorInfo rhs_info = TensorInfo(rhs_shape, 1, DataType::F32);
+ const TensorInfo bia_info = TensorInfo(bia_shape, 1, DataType::F32);
TensorInfo output_info;
MatMulKernelInfo matmul_kernel_info{ adj_lhs, adj_rhs, 1, 1, 1, false /* export_rhs_to_cl_image */ };
- Status status = ClMatMulNativeKernel::validate(&lhs_info, &rhs_info, &output_info, matmul_kernel_info);
+ Status status = ClMatMulNativeKernel::validate(&lhs_info, &rhs_info, &bia_info, &output_info, matmul_kernel_info);
ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS);
}
}
@@ -322,7 +330,7 @@ TEST_CASE(ValidateDataTypes, framework::DatasetMode::ALL)
const TensorInfo rhs_info(shape, 1, std::get<1>(tuple));
TensorInfo output_info(shape, 1, std::get<2>(tuple));
- Status status = ClMatMulNativeKernel::validate(&lhs_info, &rhs_info, &output_info, matmul_kernel_info);
+ Status status = ClMatMulNativeKernel::validate(&lhs_info, &rhs_info, nullptr, &output_info, matmul_kernel_info);
ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS);
}
}
@@ -356,6 +364,19 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLMatMulKernelFixture<float>, framework::Datase
// Validate output
validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32);
}
+FIXTURE_DATA_TEST_CASE(RunWithBias, CLMatMulKernelBiasFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::SmallMatMulDataset(),
+ framework::dataset::make("TransposeA", { false, true })),
+ framework::dataset::make("TransposeB", { false, true })),
+ m0_values_precommit),
+ n0_values_precommit),
+ k0_values_precommit),
+ framework::dataset::make("ExportRhsToCLImage", { false })),
+ framework::dataset::make("DataType", DataType::F32)))
+
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32);
+}
FIXTURE_DATA_TEST_CASE(RunLargeNoTranspose, CLMatMulKernelFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulDataset(),
framework::dataset::make("TransposeA", { false })),
framework::dataset::make("TransposeB", { false })),