diff options
Diffstat (limited to 'tests/validation/fixtures/DirectConvolutionLayerFixture.h')
-rw-r--r-- | tests/validation/fixtures/DirectConvolutionLayerFixture.h | 48 |
1 files changed, 28 insertions, 20 deletions
diff --git a/tests/validation/fixtures/DirectConvolutionLayerFixture.h b/tests/validation/fixtures/DirectConvolutionLayerFixture.h index fef9d2dc6e..ef7721dd5e 100644 --- a/tests/validation/fixtures/DirectConvolutionLayerFixture.h +++ b/tests/validation/fixtures/DirectConvolutionLayerFixture.h @@ -49,7 +49,7 @@ public: public: template <typename...> void setup(TensorShape input_shape, int stride_x, int stride_y, int pad_x, int pad_y, unsigned int kernel_size, unsigned int num_kernels, - DataType data_type, int fractional_bits, QuantizationInfo quantization_info) + DataType data_type, int fractional_bits, QuantizationInfo quantization_info, ActivationLayerInfo act_info) { _fractional_bits = fractional_bits; _quantization_info = quantization_info; @@ -61,13 +61,13 @@ public: const TensorShape output_shape = get_output_shape(input_shape, weights_shape, info); const DataType bias_data_type = is_data_type_quantized_asymmetric(data_type) ? DataType::S32 : data_type; - _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, data_type, bias_data_type, fractional_bits, quantization_info); - _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, data_type, bias_data_type, fractional_bits, quantization_info); + _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, data_type, bias_data_type, fractional_bits, quantization_info, act_info); + _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, data_type, bias_data_type, fractional_bits, quantization_info, act_info); } template <typename...> void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, - DataType data_type, int fractional_bits, QuantizationInfo quantization_info) + DataType data_type, int fractional_bits, QuantizationInfo quantization_info, ActivationLayerInfo act_info) { ARM_COMPUTE_UNUSED(dilation); @@ -77,8 +77,8 @@ public: const DataType bias_data_type = is_data_type_quantized_asymmetric(data_type) ? DataType::S32 : data_type; - _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, data_type, bias_data_type, fractional_bits, quantization_info); - _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, data_type, bias_data_type, fractional_bits, quantization_info); + _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, data_type, bias_data_type, fractional_bits, quantization_info, act_info); + _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, data_type, bias_data_type, fractional_bits, quantization_info, act_info); } protected: @@ -112,7 +112,7 @@ protected: } TensorType compute_target(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info, - DataType data_type, DataType bias_data_type, int fixed_point_position, QuantizationInfo quantization_info) + DataType data_type, DataType bias_data_type, int fixed_point_position, QuantizationInfo quantization_info, ActivationLayerInfo act_info) { // Create tensors TensorType src = create_tensor<TensorType>(input_shape, data_type, 1, fixed_point_position, quantization_info); @@ -122,7 +122,7 @@ protected: // Create and configure function FunctionType conv; - conv.configure(&src, &weights, &bias, &dst, info); + conv.configure(&src, &weights, &bias, &dst, info, act_info); ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(weights.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -152,7 +152,7 @@ protected: } SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info, - DataType data_type, DataType bias_data_type, int fixed_point_position, QuantizationInfo quantization_info) + DataType data_type, DataType bias_data_type, int fixed_point_position, QuantizationInfo quantization_info, ActivationLayerInfo act_info) { // Create reference SimpleTensor<T> src{ input_shape, data_type, 1, fixed_point_position, quantization_info }; @@ -164,7 +164,9 @@ protected: fill(weights, 1); fill(bias, 2); - return reference::convolution_layer<T>(src, weights, bias, output_shape, info); + return (act_info.enabled()) ? reference::activation_layer<T>(reference::convolution_layer<T>(src, weights, bias, output_shape, info), + act_info) : + reference::convolution_layer<T>(src, weights, bias, output_shape, info); } TensorType _target{}; @@ -194,9 +196,10 @@ class DirectConvolutionValidationFixture : public DirectConvolutionValidationGen { public: template <typename...> - void setup(TensorShape input_shape, int stride_x, int stride_y, int pad_x, int pad_y, unsigned int kernel_size, unsigned int num_kernels, DataType data_type) + void setup(TensorShape input_shape, int stride_x, int stride_y, int pad_x, int pad_y, unsigned int kernel_size, unsigned int num_kernels, DataType data_type, ActivationLayerInfo act_info) { - DirectConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, stride_x, stride_y, pad_x, pad_y, kernel_size, num_kernels, data_type, 0, QuantizationInfo()); + DirectConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, stride_x, stride_y, pad_x, pad_y, kernel_size, num_kernels, data_type, 0, QuantizationInfo(), + act_info); } }; @@ -205,10 +208,11 @@ class DirectConvolutionValidationFixedPointFixture : public DirectConvolutionVal { public: template <typename...> - void setup(TensorShape input_shape, int stride_x, int stride_y, int pad_x, int pad_y, unsigned int kernel_size, unsigned int num_kernels, DataType data_type, int fractional_bits) + void setup(TensorShape input_shape, int stride_x, int stride_y, int pad_x, int pad_y, unsigned int kernel_size, unsigned int num_kernels, DataType data_type, int fractional_bits, + ActivationLayerInfo act_info) { DirectConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, stride_x, stride_y, pad_x, pad_y, kernel_size, num_kernels, data_type, fractional_bits, - QuantizationInfo()); + QuantizationInfo(), act_info); } }; @@ -217,9 +221,11 @@ class DirectConvolutionValidationQuantizedFixture : public DirectConvolutionVali { public: template <typename...> - void setup(TensorShape input_shape, int stride_x, int stride_y, int pad_x, int pad_y, unsigned int kernel_size, unsigned int num_kernels, DataType data_type, QuantizationInfo quantization_info) + void setup(TensorShape input_shape, int stride_x, int stride_y, int pad_x, int pad_y, unsigned int kernel_size, unsigned int num_kernels, DataType data_type, QuantizationInfo quantization_info, + ActivationLayerInfo act_info) { - DirectConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, stride_x, stride_y, pad_x, pad_y, kernel_size, num_kernels, data_type, 0, quantization_info); + DirectConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, stride_x, stride_y, pad_x, pad_y, kernel_size, num_kernels, data_type, 0, quantization_info, + act_info); } }; @@ -229,9 +235,10 @@ class DirectConvolutionValidationWithTensorShapesQuantizedFixture : public Direc public: template <typename...> void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, - DataType data_type, QuantizationInfo quantization_info) + DataType data_type, QuantizationInfo quantization_info, ActivationLayerInfo act_info) { - DirectConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, data_type, 0, quantization_info); + DirectConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, data_type, 0, quantization_info, + act_info); } }; @@ -241,9 +248,10 @@ class DirectConvolutionValidationWithTensorShapesFixture : public DirectConvolut public: template <typename...> void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, - DataType data_type) + DataType data_type, ActivationLayerInfo act_info) { - DirectConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, data_type, 0, QuantizationInfo()); + DirectConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, data_type, 0, QuantizationInfo(), + act_info); } }; |