aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorMohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com>2023-06-27 14:25:58 +0100
committerMohmun02 <MohammedSuhail.Munshi@arm.com>2023-07-11 08:53:19 +0000
commit8e2dedea8550b1c18c3bbeead8c972f661dcfac8 (patch)
tree61cd0326b9690e343d62a5c72d935fcd68017eb9 /tests
parent5ff480265a110ea1f2ce24491e082f52348b0f92 (diff)
downloadComputeLibrary-8e2dedea8550b1c18c3bbeead8c972f661dcfac8.tar.gz
Add Bias to MatMul Kernels and add support for use in Fully Connected Layer
Resolves: [COMPMID-6316] Signed-off-by: Mohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com> Change-Id: I08e6bac9e6b46b76978da0dc6a48ccfe3dde5086 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9833 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests')
-rw-r--r--tests/validation/CL/MatMulKernel.cpp49
-rw-r--r--tests/validation/CL/MatMulLowpNativeKernel.cpp92
-rw-r--r--tests/validation/CL/MatMulNativeMMULKernel.cpp154
-rw-r--r--tests/validation/fixtures/MatMulKernelFixture.h103
4 files changed, 279 insertions, 119 deletions
diff --git a/tests/validation/CL/MatMulKernel.cpp b/tests/validation/CL/MatMulKernel.cpp
index ff872aaa0a..b47f8bc924 100644
--- a/tests/validation/CL/MatMulKernel.cpp
+++ b/tests/validation/CL/MatMulKernel.cpp
@@ -75,6 +75,9 @@ const auto k0_values_nightly_lhs_t_rhs_nt = framework::dataset::make("K0", { 1,
template <typename T>
using CLMatMulKernelFixture = MatMulKernelValidationFixture<T, ClMatMulNativeKernel>;
+template <typename T>
+using CLMatMulKernelBiasFixture = MatMulKernelWithBiasValidation<T, ClMatMulNativeKernel>;
+
TEST_SUITE(CL)
TEST_SUITE(MatMulKernel)
TEST_SUITE(Validate)
@@ -162,7 +165,7 @@ TEST_CASE(SupportedBlockSizes, framework::DatasetMode::ALL)
for(auto &pair : supported_block_sizes)
{
TensorInfo output_info;
- Status status = ClMatMulNativeKernel::validate(&lhs_info, &rhs_info, &output_info, pair.first);
+ Status status = ClMatMulNativeKernel::validate(&lhs_info, &rhs_info, nullptr, &output_info, pair.first);
if(!pair.first.export_rhs_to_cl_image || export_to_cl_image_supported)
{
@@ -222,7 +225,7 @@ TEST_CASE(ExportToCLImage, framework::DatasetMode::ALL)
};
TensorInfo output_info;
- Status status = ClMatMulNativeKernel::validate(&lhs_info, &rhs_info, &output_info, matmul_kernel_info);
+ Status status = ClMatMulNativeKernel::validate(&lhs_info, &rhs_info, nullptr, &output_info, matmul_kernel_info);
const bool expected = std::get<4>(tuple);
ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS);
@@ -233,22 +236,25 @@ TEST_CASE(ExportToCLImage, framework::DatasetMode::ALL)
TEST_CASE(ValidateInputShapes, framework::DatasetMode::ALL)
{
// Configurations are assumed to be Nt/Nt, but will be transposed inside the test to test other configurations
- using ShapeConfigurationTuple = std::tuple<TensorShape, TensorShape, bool>;
+ using ShapeConfigurationTuple = std::tuple<TensorShape, TensorShape, TensorShape, bool>;
const std::vector<ShapeConfigurationTuple> shape_configurations =
{
- { TensorShape(5U, 1U), TensorShape(3U, 5U), true },
- { TensorShape(10U, 12U), TensorShape(3U, 10U), true },
- { TensorShape(8U, 4U), TensorShape(2U, 8U), true },
- { TensorShape(8U, 4U), TensorShape(2U, 5U), false }, // Mismatch in the K dimension
- { TensorShape(5U, 0U), TensorShape(2U, 5U), false }, // Invalid dimension
- { TensorShape(5U, 4U, 3U, 4U, 5U, 6U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), true },
- { TensorShape(5U, 4U, 3U, 4U, 5U, 1U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), false }, // no batch broadcasting
- { TensorShape(5U, 4U, 3U, 4U, 9U, 6U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), false }, // mismatch in batch dimension
+ { TensorShape(5U, 1U), TensorShape(3U, 5U), TensorShape(3U), true },
+ { TensorShape(10U, 12U), TensorShape(3U, 10U), TensorShape(3U), true },
+ { TensorShape(8U, 4U), TensorShape(2U, 8U), TensorShape(2U), true },
+ { TensorShape(8U, 4U), TensorShape(2U, 5U), TensorShape(2U), false }, // Mismatch in the K dimension
+ { TensorShape(5U, 0U), TensorShape(2U, 5U), TensorShape(2U), false }, // Invalid dimension
+ { TensorShape(5U, 4U, 3U, 4U, 5U, 6U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), TensorShape(2U), true },
+ { TensorShape(5U, 4U, 3U, 4U, 5U, 1U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), TensorShape(2U), false }, // no batch broadcasting
+ { TensorShape(5U, 4U, 3U, 4U, 9U, 6U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), TensorShape(2U), false }, // mismatch in batch dimension
+ { TensorShape(5U, 1U), TensorShape(3U, 5U), TensorShape(1U), false }, // Unsupported bias broadcasting.
+ { TensorShape(5U, 1U), TensorShape(3U, 5U), TensorShape(3U, 3U), false }, // 2D bias is unsupported.
+ { TensorShape(5U, 1U), TensorShape(3U, 5U), TensorShape(6U), false }, // bias first dimension != dst first dimension
};
for(auto &tuple : shape_configurations)
{
- const bool expected = std::get<2>(tuple);
+ const bool expected = std::get<3>(tuple);
for(bool adj_lhs :
{
@@ -262,6 +268,7 @@ TEST_CASE(ValidateInputShapes, framework::DatasetMode::ALL)
{
TensorShape lhs_shape = std::get<0>(tuple);
TensorShape rhs_shape = std::get<1>(tuple);
+ TensorShape bia_shape = std::get<2>(tuple);
if(adj_lhs)
{
@@ -275,11 +282,12 @@ TEST_CASE(ValidateInputShapes, framework::DatasetMode::ALL)
const TensorInfo lhs_info = TensorInfo(lhs_shape, 1, DataType::F32);
const TensorInfo rhs_info = TensorInfo(rhs_shape, 1, DataType::F32);
+ const TensorInfo bia_info = TensorInfo(bia_shape, 1, DataType::F32);
TensorInfo output_info;
MatMulKernelInfo matmul_kernel_info{ adj_lhs, adj_rhs, 1, 1, 1, false /* export_rhs_to_cl_image */ };
- Status status = ClMatMulNativeKernel::validate(&lhs_info, &rhs_info, &output_info, matmul_kernel_info);
+ Status status = ClMatMulNativeKernel::validate(&lhs_info, &rhs_info, &bia_info, &output_info, matmul_kernel_info);
ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS);
}
}
@@ -322,7 +330,7 @@ TEST_CASE(ValidateDataTypes, framework::DatasetMode::ALL)
const TensorInfo rhs_info(shape, 1, std::get<1>(tuple));
TensorInfo output_info(shape, 1, std::get<2>(tuple));
- Status status = ClMatMulNativeKernel::validate(&lhs_info, &rhs_info, &output_info, matmul_kernel_info);
+ Status status = ClMatMulNativeKernel::validate(&lhs_info, &rhs_info, nullptr, &output_info, matmul_kernel_info);
ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS);
}
}
@@ -356,6 +364,19 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLMatMulKernelFixture<float>, framework::Datase
// Validate output
validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32);
}
+FIXTURE_DATA_TEST_CASE(RunWithBias, CLMatMulKernelBiasFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::SmallMatMulDataset(),
+ framework::dataset::make("TransposeA", { false, true })),
+ framework::dataset::make("TransposeB", { false, true })),
+ m0_values_precommit),
+ n0_values_precommit),
+ k0_values_precommit),
+ framework::dataset::make("ExportRhsToCLImage", { false })),
+ framework::dataset::make("DataType", DataType::F32)))
+
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32);
+}
FIXTURE_DATA_TEST_CASE(RunLargeNoTranspose, CLMatMulKernelFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulDataset(),
framework::dataset::make("TransposeA", { false })),
framework::dataset::make("TransposeB", { false })),
diff --git a/tests/validation/CL/MatMulLowpNativeKernel.cpp b/tests/validation/CL/MatMulLowpNativeKernel.cpp
index fd7a4cb156..90eee4fb82 100644
--- a/tests/validation/CL/MatMulLowpNativeKernel.cpp
+++ b/tests/validation/CL/MatMulLowpNativeKernel.cpp
@@ -49,6 +49,9 @@ constexpr AbsoluteTolerance<float> tolerance_quant(1); /**< Tolerance value for
template <typename T>
using CLMatMulLowpNativeKernelFixture = MatMulKernelValidationFixture<T, ClMatMulLowpNativeKernel>;
+template <typename T>
+using CLMatMulLowpKernelWithBiasFixture = MatMulKernelWithBiasValidation<T, ClMatMulLowpNativeKernel>;
+
/** M0 values to test --precommit*/
const auto m0_values_precommit = framework::dataset::make("M0", { 1, 3 });
@@ -103,7 +106,7 @@ TEST_CASE(SupportedKernelConfigurations, framework::DatasetMode::ALL)
for(auto &pair : supported_block_sizes)
{
TensorInfo output_info;
- Status status = ClMatMulLowpNativeKernel::validate(&lhs_info, &rhs_info, &output_info, pair.first);
+ Status status = ClMatMulLowpNativeKernel::validate(&lhs_info, &rhs_info, nullptr, &output_info, pair.first);
ARM_COMPUTE_EXPECT(bool(status) == pair.second, framework::LogLevel::ERRORS);
}
@@ -112,22 +115,24 @@ TEST_CASE(SupportedKernelConfigurations, framework::DatasetMode::ALL)
TEST_CASE(ValidateInputShapes, framework::DatasetMode::ALL)
{
// Configurations are assumed to be Nt/Nt, but will be transposed inside the test to test other configurations
- using ShapeConfigurationTuple = std::tuple<TensorShape, TensorShape, bool>;
+ using ShapeConfigurationTuple = std::tuple<TensorShape, TensorShape, TensorShape, bool>;
const std::vector<ShapeConfigurationTuple> shape_configurations =
{
- { TensorShape(5U, 1U), TensorShape(3U, 5U), true },
- { TensorShape(10U, 12U), TensorShape(3U, 10U), true },
- { TensorShape(8U, 4U), TensorShape(2U, 8U), true },
- { TensorShape(8U, 4U), TensorShape(2U, 5U), false }, // Mismatch in the K dimension
- { TensorShape(5U, 0U), TensorShape(2U, 5U), false }, // Invalid dimension
- { TensorShape(5U, 4U, 3U, 4U, 5U, 6U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), true },
- { TensorShape(5U, 4U, 3U, 4U, 5U, 1U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), false }, // no batch broadcasting
- { TensorShape(5U, 4U, 3U, 4U, 9U, 6U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), false }, // mismatch in batch dimension
+ { TensorShape(5U, 1U), TensorShape(3U, 5U), TensorShape(3U), true },
+ { TensorShape(10U, 12U), TensorShape(3U, 10U), TensorShape(3U), true },
+ { TensorShape(8U, 4U), TensorShape(2U, 8U), TensorShape(2U), true },
+ { TensorShape(8U, 4U), TensorShape(2U, 5U), TensorShape(2U), false }, // Mismatch in the K dimension
+ { TensorShape(5U, 0U), TensorShape(2U, 5U), TensorShape(2U), false }, // Invalid dimension
+ { TensorShape(5U, 4U, 3U, 4U, 5U, 6U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), TensorShape(2U), true },
+ { TensorShape(5U, 4U, 3U, 4U, 5U, 1U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), TensorShape(2U), false }, // no batch broadcasting
+ { TensorShape(5U, 4U, 3U, 4U, 9U, 6U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), TensorShape(2U), false }, // mismatch in batch dimension
+ { TensorShape(5U, 1U), TensorShape(3U, 5U), TensorShape(1U), false }, // invalid broadcast of bias
+ { TensorShape(5U, 1U), TensorShape(3U, 5U), TensorShape(3U, 3U), false }, // 2d bias is invalid
};
for(auto &tuple : shape_configurations)
{
- const bool expected = std::get<2>(tuple);
+ const bool expected = std::get<3>(tuple);
for(bool adj_lhs :
{
@@ -141,6 +146,7 @@ TEST_CASE(ValidateInputShapes, framework::DatasetMode::ALL)
{
TensorShape lhs_shape = std::get<0>(tuple);
TensorShape rhs_shape = std::get<1>(tuple);
+ TensorShape bia_shape = std::get<2>(tuple);
if(adj_lhs)
{
@@ -154,11 +160,12 @@ TEST_CASE(ValidateInputShapes, framework::DatasetMode::ALL)
const TensorInfo lhs_info = TensorInfo(lhs_shape, 1, DataType::QASYMM8_SIGNED);
const TensorInfo rhs_info = TensorInfo(rhs_shape, 1, DataType::QASYMM8_SIGNED);
+ const TensorInfo bia_info = TensorInfo(bia_shape, 1, DataType::S32);
TensorInfo output_info;
MatMulKernelInfo matmul_kernel_info{ adj_lhs, adj_rhs, 1, 1, 1, false /* export_rhs_to_cl_image */ };
- Status status = ClMatMulLowpNativeKernel::validate(&lhs_info, &rhs_info, &output_info, matmul_kernel_info);
+ Status status = ClMatMulLowpNativeKernel::validate(&lhs_info, &rhs_info, &bia_info, &output_info, matmul_kernel_info);
ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS);
}
}
@@ -167,41 +174,44 @@ TEST_CASE(ValidateInputShapes, framework::DatasetMode::ALL)
TEST_CASE(ValidateDataTypes, framework::DatasetMode::ALL)
{
- using DataTypeConfigurationTuple = std::tuple<DataType, DataType, DataType, bool>;
+ using DataTypeConfigurationTuple = std::tuple<DataType, DataType, DataType, DataType, bool>;
const std::vector<DataTypeConfigurationTuple> data_type_configurations =
{
- { DataType::F32, DataType::F32, DataType::F32, false }, // no floating point types
- { DataType::F16, DataType::F16, DataType::F16, false }, // no floating point types
- { DataType::F64, DataType::F64, DataType::F64, false }, // no double precision
- { DataType::QASYMM8, DataType::QASYMM8, DataType::QASYMM8, true },
- { DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, true },
- { DataType::QSYMM8_PER_CHANNEL, DataType::QSYMM8_PER_CHANNEL, DataType::QSYMM8_PER_CHANNEL, false }, // only qasymm8/qasymm8_signed is supported
- { DataType::QASYMM16, DataType::QASYMM16, DataType::QASYMM16, false }, // only qasymm8/qasymm8_signed is supported
- { DataType::QSYMM16, DataType::QSYMM16, DataType::QSYMM16, false }, // only qasymm8/qasymm8_signed is supported
- { DataType::QSYMM8, DataType::QSYMM8, DataType::QSYMM8, false }, // only qasymm8/qasymm8_signed is supported
- { DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QASYMM8, false }, // no mixed data types
- { DataType::S64, DataType::S64, DataType::S64, false }, // no integral types
- { DataType::S32, DataType::S32, DataType::S32, false }, // no integral types
- { DataType::S16, DataType::S16, DataType::S16, false }, // no integral types
- { DataType::S8, DataType::S8, DataType::S8, false }, // no integral types
- { DataType::U64, DataType::U64, DataType::U64, false }, // no integral types
- { DataType::U32, DataType::U32, DataType::U32, false }, // no integral types
- { DataType::U16, DataType::U16, DataType::U16, false }, // no integral types
- { DataType::U8, DataType::U8, DataType::U8, false }, // no integral types
+ { DataType::F32, DataType::F32, DataType::F32, DataType::F32, false }, // no floating point types
+ { DataType::F16, DataType::F16, DataType::F16, DataType::F16, false }, // no floating point types
+ { DataType::F64, DataType::F64, DataType::F64, DataType::F64, false }, // no double precision
+ { DataType::QASYMM8, DataType::QASYMM8, DataType::S32, DataType::QASYMM8, true },
+ { DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, DataType::S32, DataType::QASYMM8_SIGNED, true },
+ { DataType::QSYMM8_PER_CHANNEL, DataType::QSYMM8_PER_CHANNEL, DataType::S32, DataType::QSYMM8_PER_CHANNEL, false }, // only qasymm8/qasymm8_signed is supported
+ { DataType::QASYMM16, DataType::QASYMM16, DataType::S32, DataType::QASYMM16, false }, // only qasymm8/qasymm8_signed is supported
+ { DataType::QSYMM16, DataType::QSYMM16, DataType::S32, DataType::QSYMM16, false }, // only qasymm8/qasymm8_signed is supported
+ { DataType::QSYMM8, DataType::QSYMM8, DataType::S32, DataType::QSYMM8, false }, // only qasymm8/qasymm8_signed is supported
+ { DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S32, DataType::QASYMM8, false }, // no mixed data types
+ { DataType::S64, DataType::S64, DataType::S64, DataType::S64, false }, // no integral types
+ { DataType::S32, DataType::S32, DataType::S32, DataType::S32, false }, // no integral types
+ { DataType::S16, DataType::S16, DataType::S16, DataType::S16, false }, // no integral types
+ { DataType::S8, DataType::S8, DataType::S8, DataType::S8, false }, // no integral types
+ { DataType::U64, DataType::U64, DataType::U64, DataType::U64, false }, // no integral types
+ { DataType::U32, DataType::U32, DataType::U32, DataType::U32, false }, // no integral types
+ { DataType::U16, DataType::U16, DataType::U16, DataType::U16, false }, // no integral types
+ { DataType::U8, DataType::U8, DataType::U8, DataType::U8, false }, // no integral types
+ { DataType::QASYMM8, DataType::QASYMM8, DataType::F32, DataType::QASYMM8, false } // Only S32 bias is supported
};
// It's enough to test a single shape and block size configuration while checking data types
- const TensorShape shape = TensorShape(10U, 10U);
+ const TensorShape shape = TensorShape(10U, 10U);
+ const TensorShape bia_shape = TensorShape(10U);
const MatMulKernelInfo matmul_kernel_info{ false, false, 1, 1, 1, false };
for(auto &tuple : data_type_configurations)
{
- const bool expected = std::get<3>(tuple);
+ const bool expected = std::get<4>(tuple);
const TensorInfo lhs_info(shape, 1, std::get<0>(tuple));
const TensorInfo rhs_info(shape, 1, std::get<1>(tuple));
- TensorInfo output_info(shape, 1, std::get<2>(tuple));
+ const TensorInfo bia_info(bia_shape, 1, std::get<2>(tuple));
+ TensorInfo output_info(shape, 1, std::get<3>(tuple));
- Status status = ClMatMulLowpNativeKernel::validate(&lhs_info, &rhs_info, &output_info, matmul_kernel_info);
+ Status status = ClMatMulLowpNativeKernel::validate(&lhs_info, &rhs_info, &bia_info, &output_info, matmul_kernel_info);
ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS);
}
}
@@ -234,6 +244,18 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLMatMulLowpNativeKernelFixture<int8_t>, framew
// Validate output
validate(CLAccessor(_target), _reference, tolerance_quant);
}
+FIXTURE_DATA_TEST_CASE(RunWithBias, CLMatMulLowpKernelWithBiasFixture<int8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::SmallMatMulDataset(),
+ framework::dataset::make("TransposeA", { true, false })),
+ framework::dataset::make("TransposeB", { true, false })),
+ m0_values_precommit),
+ n0_values_precommit),
+ k0_values_precommit),
+ framework::dataset::make("ExportRhsToCLImage", { false })),
+ framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_quant);
+}
FIXTURE_DATA_TEST_CASE(RunLargeNoTranspose, CLMatMulLowpNativeKernelFixture<int8_t>, framework::DatasetMode::NIGHTLY,
combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulDataset(),
framework::dataset::make("TransposeA", { false })),
diff --git a/tests/validation/CL/MatMulNativeMMULKernel.cpp b/tests/validation/CL/MatMulNativeMMULKernel.cpp
index b63af75169..70c80985db 100644
--- a/tests/validation/CL/MatMulNativeMMULKernel.cpp
+++ b/tests/validation/CL/MatMulNativeMMULKernel.cpp
@@ -70,6 +70,9 @@ const auto k0_value = framework::dataset::make("K0", { 1 });
template <typename T>
using CLMatMulNativeMMULKernelFixture = MatMulKernelValidationFixture<T, ClMatMulNativeMMULKernel, true /*use_mmul*/>;
+template <typename T>
+using CLMatMulKernelBiasFixture = MatMulKernelWithBiasValidation<T, ClMatMulNativeMMULKernel, true /*use_mmul*/>;
+
TEST_SUITE(CL)
TEST_SUITE(MatMulNativeMMULKernel)
TEST_SUITE(Validate)
@@ -117,7 +120,7 @@ TEST_CASE(SupportedBlockSizes, framework::DatasetMode::ALL)
{ MatMulKernelInfo(true, true, 3, 7, 1), false }, // N0 not in {1, 2, 3, 4, 8, 16}
{ MatMulKernelInfo(true, true, 6, 3, 1), false }, // M0 not in {1, 2, 3, 4, 8, 16}
{ MatMulKernelInfo(true, true, 5, 3, 1), false }, // M0 not in {1, 2, 3, 4, 8, 16}
- { MatMulKernelInfo(true, true, 4, 8, 2), false }, // K0 is not 1
+ { MatMulKernelInfo(true, true, 4, 8, 2), false }, // K0 is not 1
{ MatMulKernelInfo(true, true, 4, 8, 1), true },
{ MatMulKernelInfo(true, true, 3, 3, 1), true },
{ MatMulKernelInfo(true, true, 16, 4, 1), true },
@@ -132,7 +135,7 @@ TEST_CASE(SupportedBlockSizes, framework::DatasetMode::ALL)
for(auto &pair : supported_block_sizes)
{
TensorInfo output_info;
- Status status = ClMatMulNativeMMULKernel::validate(&lhs_info, &rhs_info, &output_info, pair.first);
+ Status status = ClMatMulNativeMMULKernel::validate(&lhs_info, &rhs_info, nullptr, &output_info, pair.first);
ARM_COMPUTE_EXPECT(bool(status) == pair.second, framework::LogLevel::ERRORS);
}
}
@@ -148,28 +151,30 @@ TEST_CASE(ValidateInputShapes, framework::DatasetMode::ALL)
if(arm_matrix_multiply_supported(CLKernelLibrary::get().get_device()))
{
// Configurations are assumed to be Nt/Nt, but will be transposed inside the test to test other configurations
- using ShapeConfigurationTuple = std::tuple<TensorShape, TensorShape, bool>;
+ using ShapeConfigurationTuple = std::tuple<TensorShape, TensorShape, TensorShape, bool>; // lhs, rhs, bias, result
const std::vector<ShapeConfigurationTuple> shape_configurations =
{
- { TensorShape(4U, 1U), TensorShape(3U, 4U), true },
- { TensorShape(12U, 12U), TensorShape(3U, 12U), true },
- { TensorShape(8U, 4U), TensorShape(2U, 8U), true },
- { TensorShape(8U, 4U), TensorShape(2U, 4U), false }, // Mismatch in the K dimension
- { TensorShape(5U, 0U), TensorShape(2U, 5U), false }, // Invalid dimension
- { TensorShape(5U, 7U), TensorShape(2U, 5U), false }, // K not a multiple of 4 (MMUL_K0)
- { TensorShape(8U, 4U, 3U, 4U, 5U, 6U), TensorShape(2U, 8U, 3U, 4U, 5U, 6U), true },
- { TensorShape(5U, 4U, 3U, 4U, 5U, 1U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), false }, // No batch broadcasting
- { TensorShape(5U, 4U, 3U, 4U, 9U, 6U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), false }, // Mismatch in batch dimension
+ { TensorShape(4U, 1U), TensorShape(3U, 4U), TensorShape(3U), true },
+ { TensorShape(12U, 12U), TensorShape(3U, 12U), TensorShape(3U), true },
+ { TensorShape(8U, 4U), TensorShape(2U, 8U), TensorShape(2U), true },
+ { TensorShape(8U, 4U), TensorShape(2U, 4U), TensorShape(2U), false }, // Mismatch in the K dimension
+ { TensorShape(5U, 0U), TensorShape(2U, 5U), TensorShape(2U), false }, // Invalid dimension
+ { TensorShape(5U, 7U), TensorShape(2U, 5U), TensorShape(2U), false }, // K not a multiple of 4 (MMUL_K0)
+ { TensorShape(8U, 4U, 3U, 4U, 5U, 6U), TensorShape(2U, 8U, 3U, 4U, 5U, 6U), TensorShape(2U), true },
+ { TensorShape(5U, 4U, 3U, 4U, 5U, 1U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), TensorShape(2U), false }, // No batch broadcasting
+ { TensorShape(5U, 4U, 3U, 4U, 9U, 6U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), TensorShape(2U), false }, // Mismatch in batch dimension
+ { TensorShape(4U, 1U), TensorShape(3U, 4U), TensorShape(1U), false }, // Bias first dimensions != dst first dimension.
+ { TensorShape(4U, 1U), TensorShape(3U, 4U), TensorShape(5U, 6U), false }, // Bias is 2d which is invalid.
};
for(auto &tuple : shape_configurations)
{
- const bool expected = std::get<2>(tuple);
+ const bool expected = std::get<3>(tuple);
- for(bool adj_lhs :
- {
- false, true
- })
+ for(bool adj_lhs :
+ {
+ false, true
+ })
{
for(bool adj_rhs :
{
@@ -178,6 +183,7 @@ TEST_CASE(ValidateInputShapes, framework::DatasetMode::ALL)
{
TensorShape lhs_shape = std::get<0>(tuple);
TensorShape rhs_shape = std::get<1>(tuple);
+ TensorShape bia_shape = std::get<2>(tuple);
if(adj_lhs)
{
@@ -191,11 +197,12 @@ TEST_CASE(ValidateInputShapes, framework::DatasetMode::ALL)
const TensorInfo lhs_info = TensorInfo(lhs_shape, 1, DataType::F32);
const TensorInfo rhs_info = TensorInfo(rhs_shape, 1, DataType::F32);
+ const TensorInfo bia_info = TensorInfo(bia_shape, 1, DataType::F32);
TensorInfo output_info;
MatMulKernelInfo matmul_kernel_info{ adj_lhs, adj_rhs, 1, 1, 1, false /* export_rhs_to_cl_image */ };
- Status status = ClMatMulNativeMMULKernel::validate(&lhs_info, &rhs_info, &output_info, matmul_kernel_info);
+ Status status = ClMatMulNativeMMULKernel::validate(&lhs_info, &rhs_info, &bia_info, &output_info, matmul_kernel_info);
ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS);
}
}
@@ -213,40 +220,44 @@ TEST_CASE(ValidateDataTypes, framework::DatasetMode::ALL)
if(arm_matrix_multiply_supported(CLKernelLibrary::get().get_device()))
{
// Configurations are assumed to be Nt/Nt, but will be transposed inside the test to test other configurations
- using DataTypeConfigurationTuple = std::tuple<DataType, DataType, DataType, bool>;
+ using DataTypeConfigurationTuple = std::tuple<DataType, DataType, DataType, DataType, bool>;
const std::vector<DataTypeConfigurationTuple> data_type_configurations =
{
- { DataType::F32, DataType::F32, DataType::F32, true },
- { DataType::F16, DataType::F16, DataType::F16, true },
- { DataType::F16, DataType::F32, DataType::F32, false }, // no mixed precision
- { DataType::F64, DataType::F64, DataType::F64, false }, // no double precision
- { DataType::QASYMM8, DataType::QASYMM8, DataType::QASYMM8, false }, // no quantized types
- { DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, false }, // no quantized types
- { DataType::QSYMM8_PER_CHANNEL, DataType::QSYMM8_PER_CHANNEL, DataType::QSYMM8_PER_CHANNEL, false }, // no quantized types
- { DataType::QASYMM16, DataType::QASYMM16, DataType::QASYMM16, false }, // no quantized types
- { DataType::QSYMM16, DataType::QSYMM16, DataType::QSYMM16, false }, // no quantized types
- { DataType::QSYMM8, DataType::QSYMM8, DataType::QSYMM8, false }, // no quantized types
- { DataType::S64, DataType::S64, DataType::S64, false }, // no integral types
- { DataType::S32, DataType::S32, DataType::S32, false }, // no integral types
- { DataType::S16, DataType::S16, DataType::S16, false }, // no integral types
- { DataType::S8, DataType::S8, DataType::S8, false }, // no integral types
- { DataType::U64, DataType::U64, DataType::U64, false }, // no integral types
- { DataType::U32, DataType::U32, DataType::U32, false }, // no integral types
- { DataType::U16, DataType::U16, DataType::U16, false }, // no integral types
- { DataType::U8, DataType::U8, DataType::U8, false }, // no integral types
+ { DataType::F32, DataType::F32, DataType::F32, DataType::F32, true },
+ { DataType::F16, DataType::F16, DataType::F16, DataType::F16, true },
+ { DataType::F32, DataType::F32, DataType::F32, DataType::F32, true },
+ { DataType::F32, DataType::F32, DataType::F16, DataType::F32, false }, // incorrect bias type
+ { DataType::F16, DataType::F32, DataType::F32, DataType::F32, false }, // no mixed precision
+ { DataType::F64, DataType::F64, DataType::F64, DataType::F64, false }, // no double precision
+ { DataType::QASYMM8, DataType::QASYMM8, DataType::S32, DataType::QASYMM8, false }, // no quantized types
+ { DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, DataType::S32, DataType::QASYMM8_SIGNED, false }, // no quantized types
+ { DataType::QSYMM8_PER_CHANNEL, DataType::QSYMM8_PER_CHANNEL, DataType::S32, DataType::QSYMM8_PER_CHANNEL, false }, // no quantized types
+ { DataType::QASYMM16, DataType::QASYMM16, DataType::S32, DataType::QASYMM16, false }, // no quantized types
+ { DataType::QSYMM16, DataType::QSYMM16, DataType::S32, DataType::QSYMM16, false }, // no quantized types
+ { DataType::QSYMM8, DataType::QSYMM8, DataType::S32, DataType::QSYMM8, false }, // no quantized types
+ { DataType::S64, DataType::S64, DataType::S64, DataType::S64, false }, // no integral types
+ { DataType::S32, DataType::S32, DataType::S32, DataType::S32, false }, // no integral types
+ { DataType::S16, DataType::S16, DataType::S16, DataType::S16, false }, // no integral types
+ { DataType::S8, DataType::S8, DataType::S8, DataType::S8, false }, // no integral types
+ { DataType::U64, DataType::U64, DataType::U64, DataType::U64, false }, // no integral types
+ { DataType::U32, DataType::U32, DataType::U32, DataType::U32, false }, // no integral types
+ { DataType::U16, DataType::U16, DataType::U16, DataType::U16, false }, // no integral types
+ { DataType::U8, DataType::U8, DataType::U8, DataType::U8, false }, // no integral types
};
- const TensorShape shape = TensorShape(8U, 8U);
+ const TensorShape shape = TensorShape(8U, 8U);
+ const TensorShape bia_shape = TensorShape(8U);
const MatMulKernelInfo matmul_kernel_info{ false, false, 1, 1, 1, false };
for(auto &tuple : data_type_configurations)
{
- const bool expected = std::get<3>(tuple);
+ const bool expected = std::get<4>(tuple);
const TensorInfo lhs_info(shape, 1, std::get<0>(tuple));
const TensorInfo rhs_info(shape, 1, std::get<1>(tuple));
- TensorInfo output_info(shape, 1, std::get<2>(tuple));
+ const TensorInfo bia_info(bia_shape, 1, std::get<2>(tuple));
+ TensorInfo output_info(shape, 1, std::get<3>(tuple));
- Status status = ClMatMulNativeMMULKernel::validate(&lhs_info, &rhs_info, &output_info, matmul_kernel_info);
+ Status status = ClMatMulNativeMMULKernel::validate(&lhs_info, &rhs_info, &bia_info, &output_info, matmul_kernel_info);
ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS);
}
}
@@ -292,7 +303,23 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLMatMulNativeMMULKernelFixture<float>, framewo
validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32);
}
}
-FIXTURE_DATA_TEST_CASE(RunLargeNoTranspose, CLMatMulNativeMMULKernelFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulMMULDataset(),
+FIXTURE_DATA_TEST_CASE(RunWithBias, CLMatMulKernelBiasFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::SmallMatMulMMULDataset(),
+ framework::dataset::make("TransposeA", { false, true })),
+ framework::dataset::make("TransposeB", { false, true })),
+ m0_values_precommit),
+ n0_values_precommit),
+ k0_value),
+ framework::dataset::make("ExportRhsToCLImage", { false })),
+ framework::dataset::make("DataType", DataType::F32)))
+{
+ // Validate output
+ if(_device_supports_mmul)
+ {
+ validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32);
+ }
+}
+FIXTURE_DATA_TEST_CASE(RunLargeNoTranspose, CLMatMulNativeMMULKernelFixture<float>, framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulMMULDataset(),
framework::dataset::make("TransposeA", { false })),
framework::dataset::make("TransposeB", { false })),
m0_values_nightly_lhs_nt),
@@ -308,7 +335,8 @@ FIXTURE_DATA_TEST_CASE(RunLargeNoTranspose, CLMatMulNativeMMULKernelFixture<floa
}
}
-FIXTURE_DATA_TEST_CASE(RunLargeRhsTranspose, CLMatMulNativeMMULKernelFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulMMULDataset(),
+FIXTURE_DATA_TEST_CASE(RunLargeRhsTranspose, CLMatMulNativeMMULKernelFixture<float>, framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulMMULDataset(),
framework::dataset::make("TransposeA", { false })),
framework::dataset::make("TransposeB", { true })),
m0_values_nightly_lhs_nt),
@@ -323,14 +351,15 @@ FIXTURE_DATA_TEST_CASE(RunLargeRhsTranspose, CLMatMulNativeMMULKernelFixture<flo
validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32);
}
}
-FIXTURE_DATA_TEST_CASE(RunLargeLhsTransposed, CLMatMulNativeMMULKernelFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulMMULDataset(),
- framework::dataset::make("TransposeA", { true })),
- framework::dataset::make("TransposeB", { false })),
- m0_values_nightly_lhs_t),
- n0_values_nightly_rhs_nt),
- k0_value),
- framework::dataset::make("ExportRhsToCLImage", { false })),
- framework::dataset::make("DataType", DataType::F32)))
+FIXTURE_DATA_TEST_CASE(RunLargeLhsTransposed, CLMatMulNativeMMULKernelFixture<float>, framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulMMULDataset(),
+ framework::dataset::make("TransposeA", { true })),
+ framework::dataset::make("TransposeB", { false })),
+ m0_values_nightly_lhs_t),
+ n0_values_nightly_rhs_nt),
+ k0_value),
+ framework::dataset::make("ExportRhsToCLImage", { false })),
+ framework::dataset::make("DataType", DataType::F32)))
{
// Validate output
// Validate output
@@ -395,7 +424,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLMatMulNativeMMULKernelFixture<half>, framewor
validate(CLAccessor(_target), _reference, tolerance_f16, 0.f, abs_tolerance_f16);
}
}
-FIXTURE_DATA_TEST_CASE(RunLargeNoTranspose, CLMatMulNativeMMULKernelFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulMMULDataset(),
+FIXTURE_DATA_TEST_CASE(RunLargeNoTranspose, CLMatMulNativeMMULKernelFixture<half>, framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulMMULDataset(),
framework::dataset::make("TransposeA", { false })),
framework::dataset::make("TransposeB", { false })),
m0_values_nightly_lhs_nt),
@@ -410,7 +440,8 @@ FIXTURE_DATA_TEST_CASE(RunLargeNoTranspose, CLMatMulNativeMMULKernelFixture<half
validate(CLAccessor(_target), _reference, tolerance_f16, 0.f, abs_tolerance_f16);
}
}
-FIXTURE_DATA_TEST_CASE(RunLargeRhsTranspose, CLMatMulNativeMMULKernelFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulMMULDataset(),
+FIXTURE_DATA_TEST_CASE(RunLargeRhsTranspose, CLMatMulNativeMMULKernelFixture<half>, framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulMMULDataset(),
framework::dataset::make("TransposeA", { false })),
framework::dataset::make("TransposeB", { true })),
m0_values_nightly_lhs_nt),
@@ -425,14 +456,15 @@ FIXTURE_DATA_TEST_CASE(RunLargeRhsTranspose, CLMatMulNativeMMULKernelFixture<hal
validate(CLAccessor(_target), _reference, tolerance_f16, 0.f, abs_tolerance_f16);
}
}
-FIXTURE_DATA_TEST_CASE(RunLargeLhsTransposed, CLMatMulNativeMMULKernelFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulMMULDataset(),
- framework::dataset::make("TransposeA", { true })),
- framework::dataset::make("TransposeB", { false })),
- m0_values_nightly_lhs_t),
- n0_values_nightly_rhs_nt),
- k0_value),
- framework::dataset::make("ExportRhsToCLImage", { false })),
- framework::dataset::make("DataType", DataType::F16)))
+FIXTURE_DATA_TEST_CASE(RunLargeLhsTransposed, CLMatMulNativeMMULKernelFixture<half>, framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulMMULDataset(),
+ framework::dataset::make("TransposeA", { true })),
+ framework::dataset::make("TransposeB", { false })),
+ m0_values_nightly_lhs_t),
+ n0_values_nightly_rhs_nt),
+ k0_value),
+ framework::dataset::make("ExportRhsToCLImage", { false })),
+ framework::dataset::make("DataType", DataType::F16)))
{
// Validate output
// Validate output
diff --git a/tests/validation/fixtures/MatMulKernelFixture.h b/tests/validation/fixtures/MatMulKernelFixture.h
index 59bcfe5b2d..88fdf8b291 100644
--- a/tests/validation/fixtures/MatMulKernelFixture.h
+++ b/tests/validation/fixtures/MatMulKernelFixture.h
@@ -36,7 +36,7 @@
#include "tests/validation/reference/GEMMLowp.h"
#include "tests/validation/reference/Permute.h"
#include "tests/validation/reference/ReshapeLayer.h"
-
+#include <cmath>
#include <random>
namespace arm_compute
@@ -48,12 +48,16 @@ namespace validation
using namespace arm_compute::opencl::kernels;
template <typename T, typename KernelType, bool use_mmul = false>
-class MatMulKernelValidationFixture : public framework::Fixture
+class MatMulKernelGenericValidationFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool pretranspose_a, bool pretranspose_b, int M0, int N0, int K0, bool export_rhs_to_cl_image, DataType data_type)
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool pretranspose_a, bool pretranspose_b, int M0, int N0, int K0, bool export_rhs_to_cl_image, DataType data_type,
+ bool enable_bias)
{
+ // Flag to create a bias
+ _enable_bias = enable_bias;
+
// For brevity, the input shapes are assumed to be not-transposed for both Lhs and Rhs matrices.
QuantizationInfo lhs_q_info;
QuantizationInfo rhs_q_info;
@@ -138,6 +142,16 @@ protected:
}
}
+ template <typename U>
+ void fill_bias_s32(U &&tensor, int i, const UniformQuantizationInfo &q_info)
+ {
+ // For quantized cases, fill the S32 bias according to the following to avoid saturation of test cases.
+ // The following code limits size of bias values to within expected range of output quantization.
+ const unsigned int bound = std::abs(q_info.scale * 256); // 256 is size of 8 bit datatype
+ std::uniform_int_distribution<int32_t> distribution(-(bound / 10), (bound / 10));
+ library->fill(tensor, distribution, i);
+ }
+
template <typename U, typename D>
void fill_constant(U &&tensor, D value)
{
@@ -156,12 +170,15 @@ protected:
matmul_info.k0 = K0;
matmul_info.export_rhs_to_cl_image = export_rhs_to_cl_image;
+ bool is_quantized = is_data_type_quantized(data_type);
+
// Create tensors
- CLTensor a = create_tensor<CLTensor>(shape_a, data_type, 1, lhs_q_info);
- CLTensor b = create_tensor<CLTensor>(shape_b, data_type, 1, rhs_q_info);
- CLTensor dst = create_tensor<CLTensor>(output_shape, data_type, 1, dst_q_info);
+ CLTensor a = create_tensor<CLTensor>(shape_a, data_type, 1, lhs_q_info);
+ CLTensor b = create_tensor<CLTensor>(shape_b, data_type, 1, rhs_q_info);
+ CLTensor bias = create_tensor<CLTensor>(output_shape[0], (is_quantized) ? DataType::S32 : data_type, 1, dst_q_info);
+ CLTensor dst = create_tensor<CLTensor>(output_shape, data_type, 1, dst_q_info);
- matMul.configure(a.info(), b.info(), dst.info(), matmul_info);
+ matMul.configure(a.info(), b.info(), (_enable_bias) ? bias.info() : nullptr, dst.info(), matmul_info);
ARM_COMPUTE_ASSERT(a.info()->is_resizable());
ARM_COMPUTE_ASSERT(b.info()->is_resizable());
ARM_COMPUTE_ASSERT(dst.info()->is_resizable());
@@ -184,6 +201,22 @@ protected:
{ ACL_SRC_1, &b },
{ ACL_DST, &dst }
});
+
+ if(_enable_bias)
+ {
+ // Allocate, fill and add bias to TensorPack obj
+ bias.allocator()->allocate();
+ if(is_quantized)
+ {
+ fill_bias_s32(CLAccessor(bias), 2, dst_q_info.uniform());
+ }
+ else
+ {
+ fill(CLAccessor(bias), 2);
+ }
+ tensors_pack.add_tensor(ACL_SRC_2, &bias);
+ }
+
matMul.run(tensors_pack);
return dst;
@@ -252,9 +285,21 @@ protected:
template <typename U = T>
typename std::enable_if < std::is_same<U, float>::value || std::is_same<U, half>::value, SimpleTensor<U >>::type gemm_reference(SimpleTensor<U> &a, SimpleTensor<U> &b, SimpleTensor<U> &c)
{
+ // Fill bias, then copy first dimension into subsequent dimensions to mimic broadcast
+ // of bias tensor from shape [dst.dimension(0)] to [dst.tensor_shape()] in target kernel
+ if(_enable_bias)
+ {
+ fill(c, 2);
+ const int n = c.shape().x();
+ const int other_dims = c.shape().collapsed_from(1)[1];
+ for(int i = 1; i < other_dims; ++i) // For all data, copy first n elements into remaining batches
+ {
+ memcpy(c.data() + i * n, c.data(), n * sizeof(T));
+ }
+ }
// Setting beta to 0 will effectively disable C for the
// computation of the reference: alpha * A * B + 0 * C
- return reference::gemm<U>(a, b, c, 1.0f, 0.f);
+ return reference::gemm<U>(a, b, c, 1.0f, (_enable_bias) ? 1.0f : 0.f);
}
template <typename U = T>
@@ -276,19 +321,59 @@ protected:
constexpr int32_t gemmlowp_max_bound = std::numeric_limits<int32_t>::max();
SimpleTensor<int> bias{ c.shape(), DataType::S32 };
- fill_constant(bias, static_cast<int32_t>(0));
+ if(_enable_bias)
+ {
+ // Identical to float implementation, fill and copy values of bias first dimension
+ fill_bias_s32(bias, 2, cq);
+ const int n = bias.shape().x();
+ const int other_dims = bias.shape().collapsed_from(1)[1];
+ const unsigned int dt_size = sizeof(int32_t);
+ for(int i = 1; i < other_dims; ++i)
+ {
+ memcpy(bias.data() + i * n, bias.data(), n * dt_size);
+ }
+ }
+ else
+ {
+ fill_constant(bias, static_cast<int32_t>(0)); // effectively disable bias
+ }
const SimpleTensor<U> final_result = reference::gemmlowp_quantize_down_scale_by_fixedpoint<int32_t, U>(result, bias,
gemmlowp_multipliers, gemmlowp_shifts, gemmlowp_offset, gemmlowp_min_bound, gemmlowp_max_bound);
+
return final_result;
}
CLTensor _target{};
SimpleTensor<T> _reference{};
+ bool _enable_bias{ false };
bool _device_supports_export_to_cl_image{ true };
bool _device_supports_mmul{ true };
};
+template <typename T, typename KernelType, bool use_mmul = false>
+class MatMulKernelValidationFixture : public MatMulKernelGenericValidationFixture<T, KernelType, use_mmul>
+{
+public:
+ template <typename...>
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool pretranspose_a, bool pretranspose_b, int M0, int N0, int K0, bool export_rhs_to_cl_image, DataType data_type)
+ {
+ MatMulKernelGenericValidationFixture<T, KernelType, use_mmul>::setup(shape_a, shape_b, output_shape, pretranspose_a, pretranspose_b, M0, N0, K0, export_rhs_to_cl_image, data_type,
+ false /* enable bias */);
+ }
+};
+
+template <typename T, typename KernelType, bool use_mmul = false>
+class MatMulKernelWithBiasValidation : public MatMulKernelGenericValidationFixture<T, KernelType, use_mmul>
+{
+public:
+ template <typename...>
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool pretranspose_a, bool pretranspose_b, int M0, int N0, int K0, bool export_rhs_to_cl_image, DataType data_type)
+ {
+ MatMulKernelGenericValidationFixture<T, KernelType, use_mmul>::setup(shape_a, shape_b, output_shape, pretranspose_a, pretranspose_b, M0, N0, K0, export_rhs_to_cl_image, data_type,
+ true /* enable bias */);
+ }
+};
} // namespace validation
} // namespace test
} // namespace arm_compute