aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/GEMMLowpFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/GEMMLowpFixture.h')
-rw-r--r--tests/validation/fixtures/GEMMLowpFixture.h29
1 files changed, 26 insertions, 3 deletions
diff --git a/tests/validation/fixtures/GEMMLowpFixture.h b/tests/validation/fixtures/GEMMLowpFixture.h
index 5e2154592e..3f83cc92f1 100644
--- a/tests/validation/fixtures/GEMMLowpFixture.h
+++ b/tests/validation/fixtures/GEMMLowpFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -160,7 +160,7 @@ TensorType compute_gemmlowp_target(const TensorShape &shape_a, const TensorShape
return output;
}
-template <bool reinterpret_input_as_3d, typename TI = uint8_t, typename TW = uint8_t>
+template <bool reinterpret_input_as_3d, typename TI = uint8_t, typename TW = uint8_t, bool pretranspose_A = false, bool pretranspose_B = false>
SimpleTensor<int32_t> compute_gemmlowp_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset,
DataType data_type_a = DataType::QASYMM8, DataType data_type_b = DataType::QASYMM8, QuantizationInfo b_qinfo = QuantizationInfo())
{
@@ -175,10 +175,33 @@ SimpleTensor<int32_t> compute_gemmlowp_reference(const TensorShape &shape_a, con
SimpleTensor<TI> a{ shape_a_to_use, data_type_a, 1 };
SimpleTensor<TW> b{ shape_b, data_type_b, 1, data_type_b == DataType::QSYMM8_PER_CHANNEL ? b_qinfo : QuantizationInfo(1.0f / 255, b_offset) };
+ TensorShape shape_a_to_use_transposed{ shape_a_to_use };
+ TensorShape shape_b_transposed{ shape_b };
+
+ shape_a_to_use_transposed.set(0, shape_a_to_use[1]);
+ shape_a_to_use_transposed.set(1, shape_a_to_use[0]);
+ shape_b_transposed.set(0, shape_b[1]);
+ shape_b_transposed.set(1, shape_b[0]);
+
+ SimpleTensor<TI> a_transposed{ shape_a_to_use_transposed, data_type_a, 1 };
+ SimpleTensor<TW> b_transposed{ shape_b_transposed, data_type_b, 1, data_type_b == DataType::QSYMM8_PER_CHANNEL ? b_qinfo : QuantizationInfo(1.0f / 255, b_offset) };
+
// Fill reference
fill(a, 0);
fill(b, 1);
- return reference::gemmlowp_matrix_multiply_core<int32_t, TI, TW>(a, b, shape_output, a_offset, b_offset);
+
+ // Transpose reference if required
+ if(pretranspose_A)
+ {
+ transpose_matrix<TI>(a, a_transposed);
+ }
+
+ if(pretranspose_B)
+ {
+ transpose_matrix<TW>(b, b_transposed);
+ }
+
+ return reference::gemmlowp_matrix_multiply_core<int32_t, TI, TW>((pretranspose_A ? a_transposed : a), (pretranspose_B ? b_transposed : b), shape_output, a_offset, b_offset);
}
}