aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/MatMulFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/MatMulFixture.h')
-rw-r--r--tests/validation/fixtures/MatMulFixture.h60
1 files changed, 45 insertions, 15 deletions
diff --git a/tests/validation/fixtures/MatMulFixture.h b/tests/validation/fixtures/MatMulFixture.h
index 2f94c1f9d2..3e4cac5e34 100644
--- a/tests/validation/fixtures/MatMulFixture.h
+++ b/tests/validation/fixtures/MatMulFixture.h
@@ -112,14 +112,14 @@ protected:
// Configure MatMulInfo class
MatMulInfo mm_info;
- mm_info.adj_lhs(transpose_a).adj_rhs(transpose_b).fused_activation(act_info);
+ mm_info.adj_lhs(transpose_a).adj_rhs(transpose_b);
// Ensure values are dynamic
a.info()->set_are_values_constant(false);
b.info()->set_are_values_constant(false);
// Configure operator
- matmul.configure(&a, &b, &dst, mm_info, settings);
+ matmul.configure(&a, &b, &dst, mm_info, settings, act_info);
// Assertions
ARM_COMPUTE_ASSERT(a.info()->is_resizable());
@@ -162,8 +162,8 @@ protected:
}
template <typename TT>
- typename std::enable_if<!std::is_integral<TT>::value, SimpleTensor<TT>>::type
- compute_reference_gemm(const SimpleTensor<TT> &a, const SimpleTensor<TT> &b, const SimpleTensor<TT> &c, float alpha, float beta, const QuantizationInfo &o_qinfo)
+ typename std::enable_if < !std::is_integral<TT>::value, SimpleTensor<TT >>::type
+ compute_reference_gemm(const SimpleTensor<TT> &a, const SimpleTensor<TT> &b, const SimpleTensor<TT> &c, float alpha, float beta, const QuantizationInfo &o_qinfo)
{
ARM_COMPUTE_UNUSED(o_qinfo);
@@ -172,7 +172,7 @@ protected:
template <typename TT>
typename std::enable_if<std::is_integral<TT>::value, SimpleTensor<TT>>::type
- compute_reference_gemm(const SimpleTensor<TT> &a, const SimpleTensor<TT> &b, const SimpleTensor<TT> &c, float alpha, float beta, const QuantizationInfo &o_qinfo)
+ compute_reference_gemm(const SimpleTensor<TT> &a, const SimpleTensor<TT> &b, const SimpleTensor<TT> &c, float alpha, float beta, const QuantizationInfo &o_qinfo)
{
ARM_COMPUTE_UNUSED(alpha, beta);
@@ -183,18 +183,18 @@ protected:
const auto multiplier = aq.scale * bq.scale / oq.scale;
int32_t output_multiplier = 0;
- int32_t output_shift = 0;
+ int32_t output_shift = 0;
quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift);
std::vector<int32_t> output_multipliers{ output_multiplier };
std::vector<int32_t> output_shifts{ output_shift };
//The lhs and rhs offsets are negated here to keep the reference aligned with the function implementation where the lhs and rhs offsets are also negated.
const auto tmp = reference::gemmlowp_matrix_multiply_core<int32_t>(
- a, b, c.shape(), -aq.offset, -bq.offset);
+ a, b, c.shape(), -aq.offset, -bq.offset);
auto output = reference::gemmlowp_quantize_down_scale_by_fixedpoint<int32_t, TT>(
- tmp, output_multipliers, output_shifts, oq.offset,
- std::numeric_limits<int32_t>::lowest(), std::numeric_limits<int32_t>::max());
+ tmp, output_multipliers, output_shifts, oq.offset,
+ std::numeric_limits<int32_t>::lowest(), std::numeric_limits<int32_t>::max());
output.quantization_info(o_qinfo);
return output;
@@ -280,6 +280,30 @@ public:
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename Settings, typename T>
+class MatMulValidationWithDynamicTensorsFixture : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>
+{
+public:
+ template <typename...>
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo act_info, int num_extra_runs)
+ {
+ MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, Settings());
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename Settings, typename T>
+class QuantizedMatMulValidationFixture : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>
+{
+public:
+ template <typename...>
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo act_info, int num_extra_runs,
+ QuantizationInfo a_qinfo, QuantizationInfo b_qinfo, QuantizationInfo o_qinfo)
+ {
+ MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, Settings(),
+ a_qinfo, b_qinfo, o_qinfo);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename Settings, typename T>
class MatMulValidationWithActivationFixture : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>
{
public:
@@ -291,24 +315,30 @@ public:
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename Settings, typename T>
-class MatMulValidationWithDynamicTensorsFixture : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>
+class MatMulValidationWithActivationAlphaBetaFixture : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>
{
public:
template <typename...>
- void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo act_info, int num_extra_runs)
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo::ActivationFunction function,
+ float alpha_beta)
{
- MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, Settings());
+ ActivationLayerInfo act_info(function, alpha_beta, alpha_beta);
+ MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, 0, Settings());
}
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename Settings, typename T>
-class QuantizedMatMulValidationFixture : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>
+class QuantizedMatMulValidationWithActivationFixture : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>
{
public:
template <typename...>
- void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo act_info, int num_extra_runs, QuantizationInfo a_qinfo, QuantizationInfo b_qinfo, QuantizationInfo o_qinfo)
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo::ActivationFunction function,
+ float alpha_beta, int num_extra_runs,
+ QuantizationInfo a_qinfo, QuantizationInfo b_qinfo, QuantizationInfo o_qinfo)
{
- MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, Settings(), a_qinfo, b_qinfo, o_qinfo);
+ ActivationLayerInfo act_info(function, alpha_beta, alpha_beta);
+ MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, Settings(),
+ a_qinfo, b_qinfo, o_qinfo);
}
};