diff options
author | Michael Tyler <michael.tyler@arm.com> | 2024-06-04 15:47:37 +0100 |
---|---|---|
committer | Michael Tyler <michael.tyler@arm.com> | 2024-06-25 09:10:13 +0000 |
commit | fc94f4d23abd4bc427b701f54ad85282e9ec7872 (patch) | |
tree | 5e2980599256e2b2f4374e5beb61596fc95c9d5a /tests | |
parent | c2237ec4094c7824f8f7e61bc89504d01c5b59ff (diff) | |
download | ComputeLibrary-fc94f4d23abd4bc427b701f54ad85282e9ec7872.tar.gz |
Update CPU kernels and add mixed sign GEMM support
- Add support for mixed sign quantized convolution.
- Add support for mixed sign dequantized GEMM.
- Add SME FP16 GEMV kernel.
- Change SME vector length function to use RDSVL instead of static variable.
- Add GEMM dilation support internally (not exposed yet).
- Remove unused "get_default_activation_values" functions.
- Add SVE fixed format interleaved BF16 DOT kernel.
- Updates and optimizations to assembly kernels.
Resolves COMPMID-6926
Change-Id: I227f502502611d4cc4111c89e30c53ce94079544
Signed-off-by: Michael Tyler <michael.tyler@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11570
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests')
-rw-r--r-- | tests/validation/NEON/ConvolutionLayer.cpp | 101 | ||||
-rw-r--r-- | tests/validation/NEON/GEMMLowp.cpp | 36 | ||||
-rw-r--r-- | tests/validation/fixtures/ConvolutionLayerFixture.h | 27 | ||||
-rw-r--r-- | tests/validation/fixtures/GEMMLowpFixture.h | 50 |
4 files changed, 193 insertions, 21 deletions
diff --git a/tests/validation/NEON/ConvolutionLayer.cpp b/tests/validation/NEON/ConvolutionLayer.cpp index d739d4e1a4..7eada81ce5 100644 --- a/tests/validation/NEON/ConvolutionLayer.cpp +++ b/tests/validation/NEON/ConvolutionLayer.cpp @@ -147,6 +147,45 @@ const auto QuantizationData = make("QuantizationInfo", TEST_SUITE(NEON) TEST_SUITE(ConvolutionLayer) +DATA_TEST_CASE(SupportedTypes, framework::DatasetMode::ALL, zip( + make("DataType", { + DataType::F32, + DataType::QASYMM8, + DataType::QASYMM8, + DataType::QASYMM8_SIGNED + }), + make("WeightsDataType", { + DataType::F32, + DataType::QASYMM8, + DataType::QASYMM8_SIGNED, + DataType::QASYMM8 + }), + make("Expected", + { + true, + true, + true, + false + })), +data_type_const, weights_data_type_const, expected_const) +{ + TensorInfo input_info = TensorInfo(TensorShape(3U, 3U, 1U), 1, data_type_const); + TensorInfo weights_info = TensorInfo(TensorShape(2U, 2U, 1U, 1U), 1, weights_data_type_const); + TensorInfo output_info = TensorInfo(TensorShape(2U, 2U, 1U), 1, data_type_const); + + input_info.set_quantization_info(arm_compute::QuantizationInfo(1, 0)); + weights_info.set_quantization_info(arm_compute::QuantizationInfo(1, 0)); + output_info.set_quantization_info(arm_compute::QuantizationInfo(1, 0)); + + Status status = NEConvolutionLayer::validate( + &input_info, + &weights_info, + nullptr, + &output_info, + PadStrideInfo()); + + ARM_COMPUTE_EXPECT(bool(status) == expected_const, framework::LogLevel::ERRORS); +} // *INDENT-OFF* // clang-format off @@ -257,7 +296,7 @@ TEST_CASE(MemoryInjection, framework::DatasetMode::ALL) for(size_t i = 0; i < result_0.info()->tensor_shape().total_size(); ++i) { - ARM_COMPUTE_EXPECT(((float *)result_0.buffer())[i] == ((float *)result_1.buffer())[i], framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(reinterpret_cast<float *>(result_0.buffer())[i] == reinterpret_cast<float *>(result_1.buffer())[i], framework::LogLevel::ERRORS); } } @@ -303,7 +342,7 @@ TEST_CASE(MultipleExecutionWithConfigure, framework::DatasetMode::ALL) for(size_t i = 0; i < result_0.info()->tensor_shape().total_size(); ++i) { - ARM_COMPUTE_EXPECT(((float *)result_0.buffer())[i] == ((float *)result_1.buffer())[i], framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(reinterpret_cast<float *>(result_0.buffer())[i] == reinterpret_cast<float *>(result_1.buffer())[i], framework::LogLevel::ERRORS); } } @@ -580,7 +619,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEWinogradConvolutionLayerFixture<float>, frame /// It's enough to run the activations for a single weight/input combination and data type because /// activation function is called on top of the winograd output as a separate operator -/// TODO: Enable after COMPMID-6573 is resolved +/// TODO(COMPMID-6573): Enable after COMPMID-6573 is resolved FIXTURE_DATA_TEST_CASE(RunActivations, NEWinogradConvolutionLayerFixture<float>, framework::DatasetMode::DISABLED, combine( make("Input", TensorShape(3U, 3U, 32U)), @@ -1119,7 +1158,7 @@ TEST_CASE(MemoryInjection, framework::DatasetMode::ALL) auto result_1 = run_conv(); for(size_t i = 0; i < result_0.info()->tensor_shape().total_size(); ++i) { - ARM_COMPUTE_EXPECT(((float *)result_0.buffer())[i] == ((float *)result_1.buffer())[i], framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(reinterpret_cast<float *>(result_0.buffer())[i] == reinterpret_cast<float *>(result_1.buffer())[i], framework::LogLevel::ERRORS); } } @@ -1160,7 +1199,7 @@ TEST_CASE(MultipleExecutionWithConfigure, framework::DatasetMode::ALL) auto result_1 = run_conv(); for(size_t i = 0; i < result_0.info()->tensor_shape().total_size(); ++i) { - ARM_COMPUTE_EXPECT(((float *)result_0.buffer())[i] == ((float *)result_1.buffer())[i], framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(reinterpret_cast<float *>(result_0.buffer())[i] == reinterpret_cast<float *>(result_1.buffer())[i], framework::LogLevel::ERRORS); } } @@ -1251,12 +1290,14 @@ FIXTURE_DATA_TEST_CASE(RunVeryLarge, NEGEMMConvolutionLayerFixture<float>, frame TEST_SUITE_END() // FP32 TEST_SUITE_END() // Float -// TODO: COMPMID-6596 Extend quantized tests with at least one suite where the weight is padded (the legacy case, see floating point's RunPaddedWeights) +// TODO(COMPMID-6573): Extend quantized tests with at least one suite where the weight is padded (the legacy case, see floating point's RunPaddedWeights) template <typename T> using NEGEMMConvolutionLayerQuantizedFixture = ConvolutionValidationQuantizedFixture<Tensor, Accessor, NEConvolutionLayer, T>; template <typename T> using NEGEMMConvolutionLayerQuantizedMixedDataLayoutFixture = ConvolutionValidationQuantizedFixture<Tensor, Accessor, NEConvolutionLayer, T, true>; +using NEGEMMConvolutionLayerQuantizedMixedSignFixture = ConvolutionValidationQuantizedMixedTypeFixture<Tensor, Accessor, NEConvolutionLayer, uint8_t, int8_t>; + template <typename T> using NEGEMMConvolutionLayerQuantizedPerChannelFixture = ConvolutionValidationQuantizedPerChannelFixture<Tensor, Accessor, NEConvolutionLayer, T, int8_t>; @@ -1332,6 +1373,50 @@ FIXTURE_DATA_TEST_CASE(RunMixedDataLayout, NEGEMMConvolutionLayerQuantizedFixtur } TEST_SUITE_END() // QASYMM8_SIGNED +TEST_SUITE(QASYMM8_MIXED) +FIXTURE_DATA_TEST_CASE( + RunSmall, + NEGEMMConvolutionLayerQuantizedMixedSignFixture, + framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(), + framework::dataset::make("ReshapeWeights", {true})), + framework::dataset::make("DataType", DataType::QASYMM8)), + framework::dataset::make("WeightsDataType", DataType::QASYMM8_SIGNED)), + framework::dataset::make("DataLayout", {DataLayout::NCHW, DataLayout::NHWC})), + framework::dataset::make("QuantizationInfoIfActivationEnabled", +{QuantizationInfo(2.f / 255.f, 10)})), +framework::dataset::make("WeightQuantizationInfoIfActivationEnabled", +{QuantizationInfo(2.f / 255.f, 10)})), +QuantizedActivationFunctionsDataset)) +{ + // Validate output + validate(Accessor(_target), _reference, tolerance_qasymm8); +} +FIXTURE_DATA_TEST_CASE( + RunMixedDataLayout, + NEGEMMConvolutionLayerQuantizedMixedSignFixture, + framework::DatasetMode::ALL, + combine( + framework::dataset::make("Input", TensorShape(23U, 27U, 5U)), + framework::dataset::make("Weights", TensorShape(3U, 3U, 5U, 2U)), + framework::dataset::make("Bias", TensorShape(2U)), + framework::dataset::make("Output", TensorShape(11U, 25U, 2U)), + framework::dataset::make("PadStrideInfo", PadStrideInfo(2, 1, 0, 0)), + framework::dataset::make("Dilation", Size2D(1, 1)), + framework::dataset::make("ReshapeWeights", {true}), + framework::dataset::make("DataType", DataType::QASYMM8), + framework::dataset::make("WeightsDataType", DataType::QASYMM8_SIGNED), + framework::dataset::make("DataLayout", {DataLayout::NCHW, DataLayout::NHWC}), + framework::dataset::make("QuantizationInfoIfActivationEnabled", {QuantizationInfo(2.f / 255.f, 10)}), + framework::dataset::make("WeightQuantizationInfoIfActivationEnabled", {QuantizationInfo(2.f / 255.f, 10)}), + QuantizedActivationFunctionsDataset) + ) +{ + // Validate output + validate(Accessor(_target), _reference, tolerance_qasymm8); +} +TEST_SUITE_END() // QASYMM8_MIXED + TEST_SUITE(QSYMM8_PER_CHANNEL) FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerQuantizedPerChannelFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(), @@ -1436,7 +1521,7 @@ TEST_CASE(MemoryInjection, framework::DatasetMode::ALL) auto result_1 = run_conv(); for(size_t i = 0; i < result_0.info()->tensor_shape().total_size(); ++i) { - ARM_COMPUTE_EXPECT(((float *)result_0.buffer())[i] == ((float *)result_1.buffer())[i], framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(reinterpret_cast<float *>(result_0.buffer())[i] == reinterpret_cast<float *>(result_1.buffer())[i], framework::LogLevel::ERRORS); } } @@ -1476,7 +1561,7 @@ TEST_CASE(MultipleExecutionWithConfigure, framework::DatasetMode::ALL) auto result_1 = run_conv(); for(size_t i = 0; i < result_0.info()->tensor_shape().total_size(); ++i) { - ARM_COMPUTE_EXPECT(((float *)result_0.buffer())[i] == ((float *)result_1.buffer())[i], framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(reinterpret_cast<float *>(result_0.buffer())[i] == reinterpret_cast<float *>(result_1.buffer())[i], framework::LogLevel::ERRORS); } } diff --git a/tests/validation/NEON/GEMMLowp.cpp b/tests/validation/NEON/GEMMLowp.cpp index d25f43a330..61202ee2b7 100644 --- a/tests/validation/NEON/GEMMLowp.cpp +++ b/tests/validation/NEON/GEMMLowp.cpp @@ -141,20 +141,23 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip( TensorInfo(TensorShape(20U, 13U), 1, DataType::QASYMM8, QuantizationInfo(1.f/255, 10)), // Invalid dimensions TensorInfo(TensorShape(21U, 13U), 1, DataType::QASYMM8, QuantizationInfo(1.f/255, 10)), // Invalid dimensions TensorInfo(TensorShape(16U, 32U), 1, DataType::QASYMM8, QuantizationInfo(1.f/255, 10)), + TensorInfo(TensorShape(16U, 32U), 1, DataType::QASYMM8_SIGNED, QuantizationInfo(1.f/255, 10)), // Invalid types }), make("InputBInfo",{ TensorInfo(TensorShape(33U, 21U), 1, DataType::QASYMM8, QuantizationInfo(1.f/256, 10)), TensorInfo(TensorShape(33U, 21U), 1, DataType::QASYMM8, QuantizationInfo(1.f/256, 10)), TensorInfo(TensorShape(33U, 21U), 1, DataType::QASYMM8, QuantizationInfo(1.f/256, 10)), TensorInfo(TensorShape(33U, 21U), 1, DataType::QASYMM8, QuantizationInfo(1.f/256, 10)), TensorInfo(TensorShape(64U, 16U), 1, DataType::QASYMM8, QuantizationInfo(1.f/256, 10)), + TensorInfo(TensorShape(64U, 16U), 1, DataType::QASYMM8, QuantizationInfo(1.f/256, 10)), }), make("OutputInfo",{ TensorInfo(TensorShape(33U, 13U), 1, DataType::S32), TensorInfo(TensorShape(33U, 13U), 1, DataType::S32), TensorInfo(TensorShape(33U, 13U), 1, DataType::S32), TensorInfo(TensorShape(8U, 11U), 1, DataType::S32), TensorInfo(TensorShape(64U, 32U), 1, DataType::S32), + TensorInfo(TensorShape(64U, 32U), 1, DataType::S32), }), - make("Expected", { true, false, false, false, true })), + make("Expected", { true, false, false, false, true, false })), a_info, b_info, output_info, expected) { // Lock tensors @@ -359,10 +362,39 @@ TEST_SUITE_END() // DynamicQuantization #ifdef __aarch64__ // Deqaunt tests involve returning F32 from the MatrixMultiplyCore kernels and is only implemented in aarch64 TEST_SUITE(Dequant) +DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip( + make("InputAInfo", { + TensorInfo(TensorShape(16U, 32U), 1, DataType::QASYMM8, QuantizationInfo(1.f/255, 10)), + TensorInfo(TensorShape(16U, 32U), 1, DataType::QASYMM8_SIGNED, QuantizationInfo(1.f/255, 10)), + TensorInfo(TensorShape(16U, 32U), 1, DataType::QASYMM8_SIGNED, QuantizationInfo(1.f/255, 10)), // Invalid types + }), + make("InputBInfo",{ + TensorInfo(TensorShape(64U, 16U), 1, DataType::QASYMM8_SIGNED, QuantizationInfo(1.f/256, 10)), + TensorInfo(TensorShape(64U, 16U), 1, DataType::QASYMM8_SIGNED, QuantizationInfo(1.f/256, 10)), + TensorInfo(TensorShape(64U, 16U), 1, DataType::QASYMM8, QuantizationInfo(1.f/256, 10)), + }), + make("OutputInfo",{ + TensorInfo(TensorShape(64U, 32U), 1, DataType::F32), + TensorInfo(TensorShape(64U, 32U), 1, DataType::F32), + TensorInfo(TensorShape(64U, 32U), 1, DataType::F32), + }), + make("Expected", { true, true, false })), + a_info, b_info, output_info, expected) +{ + // Lock tensors + Status status = NEGEMMLowpMatrixMultiplyCore::validate(&a_info.clone()->set_is_resizable(false), + &b_info.clone()->set_is_resizable(false), + nullptr, + &output_info.clone()->set_is_resizable(false)); + ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS); +} + constexpr AbsoluteTolerance<float> tolerance_dequantized(0.01f); FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpDequantizedMatrixMultiplyValidationFixture, framework::DatasetMode::ALL, combine( datasets::SmallGEMMLowpDataset(), + make("DataTypeA", {DataType::QASYMM8_SIGNED, DataType::QASYMM8}), + make("DataTypeB", DataType::QASYMM8_SIGNED), make("accumulate", {true, false}) )) { @@ -373,6 +405,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpDequantizedMatrixMultiplyValidationFi FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMLowpDequantizedMatrixMultiplyValidationFixture, framework::DatasetMode::NIGHTLY, combine( datasets::LargeGEMMLowpDataset(), + make("DataTypeA", {DataType::QASYMM8_SIGNED, DataType::QASYMM8}), + make("DataTypeB", DataType::QASYMM8_SIGNED), make("accumulate", {false}) )) { diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h index 0622e5e6f0..939ac032cd 100644 --- a/tests/validation/fixtures/ConvolutionLayerFixture.h +++ b/tests/validation/fixtures/ConvolutionLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2023 Arm Limited. + * Copyright (c) 2017-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -480,6 +480,31 @@ public: }; template <typename TensorType, typename AccessorType, typename FunctionType, typename T, typename TW> +class ConvolutionValidationQuantizedMixedTypeFixture + : public ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, TW> +{ +public: + void setup(TensorShape input_shape, + TensorShape weights_shape, + TensorShape bias_shape, + TensorShape output_shape, + PadStrideInfo info, + Size2D dilation, + bool reshape_weights, + DataType data_type, + DataType weights_data_type, + DataLayout data_layout, + QuantizationInfo quantization_info, + QuantizationInfo weight_quantization_info, + ActivationLayerInfo act_info) + { + ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, TW>::setup( + input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights, data_type, + weights_data_type, data_layout, quantization_info, weight_quantization_info, act_info); + } +}; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename T, typename TW> class ConvolutionValidationQuantizedPerChannelFixture : public ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, TW> { public: diff --git a/tests/validation/fixtures/GEMMLowpFixture.h b/tests/validation/fixtures/GEMMLowpFixture.h index aa4eedb75d..7931d8467d 100644 --- a/tests/validation/fixtures/GEMMLowpFixture.h +++ b/tests/validation/fixtures/GEMMLowpFixture.h @@ -97,8 +97,7 @@ TensorType compute_gemmlowp_target(const TensorShape &shape_a, const TensorShape bool accumulate = false, bool dynamic_qinfo = false, DataType data_type_output = DataType::UNKNOWN) { ARM_COMPUTE_ASSERT(is_data_type_quantized_asymmetric(data_type_a)); - ARM_COMPUTE_ASSERT(data_type_a == data_type_b); - // If unknown, set to sensible defaults + // If unknown, set to sensible defaults if (data_type_output == DataType::UNKNOWN) { data_type_output = output_stage.type == GEMMLowpOutputStageType::NONE ? DataType::S32 : data_type_a; } @@ -185,7 +184,6 @@ SimpleTensor<int32_t> compute_gemmlowp_reference(const TensorShape &shape_a, con DataType data_type_a = DataType::QASYMM8, DataType data_type_b = DataType::QASYMM8, const TensorFillInfo& finfo = TensorFillInfo()) { ARM_COMPUTE_ASSERT(is_data_type_quantized_asymmetric(data_type_a)); - ARM_COMPUTE_ASSERT(data_type_a == data_type_b); TensorShape shape_a_to_use = shape_a; if(reinterpret_input_as_3d) { @@ -472,29 +470,59 @@ template <typename TensorType, typename AccessorType, typename FunctionType, boo class GEMMLowpDequantizedMatrixMultiplyValidationFixture : public framework::Fixture { public: - void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset, bool accumulate) + void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset, DataType data_type_a, DataType data_type_b, bool accumulate) { const bool dynamic_qinfo = false; const auto a_qinfo = QuantizationInfo(1.0f / 255, a_offset); const auto b_qinfo = QuantizationInfo(5.0f / 255, b_offset); TensorFillInfo finfo; - _target = compute_target(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, finfo, accumulate, dynamic_qinfo); - _reference = compute_reference(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, finfo, accumulate, dynamic_qinfo); + _target = compute_target(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, data_type_a, data_type_b, finfo, + accumulate, dynamic_qinfo); + _reference = compute_reference(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, data_type_a, data_type_b, + finfo, accumulate, dynamic_qinfo); } protected: - TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo, const TensorFillInfo& finfo, const bool accumulate, const bool dynamic_qinfo) + TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo, DataType data_type_a, DataType data_type_b, const TensorFillInfo& finfo, const bool accumulate, const bool dynamic_qinfo) { const auto output_qinfo = QuantizationInfo(); - return compute_gemmlowp_target<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, int32_t, false, run_twice>(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, output_qinfo, DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, GEMMLowpOutputStageInfo(), false, finfo, accumulate, dynamic_qinfo, DataType::F32); + return compute_gemmlowp_target<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, int32_t, false, run_twice>(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, output_qinfo, data_type_a, data_type_b, GEMMLowpOutputStageInfo(), false, finfo, accumulate, dynamic_qinfo, DataType::F32); } - SimpleTensor<float> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo, const TensorFillInfo& finfo, bool accumulate, const bool dynamic_qinfo) + SimpleTensor<float> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo, DataType data_type_a, DataType data_type_b, const TensorFillInfo& finfo, bool accumulate, const bool dynamic_qinfo) { QuantizationInfo s32_ref_output_quant_info = QuantizationInfo(a_qinfo.uniform().scale * b_qinfo.uniform().scale, 0, dynamic_qinfo); - SimpleTensor<int32_t> s32_ref_output = compute_gemmlowp_reference<reinterpret_input_as_3d, int8_t, int8_t, false, false, run_twice>(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, - DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, finfo); + SimpleTensor<int32_t> s32_ref_output; + if (data_type_a == DataType::QASYMM8) + { + if (data_type_b == DataType::QASYMM8) + { + s32_ref_output = compute_gemmlowp_reference<reinterpret_input_as_3d, uint8_t, uint8_t, false, false, run_twice>( + shape_a, shape_b, shape_output, a_qinfo, b_qinfo, data_type_a, data_type_b, finfo); + } + else + { + ARM_COMPUTE_ERROR_ON(data_type_b != DataType::QASYMM8_SIGNED); + s32_ref_output = compute_gemmlowp_reference<reinterpret_input_as_3d, uint8_t, int8_t, false, false, run_twice>( + shape_a, shape_b, shape_output, a_qinfo, b_qinfo, data_type_a, data_type_b, finfo); + } + } + else + { + ARM_COMPUTE_ERROR_ON(data_type_a != DataType::QASYMM8_SIGNED); + if (data_type_b == DataType::QASYMM8) + { + ARM_COMPUTE_ERROR("QASYMM8_SIGNED input with QASYMM8 weights not supported"); + } + else + { + ARM_COMPUTE_ERROR_ON(data_type_b != DataType::QASYMM8_SIGNED); + s32_ref_output = compute_gemmlowp_reference<reinterpret_input_as_3d, int8_t, int8_t, false, false, run_twice>( + shape_a, shape_b, shape_output, a_qinfo, b_qinfo, data_type_a, data_type_b, finfo); + } + } + s32_ref_output.quantization_info(s32_ref_output_quant_info); SimpleTensor<float> f32_ref_output(s32_ref_output.shape(), DataType::F32); |