aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/GEMMLowpFixture.h
diff options
context:
space:
mode:
authorRamy Elgammal <ramy.elgammal@arm.com>2022-09-08 11:30:08 +0100
committerGunes Bayir <gunes.bayir@arm.com>2022-09-21 08:28:49 +0000
commita77c6d774053672b7bf0261e1a7a229bb6be5f21 (patch)
tree853427c98338094d25dc8c468c6fbb8d05797d04 /tests/validation/fixtures/GEMMLowpFixture.h
parent047e5d058804e0107c77ee4e4384db55e7dd6bcf (diff)
downloadComputeLibrary-a77c6d774053672b7bf0261e1a7a229bb6be5f21.tar.gz
Add test for ClGemmLowpMatrixMultiplyCore to test a batched matrix multiplication with variable input tensors
Resolves: COMPMID-5506 Signed-off-by: Ramy Elgammal <ramy.elgammal@arm.com> Change-Id: I8345a3b7a83ef46f9ec7a77197cc65c933ec9ac6 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8239 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/fixtures/GEMMLowpFixture.h')
-rw-r--r--tests/validation/fixtures/GEMMLowpFixture.h37
1 files changed, 29 insertions, 8 deletions
diff --git a/tests/validation/fixtures/GEMMLowpFixture.h b/tests/validation/fixtures/GEMMLowpFixture.h
index 6d073cd361..f1ec81aae6 100644
--- a/tests/validation/fixtures/GEMMLowpFixture.h
+++ b/tests/validation/fixtures/GEMMLowpFixture.h
@@ -26,8 +26,8 @@
#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
#include "tests/framework/Fixture.h"
-#include "tests/validation/reference/GEMMLowp.h"
#include "tests/validation/Validation.h"
+#include "tests/validation/reference/GEMMLowp.h"
namespace arm_compute
{
@@ -85,7 +85,7 @@ void fill(U &&tensor, int i)
}
}
-template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d, bool reinterpret_output_as_3d, typename OutputType, bool is_fused = false>
+template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d, bool reinterpret_output_as_3d, typename OutputType, bool is_fused = false, bool run_twice = false>
TensorType compute_gemmlowp_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset,
GEMMLowpOutputStageInfo output_stage = GEMMLowpOutputStageInfo(), DataType data_type_a = DataType::QASYMM8, DataType data_type_b = DataType::QASYMM8,
QuantizationInfo b_qinfo = QuantizationInfo(), bool reshape_b_only_on_first_run = false)
@@ -146,12 +146,25 @@ TensorType compute_gemmlowp_target(const TensorShape &shape_a, const TensorShape
ARM_COMPUTE_ASSERT(!bias.info()->is_resizable());
fill(AccessorType(bias), 2);
}
+
+ // Run with variable inputs.
+ if(run_twice)
+ {
+ gemmlowp.run();
+ fill(AccessorType(a), 3); // Fill tensors with new seed after run
+ fill(AccessorType(b), 4);
+ if(is_fused)
+ {
+ fill(AccessorType(bias), 5);
+ }
+ }
+
// Compute GEMM function
gemmlowp.run();
return output;
}
-template <bool reinterpret_input_as_3d, typename TI = uint8_t, typename TW = uint8_t, bool pretranspose_A = false, bool pretranspose_B = false>
+template <bool reinterpret_input_as_3d, typename TI = uint8_t, typename TW = uint8_t, bool pretranspose_A = false, bool pretranspose_B = false, bool run_twice = false>
SimpleTensor<int32_t> compute_gemmlowp_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset,
DataType data_type_a = DataType::QASYMM8, DataType data_type_b = DataType::QASYMM8, QuantizationInfo b_qinfo = QuantizationInfo())
{
@@ -196,11 +209,19 @@ SimpleTensor<int32_t> compute_gemmlowp_reference(const TensorShape &shape_a, con
transpose_matrix<TW>(b, b_transposed);
}
+ // Run with variable inputs.
+ if(run_twice)
+ {
+ reference::gemmlowp_matrix_multiply_core<int32_t, TI, TW>((pretranspose_A ? a_transposed : a), (pretranspose_B ? b_transposed : b), shape_output, a_offset, b_offset);
+ fill((pretranspose_A) ? a_transposed : a, 3);
+ fill((pretranspose_B) ? b_transposed : b, 4);
+ }
+
return reference::gemmlowp_matrix_multiply_core<int32_t, TI, TW>((pretranspose_A ? a_transposed : a), (pretranspose_B ? b_transposed : b), shape_output, a_offset, b_offset);
}
}
-template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false>
+template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, bool run_twice = false>
class GEMMLowpMatrixMultiplyCoreValidationFixture : public framework::Fixture
{
public:
@@ -214,12 +235,12 @@ public:
protected:
TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset)
{
- return compute_gemmlowp_target<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, int32_t>(shape_a, shape_b, shape_output, a_offset, b_offset);
+ return compute_gemmlowp_target<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, int32_t, false, run_twice>(shape_a, shape_b, shape_output, a_offset, b_offset);
}
SimpleTensor<int32_t> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset)
{
- return compute_gemmlowp_reference<reinterpret_input_as_3d>(shape_a, shape_b, shape_output, a_offset, b_offset);
+ return compute_gemmlowp_reference<reinterpret_input_as_3d, uint8_t, uint8_t, false, false, run_twice>(shape_a, shape_b, shape_output, a_offset, b_offset);
}
TensorType _target{};
@@ -1395,7 +1416,7 @@ public:
broadcast_bias ? 1 : m,
broadcast_bias ? 1 : batch_size);
- _target = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, output_stage, a_offset, b_offset);
+ _target = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, output_stage, a_offset, b_offset);
if(gemm_validated == true)
{
_reference = compute_reference(lhs_shape, rhs_shape, bias_shape, data_type, output_stage, a_offset, b_offset);
@@ -1584,7 +1605,7 @@ public:
const TensorShape lhs_shape(k, m, batch_size);
const TensorShape rhs_shape(n, k, batch_size);
- _target = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, data_type);
+ _target = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, data_type);
if(gemm_validated == true)
{
_reference = compute_reference(lhs_shape, rhs_shape, data_type);