aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/GEMMLowpFixture.h
diff options
context:
space:
mode:
authorGiorgio Arena <giorgio.arena@arm.com>2021-06-09 15:23:06 +0100
committerGiorgio Arena <giorgio.arena@arm.com>2021-06-24 14:16:56 +0000
commit5f6fdc1d5f2d2b5ae843ddfafaffde7787516843 (patch)
treea5539d09e19e5afda3f820c029b71fc718609d12 /tests/validation/fixtures/GEMMLowpFixture.h
parent561c176598cd14245e2e7918fdf136d1c888d1da (diff)
downloadComputeLibrary-5f6fdc1d5f2d2b5ae843ddfafaffde7787516843.tar.gz
Rework gemmlowp reshaped_only_rhs using the new macros
Resolve COMPMID-4416 Change-Id: I83cdf0de7adaf4d465ffebd494ab913182072485 Signed-off-by: Giorgio Arena <giorgio.arena@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5788 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/fixtures/GEMMLowpFixture.h')
-rw-r--r--tests/validation/fixtures/GEMMLowpFixture.h31
1 files changed, 23 insertions, 8 deletions
diff --git a/tests/validation/fixtures/GEMMLowpFixture.h b/tests/validation/fixtures/GEMMLowpFixture.h
index 5cf210bab4..ab9d35de0f 100644
--- a/tests/validation/fixtures/GEMMLowpFixture.h
+++ b/tests/validation/fixtures/GEMMLowpFixture.h
@@ -97,7 +97,7 @@ void fill(U &&tensor, int i)
template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d, bool reinterpret_output_as_3d, typename OutputType, bool is_fused = false>
TensorType compute_gemmlowp_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset,
GEMMLowpOutputStageInfo output_stage = GEMMLowpOutputStageInfo(), DataType data_type_a = DataType::QASYMM8, DataType data_type_b = DataType::QASYMM8,
- QuantizationInfo b_qinfo = QuantizationInfo())
+ QuantizationInfo b_qinfo = QuantizationInfo(), bool reshape_b_only_on_first_run = false)
{
// Create tensors
DataType data_type_output = output_stage.type == GEMMLowpOutputStageType::NONE ? DataType::S32 : data_type_a;
@@ -126,7 +126,8 @@ TensorType compute_gemmlowp_target(const TensorShape &shape_a, const TensorShape
// Create and configure function
// The GEMMinfo includes the values of the depth in case of reinterpreted 3d input/output
FunctionType gemmlowp;
- gemmlowp.configure(&a, &b, is_fused ? &bias : nullptr, &output, GEMMInfo(false, false, false, (reinterpret_output_as_3d ? shape_output[2] : 0), reinterpret_input_as_3d, false, output_stage));
+ gemmlowp.configure(&a, &b, is_fused ? &bias : nullptr, &output, GEMMInfo(false, false, reshape_b_only_on_first_run, (reinterpret_output_as_3d ? shape_output[2] : 0), reinterpret_input_as_3d, false,
+ output_stage));
ARM_COMPUTE_ASSERT(a.info()->is_resizable());
ARM_COMPUTE_ASSERT(b.info()->is_resizable());
@@ -208,11 +209,12 @@ protected:
};
template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, typename TI = uint8_t, typename TW = uint8_t>
-class GEMMLowpMatrixMultiplyCoreFusedOffsetOutputValidationFixture : public framework::Fixture
+class GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset, GEMMLowpOutputStageInfo output_stage, DataType data_type_b)
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset, GEMMLowpOutputStageInfo output_stage, DataType data_type_b,
+ bool reshape_b_only_on_first_run)
{
ARM_COMPUTE_ASSERT(output_stage.type != GEMMLowpOutputStageType::NONE);
DataType data_type_a = data_type_b == DataType::QASYMM8_SIGNED ? DataType::QASYMM8_SIGNED : DataType::QASYMM8;
@@ -232,21 +234,21 @@ public:
}
_reference = compute_reference(shape_a, shape_b, shape_output, a_offset, 0, output_stage, data_type_a, data_type_b, QuantizationInfo(scales));
- _target = compute_target(shape_a, shape_b, shape_output, a_offset, 0, output_stage, data_type_a, data_type_b, QuantizationInfo(scales));
+ _target = compute_target(shape_a, shape_b, shape_output, a_offset, 0, output_stage, data_type_a, data_type_b, QuantizationInfo(scales), reshape_b_only_on_first_run);
}
else
{
_reference = compute_reference(shape_a, shape_b, shape_output, a_offset, b_offset, output_stage, data_type_a, data_type_b, QuantizationInfo());
- _target = compute_target(shape_a, shape_b, shape_output, a_offset, b_offset, output_stage, data_type_a, data_type_b, QuantizationInfo());
+ _target = compute_target(shape_a, shape_b, shape_output, a_offset, b_offset, output_stage, data_type_a, data_type_b, QuantizationInfo(), reshape_b_only_on_first_run);
}
}
protected:
TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset, GEMMLowpOutputStageInfo output_stage,
- DataType data_type_a, DataType data_type_b, QuantizationInfo b_qinfo)
+ DataType data_type_a, DataType data_type_b, QuantizationInfo b_qinfo, bool reshape_b_only_on_first_run = false)
{
return compute_gemmlowp_target<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, qasymm8_t, true>(shape_a, shape_b, shape_output, a_offset, b_offset,
- output_stage, data_type_a, data_type_b, b_qinfo);
+ output_stage, data_type_a, data_type_b, b_qinfo, reshape_b_only_on_first_run);
}
SimpleTensor<TI> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset,
@@ -277,6 +279,19 @@ protected:
SimpleTensor<TI> _reference{};
};
+template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, typename TI = uint8_t, typename TW = uint8_t>
+class GEMMLowpMatrixMultiplyCoreFusedOffsetOutputValidationFixture : public
+ GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, TI, TW>
+{
+public:
+ template <typename...>
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset, GEMMLowpOutputStageInfo output_stage, DataType data_type_b)
+ {
+ GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, TI, TW>::setup(shape_a, shape_b,
+ shape_output, a_offset, b_offset, output_stage, data_type_b, false);
+ }
+};
+
template <typename TensorType, typename AccessorType, typename FunctionType>
class GEMMLowpQuantizeDownInt32ToUint8ScaleValidationFixture : public framework::Fixture
{