aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRamy Elgammal <ramy.elgammal@arm.com>2022-09-08 11:30:08 +0100
committerGunes Bayir <gunes.bayir@arm.com>2022-09-21 08:28:49 +0000
commita77c6d774053672b7bf0261e1a7a229bb6be5f21 (patch)
tree853427c98338094d25dc8c468c6fbb8d05797d04
parent047e5d058804e0107c77ee4e4384db55e7dd6bcf (diff)
downloadComputeLibrary-a77c6d774053672b7bf0261e1a7a229bb6be5f21.tar.gz
Add test for ClGemmLowpMatrixMultiplyCore to test a batched matrix multiplication with variable input tensors
Resolves: COMPMID-5506 Signed-off-by: Ramy Elgammal <ramy.elgammal@arm.com> Change-Id: I8345a3b7a83ef46f9ec7a77197cc65c933ec9ac6 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8239 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--tests/datasets/SmallGEMMLowpDataset.h16
-rw-r--r--tests/validation/CL/GEMMLowp.cpp11
-rw-r--r--tests/validation/fixtures/GEMMLowpFixture.h37
3 files changed, 54 insertions, 10 deletions
diff --git a/tests/datasets/SmallGEMMLowpDataset.h b/tests/datasets/SmallGEMMLowpDataset.h
index b8529ac165..929940d2d9 100644
--- a/tests/datasets/SmallGEMMLowpDataset.h
+++ b/tests/datasets/SmallGEMMLowpDataset.h
@@ -76,6 +76,22 @@ public:
add_config(TensorShape(16U, 16U, 5U, 3U), TensorShape(8U, 16U), TensorShape(8U, 16U, 5U, 3U), -9, 1);
}
};
+
+class SmallGEMMLowpBatchedMatMulDataset final : public GEMMLowpDataset
+{
+public:
+ SmallGEMMLowpBatchedMatMulDataset()
+ {
+ add_config(TensorShape(4U, 3U), TensorShape(2U, 4U), TensorShape(2U, 3U), 0, 0);
+ add_config(TensorShape(12U, 15U), TensorShape(7U, 12U), TensorShape(7U, 15U), 0, 0);
+ add_config(TensorShape(59U, 17U), TensorShape(36U, 59U), TensorShape(36U, 17U), -2, 13);
+ add_config(TensorShape(2U, 4U, 3U), TensorShape(5U, 2U, 3U), TensorShape(5U, 4U, 3U), -2, 0);
+ add_config(TensorShape(15U, 7U, 36U), TensorShape(29U, 15U, 36U), TensorShape(29U, 7U, 36U), -9, 1);
+ add_config(TensorShape(56U, 17U, 32U), TensorShape(5U, 56U, 32U), TensorShape(5U, 17U, 32U), -3, 2);
+ add_config(TensorShape(13U, 256U, 32U), TensorShape(19U, 13U, 32U), TensorShape(19U, 256U, 32U), 5, 13);
+ }
+};
+
} // namespace datasets
} // namespace test
} // namespace arm_compute
diff --git a/tests/validation/CL/GEMMLowp.cpp b/tests/validation/CL/GEMMLowp.cpp
index 52adb94c83..19e8eeb0f5 100644
--- a/tests/validation/CL/GEMMLowp.cpp
+++ b/tests/validation/CL/GEMMLowp.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -53,6 +53,7 @@ TEST_SUITE(GEMMLowp)
TEST_SUITE(MatrixMultiplyCore)
using CLGEMMLowpMatrixMultiplyCoreFixture = GEMMLowpMatrixMultiplyCoreValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore>;
+using CLGEMMLowpBatchedMatMulFixture = GEMMLowpMatrixMultiplyCoreValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore, false, false, true>;
FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpMatrixMultiplyCoreFixture, framework::DatasetMode::ALL, datasets::SmallGEMMLowpDataset())
{
@@ -65,7 +66,13 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMLowpMatrixMultiplyCoreFixture, framework:
// Validate output
validate(CLAccessor(_target), _reference);
}
-
+TEST_SUITE(BATCHED_MATMUL)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpBatchedMatMulFixture, framework::DatasetMode::ALL, datasets::SmallGEMMLowpBatchedMatMulDataset())
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+TEST_SUITE_END() // BATCHED_MATMUL
TEST_SUITE(FusedOffsetOutput)
TEST_SUITE(QASYMM8)
using CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputUint8Fixture = GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore>;
diff --git a/tests/validation/fixtures/GEMMLowpFixture.h b/tests/validation/fixtures/GEMMLowpFixture.h
index 6d073cd361..f1ec81aae6 100644
--- a/tests/validation/fixtures/GEMMLowpFixture.h
+++ b/tests/validation/fixtures/GEMMLowpFixture.h
@@ -26,8 +26,8 @@
#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
#include "tests/framework/Fixture.h"
-#include "tests/validation/reference/GEMMLowp.h"
#include "tests/validation/Validation.h"
+#include "tests/validation/reference/GEMMLowp.h"
namespace arm_compute
{
@@ -85,7 +85,7 @@ void fill(U &&tensor, int i)
}
}
-template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d, bool reinterpret_output_as_3d, typename OutputType, bool is_fused = false>
+template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d, bool reinterpret_output_as_3d, typename OutputType, bool is_fused = false, bool run_twice = false>
TensorType compute_gemmlowp_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset,
GEMMLowpOutputStageInfo output_stage = GEMMLowpOutputStageInfo(), DataType data_type_a = DataType::QASYMM8, DataType data_type_b = DataType::QASYMM8,
QuantizationInfo b_qinfo = QuantizationInfo(), bool reshape_b_only_on_first_run = false)
@@ -146,12 +146,25 @@ TensorType compute_gemmlowp_target(const TensorShape &shape_a, const TensorShape
ARM_COMPUTE_ASSERT(!bias.info()->is_resizable());
fill(AccessorType(bias), 2);
}
+
+ // Run with variable inputs.
+ if(run_twice)
+ {
+ gemmlowp.run();
+ fill(AccessorType(a), 3); // Fill tensors with new seed after run
+ fill(AccessorType(b), 4);
+ if(is_fused)
+ {
+ fill(AccessorType(bias), 5);
+ }
+ }
+
// Compute GEMM function
gemmlowp.run();
return output;
}
-template <bool reinterpret_input_as_3d, typename TI = uint8_t, typename TW = uint8_t, bool pretranspose_A = false, bool pretranspose_B = false>
+template <bool reinterpret_input_as_3d, typename TI = uint8_t, typename TW = uint8_t, bool pretranspose_A = false, bool pretranspose_B = false, bool run_twice = false>
SimpleTensor<int32_t> compute_gemmlowp_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset,
DataType data_type_a = DataType::QASYMM8, DataType data_type_b = DataType::QASYMM8, QuantizationInfo b_qinfo = QuantizationInfo())
{
@@ -196,11 +209,19 @@ SimpleTensor<int32_t> compute_gemmlowp_reference(const TensorShape &shape_a, con
transpose_matrix<TW>(b, b_transposed);
}
+ // Run with variable inputs.
+ if(run_twice)
+ {
+ reference::gemmlowp_matrix_multiply_core<int32_t, TI, TW>((pretranspose_A ? a_transposed : a), (pretranspose_B ? b_transposed : b), shape_output, a_offset, b_offset);
+ fill((pretranspose_A) ? a_transposed : a, 3);
+ fill((pretranspose_B) ? b_transposed : b, 4);
+ }
+
return reference::gemmlowp_matrix_multiply_core<int32_t, TI, TW>((pretranspose_A ? a_transposed : a), (pretranspose_B ? b_transposed : b), shape_output, a_offset, b_offset);
}
}
-template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false>
+template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, bool run_twice = false>
class GEMMLowpMatrixMultiplyCoreValidationFixture : public framework::Fixture
{
public:
@@ -214,12 +235,12 @@ public:
protected:
TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset)
{
- return compute_gemmlowp_target<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, int32_t>(shape_a, shape_b, shape_output, a_offset, b_offset);
+ return compute_gemmlowp_target<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, int32_t, false, run_twice>(shape_a, shape_b, shape_output, a_offset, b_offset);
}
SimpleTensor<int32_t> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset)
{
- return compute_gemmlowp_reference<reinterpret_input_as_3d>(shape_a, shape_b, shape_output, a_offset, b_offset);
+ return compute_gemmlowp_reference<reinterpret_input_as_3d, uint8_t, uint8_t, false, false, run_twice>(shape_a, shape_b, shape_output, a_offset, b_offset);
}
TensorType _target{};
@@ -1395,7 +1416,7 @@ public:
broadcast_bias ? 1 : m,
broadcast_bias ? 1 : batch_size);
- _target = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, output_stage, a_offset, b_offset);
+ _target = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, output_stage, a_offset, b_offset);
if(gemm_validated == true)
{
_reference = compute_reference(lhs_shape, rhs_shape, bias_shape, data_type, output_stage, a_offset, b_offset);
@@ -1584,7 +1605,7 @@ public:
const TensorShape lhs_shape(k, m, batch_size);
const TensorShape rhs_shape(n, k, batch_size);
- _target = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, data_type);
+ _target = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, data_type);
if(gemm_validated == true)
{
_reference = compute_reference(lhs_shape, rhs_shape, data_type);