diff options
author | Giorgio Arena <giorgio.arena@arm.com> | 2021-06-09 15:23:06 +0100 |
---|---|---|
committer | Giorgio Arena <giorgio.arena@arm.com> | 2021-06-24 14:16:56 +0000 |
commit | 5f6fdc1d5f2d2b5ae843ddfafaffde7787516843 (patch) | |
tree | a5539d09e19e5afda3f820c029b71fc718609d12 /tests/validation/fixtures | |
parent | 561c176598cd14245e2e7918fdf136d1c888d1da (diff) | |
download | ComputeLibrary-5f6fdc1d5f2d2b5ae843ddfafaffde7787516843.tar.gz |
Rework gemmlowp reshaped_only_rhs using the new macros
Resolve COMPMID-4416
Change-Id: I83cdf0de7adaf4d465ffebd494ab913182072485
Signed-off-by: Giorgio Arena <giorgio.arena@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5788
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r-- | tests/validation/fixtures/GEMMLowpFixture.h | 31 |
1 files changed, 23 insertions, 8 deletions
diff --git a/tests/validation/fixtures/GEMMLowpFixture.h b/tests/validation/fixtures/GEMMLowpFixture.h index 5cf210bab4..ab9d35de0f 100644 --- a/tests/validation/fixtures/GEMMLowpFixture.h +++ b/tests/validation/fixtures/GEMMLowpFixture.h @@ -97,7 +97,7 @@ void fill(U &&tensor, int i) template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d, bool reinterpret_output_as_3d, typename OutputType, bool is_fused = false> TensorType compute_gemmlowp_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset, GEMMLowpOutputStageInfo output_stage = GEMMLowpOutputStageInfo(), DataType data_type_a = DataType::QASYMM8, DataType data_type_b = DataType::QASYMM8, - QuantizationInfo b_qinfo = QuantizationInfo()) + QuantizationInfo b_qinfo = QuantizationInfo(), bool reshape_b_only_on_first_run = false) { // Create tensors DataType data_type_output = output_stage.type == GEMMLowpOutputStageType::NONE ? DataType::S32 : data_type_a; @@ -126,7 +126,8 @@ TensorType compute_gemmlowp_target(const TensorShape &shape_a, const TensorShape // Create and configure function // The GEMMinfo includes the values of the depth in case of reinterpreted 3d input/output FunctionType gemmlowp; - gemmlowp.configure(&a, &b, is_fused ? &bias : nullptr, &output, GEMMInfo(false, false, false, (reinterpret_output_as_3d ? shape_output[2] : 0), reinterpret_input_as_3d, false, output_stage)); + gemmlowp.configure(&a, &b, is_fused ? &bias : nullptr, &output, GEMMInfo(false, false, reshape_b_only_on_first_run, (reinterpret_output_as_3d ? shape_output[2] : 0), reinterpret_input_as_3d, false, + output_stage)); ARM_COMPUTE_ASSERT(a.info()->is_resizable()); ARM_COMPUTE_ASSERT(b.info()->is_resizable()); @@ -208,11 +209,12 @@ protected: }; template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, typename TI = uint8_t, typename TW = uint8_t> -class GEMMLowpMatrixMultiplyCoreFusedOffsetOutputValidationFixture : public framework::Fixture +class GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture : public framework::Fixture { public: template <typename...> - void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset, GEMMLowpOutputStageInfo output_stage, DataType data_type_b) + void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset, GEMMLowpOutputStageInfo output_stage, DataType data_type_b, + bool reshape_b_only_on_first_run) { ARM_COMPUTE_ASSERT(output_stage.type != GEMMLowpOutputStageType::NONE); DataType data_type_a = data_type_b == DataType::QASYMM8_SIGNED ? DataType::QASYMM8_SIGNED : DataType::QASYMM8; @@ -232,21 +234,21 @@ public: } _reference = compute_reference(shape_a, shape_b, shape_output, a_offset, 0, output_stage, data_type_a, data_type_b, QuantizationInfo(scales)); - _target = compute_target(shape_a, shape_b, shape_output, a_offset, 0, output_stage, data_type_a, data_type_b, QuantizationInfo(scales)); + _target = compute_target(shape_a, shape_b, shape_output, a_offset, 0, output_stage, data_type_a, data_type_b, QuantizationInfo(scales), reshape_b_only_on_first_run); } else { _reference = compute_reference(shape_a, shape_b, shape_output, a_offset, b_offset, output_stage, data_type_a, data_type_b, QuantizationInfo()); - _target = compute_target(shape_a, shape_b, shape_output, a_offset, b_offset, output_stage, data_type_a, data_type_b, QuantizationInfo()); + _target = compute_target(shape_a, shape_b, shape_output, a_offset, b_offset, output_stage, data_type_a, data_type_b, QuantizationInfo(), reshape_b_only_on_first_run); } } protected: TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset, GEMMLowpOutputStageInfo output_stage, - DataType data_type_a, DataType data_type_b, QuantizationInfo b_qinfo) + DataType data_type_a, DataType data_type_b, QuantizationInfo b_qinfo, bool reshape_b_only_on_first_run = false) { return compute_gemmlowp_target<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, qasymm8_t, true>(shape_a, shape_b, shape_output, a_offset, b_offset, - output_stage, data_type_a, data_type_b, b_qinfo); + output_stage, data_type_a, data_type_b, b_qinfo, reshape_b_only_on_first_run); } SimpleTensor<TI> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset, @@ -277,6 +279,19 @@ protected: SimpleTensor<TI> _reference{}; }; +template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, typename TI = uint8_t, typename TW = uint8_t> +class GEMMLowpMatrixMultiplyCoreFusedOffsetOutputValidationFixture : public + GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, TI, TW> +{ +public: + template <typename...> + void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset, GEMMLowpOutputStageInfo output_stage, DataType data_type_b) + { + GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, TI, TW>::setup(shape_a, shape_b, + shape_output, a_offset, b_offset, output_stage, data_type_b, false); + } +}; + template <typename TensorType, typename AccessorType, typename FunctionType> class GEMMLowpQuantizeDownInt32ToUint8ScaleValidationFixture : public framework::Fixture { |