diff options
-rw-r--r-- | arm_compute/core/Types.h | 23 | ||||
-rw-r--r-- | tests/validation/Helpers.cpp | 4 | ||||
-rw-r--r-- | tests/validation/fixtures/GEMMLowpFixture.h | 29 |
3 files changed, 50 insertions, 6 deletions
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index 473102a95a..1548816e91 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -2089,7 +2089,8 @@ public: _fast_math(false), _fp_mixed_precision(false), _broadcast_bias(false), - _pretranspose_B(true), + _pretranspose_A(false), + _pretranspose_B(false), _activation_info(), _post_ops() { @@ -2124,7 +2125,8 @@ public: _fast_math(fast_math), _fp_mixed_precision(fp_mixed_precision), _broadcast_bias(broadcast_bias), - _pretranspose_B(reshape_b_only_on_first_run), + _pretranspose_A(false), + _pretranspose_B(false), _activation_info(activation_info), _post_ops(post_ops) { @@ -2227,6 +2229,22 @@ public: { return _broadcast_bias; }; + /** Flag which specifies whether A should be pre-transposed if supported. + * + * @return True if A should be pre-transposed else false. + */ + bool pretranspose_A() const + { + return _pretranspose_A; + }; + /** Set pre-transpose A flag + * + * @param[in] flag Flag to set + */ + void set_pretranspose_A(bool flag) + { + _pretranspose_A = flag; + } /** Flag which specifies whether b should be pre-transposed if supported. * * @return True if b should be pre-transposed else false. @@ -2287,6 +2305,7 @@ private: bool _fast_math; bool _fp_mixed_precision; bool _broadcast_bias; + bool _pretranspose_A; bool _pretranspose_B; ActivationLayerInfo _activation_info; experimental::PostOpList<ITensorInfo *> _post_ops; diff --git a/tests/validation/Helpers.cpp b/tests/validation/Helpers.cpp index 237a5a517c..be194dd266 100644 --- a/tests/validation/Helpers.cpp +++ b/tests/validation/Helpers.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -385,6 +385,8 @@ template void transpose_matrix(const SimpleTensor<half> &in, SimpleTensor<half> template void transpose_matrix(const SimpleTensor<int> &in, SimpleTensor<int> &out); template void transpose_matrix(const SimpleTensor<short> &in, SimpleTensor<short> &out); template void transpose_matrix(const SimpleTensor<char> &in, SimpleTensor<char> &out); +template void transpose_matrix(const SimpleTensor<int8_t> &in, SimpleTensor<int8_t> &out); +template void transpose_matrix(const SimpleTensor<uint8_t> &in, SimpleTensor<uint8_t> &out); template void matrix_multiply(const SimpleTensor<float> &a, const SimpleTensor<float> &b, SimpleTensor<float> &out); template void matrix_multiply(const SimpleTensor<half> &a, const SimpleTensor<half> &b, SimpleTensor<half> &out); diff --git a/tests/validation/fixtures/GEMMLowpFixture.h b/tests/validation/fixtures/GEMMLowpFixture.h index 5e2154592e..3f83cc92f1 100644 --- a/tests/validation/fixtures/GEMMLowpFixture.h +++ b/tests/validation/fixtures/GEMMLowpFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -160,7 +160,7 @@ TensorType compute_gemmlowp_target(const TensorShape &shape_a, const TensorShape return output; } -template <bool reinterpret_input_as_3d, typename TI = uint8_t, typename TW = uint8_t> +template <bool reinterpret_input_as_3d, typename TI = uint8_t, typename TW = uint8_t, bool pretranspose_A = false, bool pretranspose_B = false> SimpleTensor<int32_t> compute_gemmlowp_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset, DataType data_type_a = DataType::QASYMM8, DataType data_type_b = DataType::QASYMM8, QuantizationInfo b_qinfo = QuantizationInfo()) { @@ -175,10 +175,33 @@ SimpleTensor<int32_t> compute_gemmlowp_reference(const TensorShape &shape_a, con SimpleTensor<TI> a{ shape_a_to_use, data_type_a, 1 }; SimpleTensor<TW> b{ shape_b, data_type_b, 1, data_type_b == DataType::QSYMM8_PER_CHANNEL ? b_qinfo : QuantizationInfo(1.0f / 255, b_offset) }; + TensorShape shape_a_to_use_transposed{ shape_a_to_use }; + TensorShape shape_b_transposed{ shape_b }; + + shape_a_to_use_transposed.set(0, shape_a_to_use[1]); + shape_a_to_use_transposed.set(1, shape_a_to_use[0]); + shape_b_transposed.set(0, shape_b[1]); + shape_b_transposed.set(1, shape_b[0]); + + SimpleTensor<TI> a_transposed{ shape_a_to_use_transposed, data_type_a, 1 }; + SimpleTensor<TW> b_transposed{ shape_b_transposed, data_type_b, 1, data_type_b == DataType::QSYMM8_PER_CHANNEL ? b_qinfo : QuantizationInfo(1.0f / 255, b_offset) }; + // Fill reference fill(a, 0); fill(b, 1); - return reference::gemmlowp_matrix_multiply_core<int32_t, TI, TW>(a, b, shape_output, a_offset, b_offset); + + // Transpose reference if required + if(pretranspose_A) + { + transpose_matrix<TI>(a, a_transposed); + } + + if(pretranspose_B) + { + transpose_matrix<TW>(b, b_transposed); + } + + return reference::gemmlowp_matrix_multiply_core<int32_t, TI, TW>((pretranspose_A ? a_transposed : a), (pretranspose_B ? b_transposed : b), shape_output, a_offset, b_offset); } } |