diff options
author | Adnan AlSinan <adnan.alsinan@arm.com> | 2022-05-05 11:13:19 +0100 |
---|---|---|
committer | Adnan AlSinan <adnan.alsinan@arm.com> | 2022-05-06 09:47:31 +0000 |
commit | c5849580868b7ad101572f1b90c87f3daa06385d (patch) | |
tree | 6e07bcf11c5265d18010c26cbe0840d6943a9246 /tests/validation/fixtures | |
parent | 638b7e4f6b1125b74f27f90dea2cd23eca52bfe8 (diff) | |
download | ComputeLibrary-c5849580868b7ad101572f1b90c87f3daa06385d.tar.gz |
Extend GemmLowp reference to support BATCH MATMUL for quantized data type
- Extends GEMMInfo class to include flags for transposing A and B.
- Extends GEMMLowp fixtrues to have an option for transposing A and B.
Resolves COMPMID-5075
Signed-off-by: Adnan AlSinan <adnan.alsinan@arm.com>
Change-Id: If5e4b7e2b7b19ca30808a78a9641d8ba3f176a26
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7458
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r-- | tests/validation/fixtures/GEMMLowpFixture.h | 29 |
1 files changed, 26 insertions, 3 deletions
diff --git a/tests/validation/fixtures/GEMMLowpFixture.h b/tests/validation/fixtures/GEMMLowpFixture.h index 5e2154592e..3f83cc92f1 100644 --- a/tests/validation/fixtures/GEMMLowpFixture.h +++ b/tests/validation/fixtures/GEMMLowpFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -160,7 +160,7 @@ TensorType compute_gemmlowp_target(const TensorShape &shape_a, const TensorShape return output; } -template <bool reinterpret_input_as_3d, typename TI = uint8_t, typename TW = uint8_t> +template <bool reinterpret_input_as_3d, typename TI = uint8_t, typename TW = uint8_t, bool pretranspose_A = false, bool pretranspose_B = false> SimpleTensor<int32_t> compute_gemmlowp_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset, DataType data_type_a = DataType::QASYMM8, DataType data_type_b = DataType::QASYMM8, QuantizationInfo b_qinfo = QuantizationInfo()) { @@ -175,10 +175,33 @@ SimpleTensor<int32_t> compute_gemmlowp_reference(const TensorShape &shape_a, con SimpleTensor<TI> a{ shape_a_to_use, data_type_a, 1 }; SimpleTensor<TW> b{ shape_b, data_type_b, 1, data_type_b == DataType::QSYMM8_PER_CHANNEL ? b_qinfo : QuantizationInfo(1.0f / 255, b_offset) }; + TensorShape shape_a_to_use_transposed{ shape_a_to_use }; + TensorShape shape_b_transposed{ shape_b }; + + shape_a_to_use_transposed.set(0, shape_a_to_use[1]); + shape_a_to_use_transposed.set(1, shape_a_to_use[0]); + shape_b_transposed.set(0, shape_b[1]); + shape_b_transposed.set(1, shape_b[0]); + + SimpleTensor<TI> a_transposed{ shape_a_to_use_transposed, data_type_a, 1 }; + SimpleTensor<TW> b_transposed{ shape_b_transposed, data_type_b, 1, data_type_b == DataType::QSYMM8_PER_CHANNEL ? b_qinfo : QuantizationInfo(1.0f / 255, b_offset) }; + // Fill reference fill(a, 0); fill(b, 1); - return reference::gemmlowp_matrix_multiply_core<int32_t, TI, TW>(a, b, shape_output, a_offset, b_offset); + + // Transpose reference if required + if(pretranspose_A) + { + transpose_matrix<TI>(a, a_transposed); + } + + if(pretranspose_B) + { + transpose_matrix<TW>(b, b_transposed); + } + + return reference::gemmlowp_matrix_multiply_core<int32_t, TI, TW>((pretranspose_A ? a_transposed : a), (pretranspose_B ? b_transposed : b), shape_output, a_offset, b_offset); } } |