aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdnan AlSinan <adnan.alsinan@arm.com>2022-05-05 11:13:19 +0100
committerAdnan AlSinan <adnan.alsinan@arm.com>2022-05-06 09:47:31 +0000
commitc5849580868b7ad101572f1b90c87f3daa06385d (patch)
tree6e07bcf11c5265d18010c26cbe0840d6943a9246
parent638b7e4f6b1125b74f27f90dea2cd23eca52bfe8 (diff)
downloadComputeLibrary-c5849580868b7ad101572f1b90c87f3daa06385d.tar.gz
Extend GemmLowp reference to support BATCH MATMUL for quantized data type
- Extends GEMMInfo class to include flags for transposing A and B. - Extends GEMMLowp fixtrues to have an option for transposing A and B. Resolves COMPMID-5075 Signed-off-by: Adnan AlSinan <adnan.alsinan@arm.com> Change-Id: If5e4b7e2b7b19ca30808a78a9641d8ba3f176a26 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7458 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
-rw-r--r--arm_compute/core/Types.h23
-rw-r--r--tests/validation/Helpers.cpp4
-rw-r--r--tests/validation/fixtures/GEMMLowpFixture.h29
3 files changed, 50 insertions, 6 deletions
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h
index 473102a95a..1548816e91 100644
--- a/arm_compute/core/Types.h
+++ b/arm_compute/core/Types.h
@@ -2089,7 +2089,8 @@ public:
_fast_math(false),
_fp_mixed_precision(false),
_broadcast_bias(false),
- _pretranspose_B(true),
+ _pretranspose_A(false),
+ _pretranspose_B(false),
_activation_info(),
_post_ops()
{
@@ -2124,7 +2125,8 @@ public:
_fast_math(fast_math),
_fp_mixed_precision(fp_mixed_precision),
_broadcast_bias(broadcast_bias),
- _pretranspose_B(reshape_b_only_on_first_run),
+ _pretranspose_A(false),
+ _pretranspose_B(false),
_activation_info(activation_info),
_post_ops(post_ops)
{
@@ -2227,6 +2229,22 @@ public:
{
return _broadcast_bias;
};
+ /** Flag which specifies whether A should be pre-transposed if supported.
+ *
+ * @return True if A should be pre-transposed else false.
+ */
+ bool pretranspose_A() const
+ {
+ return _pretranspose_A;
+ };
+ /** Set pre-transpose A flag
+ *
+ * @param[in] flag Flag to set
+ */
+ void set_pretranspose_A(bool flag)
+ {
+ _pretranspose_A = flag;
+ }
/** Flag which specifies whether b should be pre-transposed if supported.
*
* @return True if b should be pre-transposed else false.
@@ -2287,6 +2305,7 @@ private:
bool _fast_math;
bool _fp_mixed_precision;
bool _broadcast_bias;
+ bool _pretranspose_A;
bool _pretranspose_B;
ActivationLayerInfo _activation_info;
experimental::PostOpList<ITensorInfo *> _post_ops;
diff --git a/tests/validation/Helpers.cpp b/tests/validation/Helpers.cpp
index 237a5a517c..be194dd266 100644
--- a/tests/validation/Helpers.cpp
+++ b/tests/validation/Helpers.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -385,6 +385,8 @@ template void transpose_matrix(const SimpleTensor<half> &in, SimpleTensor<half>
template void transpose_matrix(const SimpleTensor<int> &in, SimpleTensor<int> &out);
template void transpose_matrix(const SimpleTensor<short> &in, SimpleTensor<short> &out);
template void transpose_matrix(const SimpleTensor<char> &in, SimpleTensor<char> &out);
+template void transpose_matrix(const SimpleTensor<int8_t> &in, SimpleTensor<int8_t> &out);
+template void transpose_matrix(const SimpleTensor<uint8_t> &in, SimpleTensor<uint8_t> &out);
template void matrix_multiply(const SimpleTensor<float> &a, const SimpleTensor<float> &b, SimpleTensor<float> &out);
template void matrix_multiply(const SimpleTensor<half> &a, const SimpleTensor<half> &b, SimpleTensor<half> &out);
diff --git a/tests/validation/fixtures/GEMMLowpFixture.h b/tests/validation/fixtures/GEMMLowpFixture.h
index 5e2154592e..3f83cc92f1 100644
--- a/tests/validation/fixtures/GEMMLowpFixture.h
+++ b/tests/validation/fixtures/GEMMLowpFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -160,7 +160,7 @@ TensorType compute_gemmlowp_target(const TensorShape &shape_a, const TensorShape
return output;
}
-template <bool reinterpret_input_as_3d, typename TI = uint8_t, typename TW = uint8_t>
+template <bool reinterpret_input_as_3d, typename TI = uint8_t, typename TW = uint8_t, bool pretranspose_A = false, bool pretranspose_B = false>
SimpleTensor<int32_t> compute_gemmlowp_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset,
DataType data_type_a = DataType::QASYMM8, DataType data_type_b = DataType::QASYMM8, QuantizationInfo b_qinfo = QuantizationInfo())
{
@@ -175,10 +175,33 @@ SimpleTensor<int32_t> compute_gemmlowp_reference(const TensorShape &shape_a, con
SimpleTensor<TI> a{ shape_a_to_use, data_type_a, 1 };
SimpleTensor<TW> b{ shape_b, data_type_b, 1, data_type_b == DataType::QSYMM8_PER_CHANNEL ? b_qinfo : QuantizationInfo(1.0f / 255, b_offset) };
+ TensorShape shape_a_to_use_transposed{ shape_a_to_use };
+ TensorShape shape_b_transposed{ shape_b };
+
+ shape_a_to_use_transposed.set(0, shape_a_to_use[1]);
+ shape_a_to_use_transposed.set(1, shape_a_to_use[0]);
+ shape_b_transposed.set(0, shape_b[1]);
+ shape_b_transposed.set(1, shape_b[0]);
+
+ SimpleTensor<TI> a_transposed{ shape_a_to_use_transposed, data_type_a, 1 };
+ SimpleTensor<TW> b_transposed{ shape_b_transposed, data_type_b, 1, data_type_b == DataType::QSYMM8_PER_CHANNEL ? b_qinfo : QuantizationInfo(1.0f / 255, b_offset) };
+
// Fill reference
fill(a, 0);
fill(b, 1);
- return reference::gemmlowp_matrix_multiply_core<int32_t, TI, TW>(a, b, shape_output, a_offset, b_offset);
+
+ // Transpose reference if required
+ if(pretranspose_A)
+ {
+ transpose_matrix<TI>(a, a_transposed);
+ }
+
+ if(pretranspose_B)
+ {
+ transpose_matrix<TW>(b, b_transposed);
+ }
+
+ return reference::gemmlowp_matrix_multiply_core<int32_t, TI, TW>((pretranspose_A ? a_transposed : a), (pretranspose_B ? b_transposed : b), shape_output, a_offset, b_offset);
}
}