aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r--tests/validation/fixtures/ConvolutionLayerFixture.h27
-rw-r--r--tests/validation/fixtures/DirectConvolution3DFixture.h5
-rw-r--r--tests/validation/fixtures/GEMMLowpFixture.h50
3 files changed, 68 insertions, 14 deletions
diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h
index 0622e5e6f0..939ac032cd 100644
--- a/tests/validation/fixtures/ConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/ConvolutionLayerFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -480,6 +480,31 @@ public:
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T, typename TW>
+class ConvolutionValidationQuantizedMixedTypeFixture
+ : public ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, TW>
+{
+public:
+ void setup(TensorShape input_shape,
+ TensorShape weights_shape,
+ TensorShape bias_shape,
+ TensorShape output_shape,
+ PadStrideInfo info,
+ Size2D dilation,
+ bool reshape_weights,
+ DataType data_type,
+ DataType weights_data_type,
+ DataLayout data_layout,
+ QuantizationInfo quantization_info,
+ QuantizationInfo weight_quantization_info,
+ ActivationLayerInfo act_info)
+ {
+ ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, TW>::setup(
+ input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights, data_type,
+ weights_data_type, data_layout, quantization_info, weight_quantization_info, act_info);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T, typename TW>
class ConvolutionValidationQuantizedPerChannelFixture : public ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, TW>
{
public:
diff --git a/tests/validation/fixtures/DirectConvolution3DFixture.h b/tests/validation/fixtures/DirectConvolution3DFixture.h
index e80ad2f54f..e27a41a23b 100644
--- a/tests/validation/fixtures/DirectConvolution3DFixture.h
+++ b/tests/validation/fixtures/DirectConvolution3DFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021, 2023 Arm Limited.
+ * Copyright (c) 2021, 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -46,6 +46,7 @@ class DirectConvolution3DValidationGenericFixture : public framework::Fixture
{
public:
using TBias = typename std::conditional < std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value, int32_t, T >::type;
+ using TAcc = typename std::conditional < std::is_integral<T>::value, int32_t, float >::type;
void setup(const TensorShape &input_shape, int stride_x, int stride_y, int stride_z, int pad_x, int pad_y, int pad_z, unsigned int kernel_width, int kernel_height, int kernel_depth,
unsigned int num_kernels, bool has_bias, const ActivationLayerInfo &act_info, const DataType &data_type, const DataLayout &data_layout,
@@ -150,7 +151,7 @@ protected:
fill(bias, 2);
}
- return reference::activation_layer(reference::conv3d<T, TBias>(src, weights, bias, dst, conv3d_info), conv3d_info.act_info);
+ return reference::activation_layer(reference::conv3d<T, TBias, TAcc>(src, weights, bias, dst, conv3d_info), conv3d_info.act_info);
}
TensorType _target{};
diff --git a/tests/validation/fixtures/GEMMLowpFixture.h b/tests/validation/fixtures/GEMMLowpFixture.h
index aa4eedb75d..7931d8467d 100644
--- a/tests/validation/fixtures/GEMMLowpFixture.h
+++ b/tests/validation/fixtures/GEMMLowpFixture.h
@@ -97,8 +97,7 @@ TensorType compute_gemmlowp_target(const TensorShape &shape_a, const TensorShape
bool accumulate = false, bool dynamic_qinfo = false, DataType data_type_output = DataType::UNKNOWN)
{
ARM_COMPUTE_ASSERT(is_data_type_quantized_asymmetric(data_type_a));
- ARM_COMPUTE_ASSERT(data_type_a == data_type_b);
- // If unknown, set to sensible defaults
+ // If unknown, set to sensible defaults
if (data_type_output == DataType::UNKNOWN) {
data_type_output = output_stage.type == GEMMLowpOutputStageType::NONE ? DataType::S32 : data_type_a;
}
@@ -185,7 +184,6 @@ SimpleTensor<int32_t> compute_gemmlowp_reference(const TensorShape &shape_a, con
DataType data_type_a = DataType::QASYMM8, DataType data_type_b = DataType::QASYMM8, const TensorFillInfo& finfo = TensorFillInfo())
{
ARM_COMPUTE_ASSERT(is_data_type_quantized_asymmetric(data_type_a));
- ARM_COMPUTE_ASSERT(data_type_a == data_type_b);
TensorShape shape_a_to_use = shape_a;
if(reinterpret_input_as_3d)
{
@@ -472,29 +470,59 @@ template <typename TensorType, typename AccessorType, typename FunctionType, boo
class GEMMLowpDequantizedMatrixMultiplyValidationFixture : public framework::Fixture
{
public:
- void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset, bool accumulate)
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset, DataType data_type_a, DataType data_type_b, bool accumulate)
{
const bool dynamic_qinfo = false;
const auto a_qinfo = QuantizationInfo(1.0f / 255, a_offset);
const auto b_qinfo = QuantizationInfo(5.0f / 255, b_offset);
TensorFillInfo finfo;
- _target = compute_target(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, finfo, accumulate, dynamic_qinfo);
- _reference = compute_reference(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, finfo, accumulate, dynamic_qinfo);
+ _target = compute_target(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, data_type_a, data_type_b, finfo,
+ accumulate, dynamic_qinfo);
+ _reference = compute_reference(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, data_type_a, data_type_b,
+ finfo, accumulate, dynamic_qinfo);
}
protected:
- TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo, const TensorFillInfo& finfo, const bool accumulate, const bool dynamic_qinfo)
+ TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo, DataType data_type_a, DataType data_type_b, const TensorFillInfo& finfo, const bool accumulate, const bool dynamic_qinfo)
{
const auto output_qinfo = QuantizationInfo();
- return compute_gemmlowp_target<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, int32_t, false, run_twice>(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, output_qinfo, DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, GEMMLowpOutputStageInfo(), false, finfo, accumulate, dynamic_qinfo, DataType::F32);
+ return compute_gemmlowp_target<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, int32_t, false, run_twice>(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, output_qinfo, data_type_a, data_type_b, GEMMLowpOutputStageInfo(), false, finfo, accumulate, dynamic_qinfo, DataType::F32);
}
- SimpleTensor<float> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo, const TensorFillInfo& finfo, bool accumulate, const bool dynamic_qinfo)
+ SimpleTensor<float> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo, DataType data_type_a, DataType data_type_b, const TensorFillInfo& finfo, bool accumulate, const bool dynamic_qinfo)
{
QuantizationInfo s32_ref_output_quant_info = QuantizationInfo(a_qinfo.uniform().scale * b_qinfo.uniform().scale, 0, dynamic_qinfo);
- SimpleTensor<int32_t> s32_ref_output = compute_gemmlowp_reference<reinterpret_input_as_3d, int8_t, int8_t, false, false, run_twice>(shape_a, shape_b, shape_output, a_qinfo, b_qinfo,
- DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, finfo);
+ SimpleTensor<int32_t> s32_ref_output;
+ if (data_type_a == DataType::QASYMM8)
+ {
+ if (data_type_b == DataType::QASYMM8)
+ {
+ s32_ref_output = compute_gemmlowp_reference<reinterpret_input_as_3d, uint8_t, uint8_t, false, false, run_twice>(
+ shape_a, shape_b, shape_output, a_qinfo, b_qinfo, data_type_a, data_type_b, finfo);
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR_ON(data_type_b != DataType::QASYMM8_SIGNED);
+ s32_ref_output = compute_gemmlowp_reference<reinterpret_input_as_3d, uint8_t, int8_t, false, false, run_twice>(
+ shape_a, shape_b, shape_output, a_qinfo, b_qinfo, data_type_a, data_type_b, finfo);
+ }
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR_ON(data_type_a != DataType::QASYMM8_SIGNED);
+ if (data_type_b == DataType::QASYMM8)
+ {
+ ARM_COMPUTE_ERROR("QASYMM8_SIGNED input with QASYMM8 weights not supported");
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR_ON(data_type_b != DataType::QASYMM8_SIGNED);
+ s32_ref_output = compute_gemmlowp_reference<reinterpret_input_as_3d, int8_t, int8_t, false, false, run_twice>(
+ shape_a, shape_b, shape_output, a_qinfo, b_qinfo, data_type_a, data_type_b, finfo);
+ }
+ }
+
s32_ref_output.quantization_info(s32_ref_output_quant_info);
SimpleTensor<float> f32_ref_output(s32_ref_output.shape(), DataType::F32);