aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorMichael Tyler <michael.tyler@arm.com>2024-06-04 15:47:37 +0100
committerMichael Tyler <michael.tyler@arm.com>2024-06-25 09:10:13 +0000
commitfc94f4d23abd4bc427b701f54ad85282e9ec7872 (patch)
tree5e2980599256e2b2f4374e5beb61596fc95c9d5a /tests
parentc2237ec4094c7824f8f7e61bc89504d01c5b59ff (diff)
downloadComputeLibrary-fc94f4d23abd4bc427b701f54ad85282e9ec7872.tar.gz
Update CPU kernels and add mixed sign GEMM support
- Add support for mixed sign quantized convolution. - Add support for mixed sign dequantized GEMM. - Add SME FP16 GEMV kernel. - Change SME vector length function to use RDSVL instead of static variable. - Add GEMM dilation support internally (not exposed yet). - Remove unused "get_default_activation_values" functions. - Add SVE fixed format interleaved BF16 DOT kernel. - Updates and optimizations to assembly kernels. Resolves COMPMID-6926 Change-Id: I227f502502611d4cc4111c89e30c53ce94079544 Signed-off-by: Michael Tyler <michael.tyler@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11570 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests')
-rw-r--r--tests/validation/NEON/ConvolutionLayer.cpp101
-rw-r--r--tests/validation/NEON/GEMMLowp.cpp36
-rw-r--r--tests/validation/fixtures/ConvolutionLayerFixture.h27
-rw-r--r--tests/validation/fixtures/GEMMLowpFixture.h50
4 files changed, 193 insertions, 21 deletions
diff --git a/tests/validation/NEON/ConvolutionLayer.cpp b/tests/validation/NEON/ConvolutionLayer.cpp
index d739d4e1a4..7eada81ce5 100644
--- a/tests/validation/NEON/ConvolutionLayer.cpp
+++ b/tests/validation/NEON/ConvolutionLayer.cpp
@@ -147,6 +147,45 @@ const auto QuantizationData = make("QuantizationInfo",
TEST_SUITE(NEON)
TEST_SUITE(ConvolutionLayer)
+DATA_TEST_CASE(SupportedTypes, framework::DatasetMode::ALL, zip(
+ make("DataType", {
+ DataType::F32,
+ DataType::QASYMM8,
+ DataType::QASYMM8,
+ DataType::QASYMM8_SIGNED
+ }),
+ make("WeightsDataType", {
+ DataType::F32,
+ DataType::QASYMM8,
+ DataType::QASYMM8_SIGNED,
+ DataType::QASYMM8
+ }),
+ make("Expected",
+ {
+ true,
+ true,
+ true,
+ false
+ })),
+data_type_const, weights_data_type_const, expected_const)
+{
+ TensorInfo input_info = TensorInfo(TensorShape(3U, 3U, 1U), 1, data_type_const);
+ TensorInfo weights_info = TensorInfo(TensorShape(2U, 2U, 1U, 1U), 1, weights_data_type_const);
+ TensorInfo output_info = TensorInfo(TensorShape(2U, 2U, 1U), 1, data_type_const);
+
+ input_info.set_quantization_info(arm_compute::QuantizationInfo(1, 0));
+ weights_info.set_quantization_info(arm_compute::QuantizationInfo(1, 0));
+ output_info.set_quantization_info(arm_compute::QuantizationInfo(1, 0));
+
+ Status status = NEConvolutionLayer::validate(
+ &input_info,
+ &weights_info,
+ nullptr,
+ &output_info,
+ PadStrideInfo());
+
+ ARM_COMPUTE_EXPECT(bool(status) == expected_const, framework::LogLevel::ERRORS);
+}
// *INDENT-OFF*
// clang-format off
@@ -257,7 +296,7 @@ TEST_CASE(MemoryInjection, framework::DatasetMode::ALL)
for(size_t i = 0; i < result_0.info()->tensor_shape().total_size(); ++i)
{
- ARM_COMPUTE_EXPECT(((float *)result_0.buffer())[i] == ((float *)result_1.buffer())[i], framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(reinterpret_cast<float *>(result_0.buffer())[i] == reinterpret_cast<float *>(result_1.buffer())[i], framework::LogLevel::ERRORS);
}
}
@@ -303,7 +342,7 @@ TEST_CASE(MultipleExecutionWithConfigure, framework::DatasetMode::ALL)
for(size_t i = 0; i < result_0.info()->tensor_shape().total_size(); ++i)
{
- ARM_COMPUTE_EXPECT(((float *)result_0.buffer())[i] == ((float *)result_1.buffer())[i], framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(reinterpret_cast<float *>(result_0.buffer())[i] == reinterpret_cast<float *>(result_1.buffer())[i], framework::LogLevel::ERRORS);
}
}
@@ -580,7 +619,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEWinogradConvolutionLayerFixture<float>, frame
/// It's enough to run the activations for a single weight/input combination and data type because
/// activation function is called on top of the winograd output as a separate operator
-/// TODO: Enable after COMPMID-6573 is resolved
+/// TODO(COMPMID-6573): Enable after COMPMID-6573 is resolved
FIXTURE_DATA_TEST_CASE(RunActivations, NEWinogradConvolutionLayerFixture<float>, framework::DatasetMode::DISABLED,
combine(
make("Input", TensorShape(3U, 3U, 32U)),
@@ -1119,7 +1158,7 @@ TEST_CASE(MemoryInjection, framework::DatasetMode::ALL)
auto result_1 = run_conv();
for(size_t i = 0; i < result_0.info()->tensor_shape().total_size(); ++i)
{
- ARM_COMPUTE_EXPECT(((float *)result_0.buffer())[i] == ((float *)result_1.buffer())[i], framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(reinterpret_cast<float *>(result_0.buffer())[i] == reinterpret_cast<float *>(result_1.buffer())[i], framework::LogLevel::ERRORS);
}
}
@@ -1160,7 +1199,7 @@ TEST_CASE(MultipleExecutionWithConfigure, framework::DatasetMode::ALL)
auto result_1 = run_conv();
for(size_t i = 0; i < result_0.info()->tensor_shape().total_size(); ++i)
{
- ARM_COMPUTE_EXPECT(((float *)result_0.buffer())[i] == ((float *)result_1.buffer())[i], framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(reinterpret_cast<float *>(result_0.buffer())[i] == reinterpret_cast<float *>(result_1.buffer())[i], framework::LogLevel::ERRORS);
}
}
@@ -1251,12 +1290,14 @@ FIXTURE_DATA_TEST_CASE(RunVeryLarge, NEGEMMConvolutionLayerFixture<float>, frame
TEST_SUITE_END() // FP32
TEST_SUITE_END() // Float
-// TODO: COMPMID-6596 Extend quantized tests with at least one suite where the weight is padded (the legacy case, see floating point's RunPaddedWeights)
+// TODO(COMPMID-6573): Extend quantized tests with at least one suite where the weight is padded (the legacy case, see floating point's RunPaddedWeights)
template <typename T>
using NEGEMMConvolutionLayerQuantizedFixture = ConvolutionValidationQuantizedFixture<Tensor, Accessor, NEConvolutionLayer, T>;
template <typename T>
using NEGEMMConvolutionLayerQuantizedMixedDataLayoutFixture = ConvolutionValidationQuantizedFixture<Tensor, Accessor, NEConvolutionLayer, T, true>;
+using NEGEMMConvolutionLayerQuantizedMixedSignFixture = ConvolutionValidationQuantizedMixedTypeFixture<Tensor, Accessor, NEConvolutionLayer, uint8_t, int8_t>;
+
template <typename T>
using NEGEMMConvolutionLayerQuantizedPerChannelFixture = ConvolutionValidationQuantizedPerChannelFixture<Tensor, Accessor, NEConvolutionLayer, T, int8_t>;
@@ -1332,6 +1373,50 @@ FIXTURE_DATA_TEST_CASE(RunMixedDataLayout, NEGEMMConvolutionLayerQuantizedFixtur
}
TEST_SUITE_END() // QASYMM8_SIGNED
+TEST_SUITE(QASYMM8_MIXED)
+FIXTURE_DATA_TEST_CASE(
+ RunSmall,
+ NEGEMMConvolutionLayerQuantizedMixedSignFixture,
+ framework::DatasetMode::ALL,
+ combine(combine(combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
+ framework::dataset::make("ReshapeWeights", {true})),
+ framework::dataset::make("DataType", DataType::QASYMM8)),
+ framework::dataset::make("WeightsDataType", DataType::QASYMM8_SIGNED)),
+ framework::dataset::make("DataLayout", {DataLayout::NCHW, DataLayout::NHWC})),
+ framework::dataset::make("QuantizationInfoIfActivationEnabled",
+{QuantizationInfo(2.f / 255.f, 10)})),
+framework::dataset::make("WeightQuantizationInfoIfActivationEnabled",
+{QuantizationInfo(2.f / 255.f, 10)})),
+QuantizedActivationFunctionsDataset))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_qasymm8);
+}
+FIXTURE_DATA_TEST_CASE(
+ RunMixedDataLayout,
+ NEGEMMConvolutionLayerQuantizedMixedSignFixture,
+ framework::DatasetMode::ALL,
+ combine(
+ framework::dataset::make("Input", TensorShape(23U, 27U, 5U)),
+ framework::dataset::make("Weights", TensorShape(3U, 3U, 5U, 2U)),
+ framework::dataset::make("Bias", TensorShape(2U)),
+ framework::dataset::make("Output", TensorShape(11U, 25U, 2U)),
+ framework::dataset::make("PadStrideInfo", PadStrideInfo(2, 1, 0, 0)),
+ framework::dataset::make("Dilation", Size2D(1, 1)),
+ framework::dataset::make("ReshapeWeights", {true}),
+ framework::dataset::make("DataType", DataType::QASYMM8),
+ framework::dataset::make("WeightsDataType", DataType::QASYMM8_SIGNED),
+ framework::dataset::make("DataLayout", {DataLayout::NCHW, DataLayout::NHWC}),
+ framework::dataset::make("QuantizationInfoIfActivationEnabled", {QuantizationInfo(2.f / 255.f, 10)}),
+ framework::dataset::make("WeightQuantizationInfoIfActivationEnabled", {QuantizationInfo(2.f / 255.f, 10)}),
+ QuantizedActivationFunctionsDataset)
+ )
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_qasymm8);
+}
+TEST_SUITE_END() // QASYMM8_MIXED
+
TEST_SUITE(QSYMM8_PER_CHANNEL)
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerQuantizedPerChannelFixture<uint8_t>, framework::DatasetMode::ALL,
combine(combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
@@ -1436,7 +1521,7 @@ TEST_CASE(MemoryInjection, framework::DatasetMode::ALL)
auto result_1 = run_conv();
for(size_t i = 0; i < result_0.info()->tensor_shape().total_size(); ++i)
{
- ARM_COMPUTE_EXPECT(((float *)result_0.buffer())[i] == ((float *)result_1.buffer())[i], framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(reinterpret_cast<float *>(result_0.buffer())[i] == reinterpret_cast<float *>(result_1.buffer())[i], framework::LogLevel::ERRORS);
}
}
@@ -1476,7 +1561,7 @@ TEST_CASE(MultipleExecutionWithConfigure, framework::DatasetMode::ALL)
auto result_1 = run_conv();
for(size_t i = 0; i < result_0.info()->tensor_shape().total_size(); ++i)
{
- ARM_COMPUTE_EXPECT(((float *)result_0.buffer())[i] == ((float *)result_1.buffer())[i], framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(reinterpret_cast<float *>(result_0.buffer())[i] == reinterpret_cast<float *>(result_1.buffer())[i], framework::LogLevel::ERRORS);
}
}
diff --git a/tests/validation/NEON/GEMMLowp.cpp b/tests/validation/NEON/GEMMLowp.cpp
index d25f43a330..61202ee2b7 100644
--- a/tests/validation/NEON/GEMMLowp.cpp
+++ b/tests/validation/NEON/GEMMLowp.cpp
@@ -141,20 +141,23 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(
TensorInfo(TensorShape(20U, 13U), 1, DataType::QASYMM8, QuantizationInfo(1.f/255, 10)), // Invalid dimensions
TensorInfo(TensorShape(21U, 13U), 1, DataType::QASYMM8, QuantizationInfo(1.f/255, 10)), // Invalid dimensions
TensorInfo(TensorShape(16U, 32U), 1, DataType::QASYMM8, QuantizationInfo(1.f/255, 10)),
+ TensorInfo(TensorShape(16U, 32U), 1, DataType::QASYMM8_SIGNED, QuantizationInfo(1.f/255, 10)), // Invalid types
}),
make("InputBInfo",{ TensorInfo(TensorShape(33U, 21U), 1, DataType::QASYMM8, QuantizationInfo(1.f/256, 10)),
TensorInfo(TensorShape(33U, 21U), 1, DataType::QASYMM8, QuantizationInfo(1.f/256, 10)),
TensorInfo(TensorShape(33U, 21U), 1, DataType::QASYMM8, QuantizationInfo(1.f/256, 10)),
TensorInfo(TensorShape(33U, 21U), 1, DataType::QASYMM8, QuantizationInfo(1.f/256, 10)),
TensorInfo(TensorShape(64U, 16U), 1, DataType::QASYMM8, QuantizationInfo(1.f/256, 10)),
+ TensorInfo(TensorShape(64U, 16U), 1, DataType::QASYMM8, QuantizationInfo(1.f/256, 10)),
}),
make("OutputInfo",{ TensorInfo(TensorShape(33U, 13U), 1, DataType::S32),
TensorInfo(TensorShape(33U, 13U), 1, DataType::S32),
TensorInfo(TensorShape(33U, 13U), 1, DataType::S32),
TensorInfo(TensorShape(8U, 11U), 1, DataType::S32),
TensorInfo(TensorShape(64U, 32U), 1, DataType::S32),
+ TensorInfo(TensorShape(64U, 32U), 1, DataType::S32),
}),
- make("Expected", { true, false, false, false, true })),
+ make("Expected", { true, false, false, false, true, false })),
a_info, b_info, output_info, expected)
{
// Lock tensors
@@ -359,10 +362,39 @@ TEST_SUITE_END() // DynamicQuantization
#ifdef __aarch64__
// Deqaunt tests involve returning F32 from the MatrixMultiplyCore kernels and is only implemented in aarch64
TEST_SUITE(Dequant)
+DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(
+ make("InputAInfo", {
+ TensorInfo(TensorShape(16U, 32U), 1, DataType::QASYMM8, QuantizationInfo(1.f/255, 10)),
+ TensorInfo(TensorShape(16U, 32U), 1, DataType::QASYMM8_SIGNED, QuantizationInfo(1.f/255, 10)),
+ TensorInfo(TensorShape(16U, 32U), 1, DataType::QASYMM8_SIGNED, QuantizationInfo(1.f/255, 10)), // Invalid types
+ }),
+ make("InputBInfo",{
+ TensorInfo(TensorShape(64U, 16U), 1, DataType::QASYMM8_SIGNED, QuantizationInfo(1.f/256, 10)),
+ TensorInfo(TensorShape(64U, 16U), 1, DataType::QASYMM8_SIGNED, QuantizationInfo(1.f/256, 10)),
+ TensorInfo(TensorShape(64U, 16U), 1, DataType::QASYMM8, QuantizationInfo(1.f/256, 10)),
+ }),
+ make("OutputInfo",{
+ TensorInfo(TensorShape(64U, 32U), 1, DataType::F32),
+ TensorInfo(TensorShape(64U, 32U), 1, DataType::F32),
+ TensorInfo(TensorShape(64U, 32U), 1, DataType::F32),
+ }),
+ make("Expected", { true, true, false })),
+ a_info, b_info, output_info, expected)
+{
+ // Lock tensors
+ Status status = NEGEMMLowpMatrixMultiplyCore::validate(&a_info.clone()->set_is_resizable(false),
+ &b_info.clone()->set_is_resizable(false),
+ nullptr,
+ &output_info.clone()->set_is_resizable(false));
+ ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS);
+}
+
constexpr AbsoluteTolerance<float> tolerance_dequantized(0.01f);
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpDequantizedMatrixMultiplyValidationFixture, framework::DatasetMode::ALL,
combine(
datasets::SmallGEMMLowpDataset(),
+ make("DataTypeA", {DataType::QASYMM8_SIGNED, DataType::QASYMM8}),
+ make("DataTypeB", DataType::QASYMM8_SIGNED),
make("accumulate", {true, false})
))
{
@@ -373,6 +405,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpDequantizedMatrixMultiplyValidationFi
FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMLowpDequantizedMatrixMultiplyValidationFixture, framework::DatasetMode::NIGHTLY,
combine(
datasets::LargeGEMMLowpDataset(),
+ make("DataTypeA", {DataType::QASYMM8_SIGNED, DataType::QASYMM8}),
+ make("DataTypeB", DataType::QASYMM8_SIGNED),
make("accumulate", {false})
))
{
diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h
index 0622e5e6f0..939ac032cd 100644
--- a/tests/validation/fixtures/ConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/ConvolutionLayerFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -480,6 +480,31 @@ public:
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T, typename TW>
+class ConvolutionValidationQuantizedMixedTypeFixture
+ : public ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, TW>
+{
+public:
+ void setup(TensorShape input_shape,
+ TensorShape weights_shape,
+ TensorShape bias_shape,
+ TensorShape output_shape,
+ PadStrideInfo info,
+ Size2D dilation,
+ bool reshape_weights,
+ DataType data_type,
+ DataType weights_data_type,
+ DataLayout data_layout,
+ QuantizationInfo quantization_info,
+ QuantizationInfo weight_quantization_info,
+ ActivationLayerInfo act_info)
+ {
+ ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, TW>::setup(
+ input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights, data_type,
+ weights_data_type, data_layout, quantization_info, weight_quantization_info, act_info);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T, typename TW>
class ConvolutionValidationQuantizedPerChannelFixture : public ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, TW>
{
public:
diff --git a/tests/validation/fixtures/GEMMLowpFixture.h b/tests/validation/fixtures/GEMMLowpFixture.h
index aa4eedb75d..7931d8467d 100644
--- a/tests/validation/fixtures/GEMMLowpFixture.h
+++ b/tests/validation/fixtures/GEMMLowpFixture.h
@@ -97,8 +97,7 @@ TensorType compute_gemmlowp_target(const TensorShape &shape_a, const TensorShape
bool accumulate = false, bool dynamic_qinfo = false, DataType data_type_output = DataType::UNKNOWN)
{
ARM_COMPUTE_ASSERT(is_data_type_quantized_asymmetric(data_type_a));
- ARM_COMPUTE_ASSERT(data_type_a == data_type_b);
- // If unknown, set to sensible defaults
+ // If unknown, set to sensible defaults
if (data_type_output == DataType::UNKNOWN) {
data_type_output = output_stage.type == GEMMLowpOutputStageType::NONE ? DataType::S32 : data_type_a;
}
@@ -185,7 +184,6 @@ SimpleTensor<int32_t> compute_gemmlowp_reference(const TensorShape &shape_a, con
DataType data_type_a = DataType::QASYMM8, DataType data_type_b = DataType::QASYMM8, const TensorFillInfo& finfo = TensorFillInfo())
{
ARM_COMPUTE_ASSERT(is_data_type_quantized_asymmetric(data_type_a));
- ARM_COMPUTE_ASSERT(data_type_a == data_type_b);
TensorShape shape_a_to_use = shape_a;
if(reinterpret_input_as_3d)
{
@@ -472,29 +470,59 @@ template <typename TensorType, typename AccessorType, typename FunctionType, boo
class GEMMLowpDequantizedMatrixMultiplyValidationFixture : public framework::Fixture
{
public:
- void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset, bool accumulate)
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset, DataType data_type_a, DataType data_type_b, bool accumulate)
{
const bool dynamic_qinfo = false;
const auto a_qinfo = QuantizationInfo(1.0f / 255, a_offset);
const auto b_qinfo = QuantizationInfo(5.0f / 255, b_offset);
TensorFillInfo finfo;
- _target = compute_target(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, finfo, accumulate, dynamic_qinfo);
- _reference = compute_reference(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, finfo, accumulate, dynamic_qinfo);
+ _target = compute_target(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, data_type_a, data_type_b, finfo,
+ accumulate, dynamic_qinfo);
+ _reference = compute_reference(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, data_type_a, data_type_b,
+ finfo, accumulate, dynamic_qinfo);
}
protected:
- TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo, const TensorFillInfo& finfo, const bool accumulate, const bool dynamic_qinfo)
+ TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo, DataType data_type_a, DataType data_type_b, const TensorFillInfo& finfo, const bool accumulate, const bool dynamic_qinfo)
{
const auto output_qinfo = QuantizationInfo();
- return compute_gemmlowp_target<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, int32_t, false, run_twice>(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, output_qinfo, DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, GEMMLowpOutputStageInfo(), false, finfo, accumulate, dynamic_qinfo, DataType::F32);
+ return compute_gemmlowp_target<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, int32_t, false, run_twice>(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, output_qinfo, data_type_a, data_type_b, GEMMLowpOutputStageInfo(), false, finfo, accumulate, dynamic_qinfo, DataType::F32);
}
- SimpleTensor<float> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo, const TensorFillInfo& finfo, bool accumulate, const bool dynamic_qinfo)
+ SimpleTensor<float> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo, DataType data_type_a, DataType data_type_b, const TensorFillInfo& finfo, bool accumulate, const bool dynamic_qinfo)
{
QuantizationInfo s32_ref_output_quant_info = QuantizationInfo(a_qinfo.uniform().scale * b_qinfo.uniform().scale, 0, dynamic_qinfo);
- SimpleTensor<int32_t> s32_ref_output = compute_gemmlowp_reference<reinterpret_input_as_3d, int8_t, int8_t, false, false, run_twice>(shape_a, shape_b, shape_output, a_qinfo, b_qinfo,
- DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, finfo);
+ SimpleTensor<int32_t> s32_ref_output;
+ if (data_type_a == DataType::QASYMM8)
+ {
+ if (data_type_b == DataType::QASYMM8)
+ {
+ s32_ref_output = compute_gemmlowp_reference<reinterpret_input_as_3d, uint8_t, uint8_t, false, false, run_twice>(
+ shape_a, shape_b, shape_output, a_qinfo, b_qinfo, data_type_a, data_type_b, finfo);
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR_ON(data_type_b != DataType::QASYMM8_SIGNED);
+ s32_ref_output = compute_gemmlowp_reference<reinterpret_input_as_3d, uint8_t, int8_t, false, false, run_twice>(
+ shape_a, shape_b, shape_output, a_qinfo, b_qinfo, data_type_a, data_type_b, finfo);
+ }
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR_ON(data_type_a != DataType::QASYMM8_SIGNED);
+ if (data_type_b == DataType::QASYMM8)
+ {
+ ARM_COMPUTE_ERROR("QASYMM8_SIGNED input with QASYMM8 weights not supported");
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR_ON(data_type_b != DataType::QASYMM8_SIGNED);
+ s32_ref_output = compute_gemmlowp_reference<reinterpret_input_as_3d, int8_t, int8_t, false, false, run_twice>(
+ shape_a, shape_b, shape_output, a_qinfo, b_qinfo, data_type_a, data_type_b, finfo);
+ }
+ }
+
s32_ref_output.quantization_info(s32_ref_output_quant_info);
SimpleTensor<float> f32_ref_output(s32_ref_output.shape(), DataType::F32);