diff options
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r-- | tests/validation/fixtures/GEMMLowpFixture.h | 37 |
1 files changed, 29 insertions, 8 deletions
diff --git a/tests/validation/fixtures/GEMMLowpFixture.h b/tests/validation/fixtures/GEMMLowpFixture.h index 6d073cd361..f1ec81aae6 100644 --- a/tests/validation/fixtures/GEMMLowpFixture.h +++ b/tests/validation/fixtures/GEMMLowpFixture.h @@ -26,8 +26,8 @@ #include "arm_compute/core/utils/quantization/AsymmHelpers.h" #include "tests/framework/Fixture.h" -#include "tests/validation/reference/GEMMLowp.h" #include "tests/validation/Validation.h" +#include "tests/validation/reference/GEMMLowp.h" namespace arm_compute { @@ -85,7 +85,7 @@ void fill(U &&tensor, int i) } } -template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d, bool reinterpret_output_as_3d, typename OutputType, bool is_fused = false> +template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d, bool reinterpret_output_as_3d, typename OutputType, bool is_fused = false, bool run_twice = false> TensorType compute_gemmlowp_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset, GEMMLowpOutputStageInfo output_stage = GEMMLowpOutputStageInfo(), DataType data_type_a = DataType::QASYMM8, DataType data_type_b = DataType::QASYMM8, QuantizationInfo b_qinfo = QuantizationInfo(), bool reshape_b_only_on_first_run = false) @@ -146,12 +146,25 @@ TensorType compute_gemmlowp_target(const TensorShape &shape_a, const TensorShape ARM_COMPUTE_ASSERT(!bias.info()->is_resizable()); fill(AccessorType(bias), 2); } + + // Run with variable inputs. + if(run_twice) + { + gemmlowp.run(); + fill(AccessorType(a), 3); // Fill tensors with new seed after run + fill(AccessorType(b), 4); + if(is_fused) + { + fill(AccessorType(bias), 5); + } + } + // Compute GEMM function gemmlowp.run(); return output; } -template <bool reinterpret_input_as_3d, typename TI = uint8_t, typename TW = uint8_t, bool pretranspose_A = false, bool pretranspose_B = false> +template <bool reinterpret_input_as_3d, typename TI = uint8_t, typename TW = uint8_t, bool pretranspose_A = false, bool pretranspose_B = false, bool run_twice = false> SimpleTensor<int32_t> compute_gemmlowp_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset, DataType data_type_a = DataType::QASYMM8, DataType data_type_b = DataType::QASYMM8, QuantizationInfo b_qinfo = QuantizationInfo()) { @@ -196,11 +209,19 @@ SimpleTensor<int32_t> compute_gemmlowp_reference(const TensorShape &shape_a, con transpose_matrix<TW>(b, b_transposed); } + // Run with variable inputs. + if(run_twice) + { + reference::gemmlowp_matrix_multiply_core<int32_t, TI, TW>((pretranspose_A ? a_transposed : a), (pretranspose_B ? b_transposed : b), shape_output, a_offset, b_offset); + fill((pretranspose_A) ? a_transposed : a, 3); + fill((pretranspose_B) ? b_transposed : b, 4); + } + return reference::gemmlowp_matrix_multiply_core<int32_t, TI, TW>((pretranspose_A ? a_transposed : a), (pretranspose_B ? b_transposed : b), shape_output, a_offset, b_offset); } } -template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false> +template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, bool run_twice = false> class GEMMLowpMatrixMultiplyCoreValidationFixture : public framework::Fixture { public: @@ -214,12 +235,12 @@ public: protected: TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset) { - return compute_gemmlowp_target<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, int32_t>(shape_a, shape_b, shape_output, a_offset, b_offset); + return compute_gemmlowp_target<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, int32_t, false, run_twice>(shape_a, shape_b, shape_output, a_offset, b_offset); } SimpleTensor<int32_t> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset) { - return compute_gemmlowp_reference<reinterpret_input_as_3d>(shape_a, shape_b, shape_output, a_offset, b_offset); + return compute_gemmlowp_reference<reinterpret_input_as_3d, uint8_t, uint8_t, false, false, run_twice>(shape_a, shape_b, shape_output, a_offset, b_offset); } TensorType _target{}; @@ -1395,7 +1416,7 @@ public: broadcast_bias ? 1 : m, broadcast_bias ? 1 : batch_size); - _target = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, output_stage, a_offset, b_offset); + _target = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, output_stage, a_offset, b_offset); if(gemm_validated == true) { _reference = compute_reference(lhs_shape, rhs_shape, bias_shape, data_type, output_stage, a_offset, b_offset); @@ -1584,7 +1605,7 @@ public: const TensorShape lhs_shape(k, m, batch_size); const TensorShape rhs_shape(n, k, batch_size); - _target = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, data_type); + _target = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, data_type); if(gemm_validated == true) { _reference = compute_reference(lhs_shape, rhs_shape, data_type); |