diff options
Diffstat (limited to 'tests/validation/fixtures/MatMulFixture.h')
-rw-r--r-- | tests/validation/fixtures/MatMulFixture.h | 60 |
1 files changed, 45 insertions, 15 deletions
diff --git a/tests/validation/fixtures/MatMulFixture.h b/tests/validation/fixtures/MatMulFixture.h index 2f94c1f9d2..3e4cac5e34 100644 --- a/tests/validation/fixtures/MatMulFixture.h +++ b/tests/validation/fixtures/MatMulFixture.h @@ -112,14 +112,14 @@ protected: // Configure MatMulInfo class MatMulInfo mm_info; - mm_info.adj_lhs(transpose_a).adj_rhs(transpose_b).fused_activation(act_info); + mm_info.adj_lhs(transpose_a).adj_rhs(transpose_b); // Ensure values are dynamic a.info()->set_are_values_constant(false); b.info()->set_are_values_constant(false); // Configure operator - matmul.configure(&a, &b, &dst, mm_info, settings); + matmul.configure(&a, &b, &dst, mm_info, settings, act_info); // Assertions ARM_COMPUTE_ASSERT(a.info()->is_resizable()); @@ -162,8 +162,8 @@ protected: } template <typename TT> - typename std::enable_if<!std::is_integral<TT>::value, SimpleTensor<TT>>::type - compute_reference_gemm(const SimpleTensor<TT> &a, const SimpleTensor<TT> &b, const SimpleTensor<TT> &c, float alpha, float beta, const QuantizationInfo &o_qinfo) + typename std::enable_if < !std::is_integral<TT>::value, SimpleTensor<TT >>::type + compute_reference_gemm(const SimpleTensor<TT> &a, const SimpleTensor<TT> &b, const SimpleTensor<TT> &c, float alpha, float beta, const QuantizationInfo &o_qinfo) { ARM_COMPUTE_UNUSED(o_qinfo); @@ -172,7 +172,7 @@ protected: template <typename TT> typename std::enable_if<std::is_integral<TT>::value, SimpleTensor<TT>>::type - compute_reference_gemm(const SimpleTensor<TT> &a, const SimpleTensor<TT> &b, const SimpleTensor<TT> &c, float alpha, float beta, const QuantizationInfo &o_qinfo) + compute_reference_gemm(const SimpleTensor<TT> &a, const SimpleTensor<TT> &b, const SimpleTensor<TT> &c, float alpha, float beta, const QuantizationInfo &o_qinfo) { ARM_COMPUTE_UNUSED(alpha, beta); @@ -183,18 +183,18 @@ protected: const auto multiplier = aq.scale * bq.scale / oq.scale; int32_t output_multiplier = 0; - int32_t output_shift = 0; + int32_t output_shift = 0; quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift); std::vector<int32_t> output_multipliers{ output_multiplier }; std::vector<int32_t> output_shifts{ output_shift }; //The lhs and rhs offsets are negated here to keep the reference aligned with the function implementation where the lhs and rhs offsets are also negated. const auto tmp = reference::gemmlowp_matrix_multiply_core<int32_t>( - a, b, c.shape(), -aq.offset, -bq.offset); + a, b, c.shape(), -aq.offset, -bq.offset); auto output = reference::gemmlowp_quantize_down_scale_by_fixedpoint<int32_t, TT>( - tmp, output_multipliers, output_shifts, oq.offset, - std::numeric_limits<int32_t>::lowest(), std::numeric_limits<int32_t>::max()); + tmp, output_multipliers, output_shifts, oq.offset, + std::numeric_limits<int32_t>::lowest(), std::numeric_limits<int32_t>::max()); output.quantization_info(o_qinfo); return output; @@ -280,6 +280,30 @@ public: }; template <typename TensorType, typename AccessorType, typename FunctionType, typename Settings, typename T> +class MatMulValidationWithDynamicTensorsFixture : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T> +{ +public: + template <typename...> + void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo act_info, int num_extra_runs) + { + MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, Settings()); + } +}; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename Settings, typename T> +class QuantizedMatMulValidationFixture : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T> +{ +public: + template <typename...> + void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo act_info, int num_extra_runs, + QuantizationInfo a_qinfo, QuantizationInfo b_qinfo, QuantizationInfo o_qinfo) + { + MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, Settings(), + a_qinfo, b_qinfo, o_qinfo); + } +}; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename Settings, typename T> class MatMulValidationWithActivationFixture : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T> { public: @@ -291,24 +315,30 @@ public: }; template <typename TensorType, typename AccessorType, typename FunctionType, typename Settings, typename T> -class MatMulValidationWithDynamicTensorsFixture : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T> +class MatMulValidationWithActivationAlphaBetaFixture : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T> { public: template <typename...> - void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo act_info, int num_extra_runs) + void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo::ActivationFunction function, + float alpha_beta) { - MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, Settings()); + ActivationLayerInfo act_info(function, alpha_beta, alpha_beta); + MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, 0, Settings()); } }; template <typename TensorType, typename AccessorType, typename FunctionType, typename Settings, typename T> -class QuantizedMatMulValidationFixture : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T> +class QuantizedMatMulValidationWithActivationFixture : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T> { public: template <typename...> - void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo act_info, int num_extra_runs, QuantizationInfo a_qinfo, QuantizationInfo b_qinfo, QuantizationInfo o_qinfo) + void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo::ActivationFunction function, + float alpha_beta, int num_extra_runs, + QuantizationInfo a_qinfo, QuantizationInfo b_qinfo, QuantizationInfo o_qinfo) { - MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, Settings(), a_qinfo, b_qinfo, o_qinfo); + ActivationLayerInfo act_info(function, alpha_beta, alpha_beta); + MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, Settings(), + a_qinfo, b_qinfo, o_qinfo); } }; |