diff options
Diffstat (limited to 'tests/validation/fixtures/DirectConvolutionLayerTensorShiftFixture.h')
-rw-r--r-- | tests/validation/fixtures/DirectConvolutionLayerTensorShiftFixture.h | 48 |
1 files changed, 22 insertions, 26 deletions
diff --git a/tests/validation/fixtures/DirectConvolutionLayerTensorShiftFixture.h b/tests/validation/fixtures/DirectConvolutionLayerTensorShiftFixture.h index 09b6d830b4..144c7b7d0d 100644 --- a/tests/validation/fixtures/DirectConvolutionLayerTensorShiftFixture.h +++ b/tests/validation/fixtures/DirectConvolutionLayerTensorShiftFixture.h @@ -50,9 +50,8 @@ public: public: template <typename...> void setup(TensorShape input_shape, int stride_x, int stride_y, int pad_x, int pad_y, unsigned int kernel_size, unsigned int num_kernels, - DataType data_type, int fractional_bits, QuantizationInfo quantization_info) + DataType data_type, QuantizationInfo quantization_info) { - _fractional_bits = fractional_bits; _quantization_info = quantization_info; _data_type = data_type; @@ -62,24 +61,23 @@ public: const TensorShape output_shape = get_output_shape(input_shape, weights_shape, info); const DataType bias_data_type = is_data_type_quantized_asymmetric(data_type) ? DataType::S32 : data_type; - _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, data_type, bias_data_type, fractional_bits, quantization_info); - _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, data_type, bias_data_type, fractional_bits, quantization_info); + _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, data_type, bias_data_type, quantization_info); + _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, data_type, bias_data_type, quantization_info); } template <typename...> void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, unsigned int dilation_x, unsigned int dilation_y, - DataType data_type, int fractional_bits, QuantizationInfo quantization_info) + DataType data_type, QuantizationInfo quantization_info) { ARM_COMPUTE_UNUSED(dilation_x, dilation_y); - _fractional_bits = fractional_bits; _quantization_info = quantization_info; _data_type = data_type; const DataType bias_data_type = is_data_type_quantized_asymmetric(data_type) ? DataType::S32 : data_type; - _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, data_type, bias_data_type, fractional_bits, quantization_info); - _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, data_type, bias_data_type, fractional_bits, quantization_info); + _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, data_type, bias_data_type, quantization_info); + _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, data_type, bias_data_type, quantization_info); } protected: @@ -113,16 +111,16 @@ protected: } TensorType compute_target(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info, - DataType data_type, DataType bias_data_type, int fixed_point_position, QuantizationInfo quantization_info) + DataType data_type, DataType bias_data_type, QuantizationInfo quantization_info) { // Create tensors - TensorType src = create_tensor<TensorType>(input_shape, data_type, 1, fixed_point_position, quantization_info); - TensorType weights = create_tensor<TensorType>(weights_shape, data_type, 1, fixed_point_position, quantization_info); - TensorType bias = create_tensor<TensorType>(bias_shape, bias_data_type, 1, fixed_point_position, quantization_info); - TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, fixed_point_position, quantization_info); + TensorType src = create_tensor<TensorType>(input_shape, data_type, 1, quantization_info); + TensorType weights = create_tensor<TensorType>(weights_shape, data_type, 1, quantization_info); + TensorType bias = create_tensor<TensorType>(bias_shape, bias_data_type, 1, quantization_info); + TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, quantization_info); TensorShape output_shape1 = get_output_shape(output_shape, weights_shape, info); - TensorType dst1 = create_tensor<TensorType>(output_shape1, data_type, 1, fixed_point_position, quantization_info); + TensorType dst1 = create_tensor<TensorType>(output_shape1, data_type, 1, quantization_info); // Create and configure function FunctionType conv; @@ -164,14 +162,14 @@ protected: } SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info, - DataType data_type, DataType bias_data_type, int fixed_point_position, QuantizationInfo quantization_info) + DataType data_type, DataType bias_data_type, QuantizationInfo quantization_info) { // Create reference - SimpleTensor<T> src{ input_shape, data_type, 1, fixed_point_position, quantization_info }; - SimpleTensor<T> weights{ weights_shape, data_type, 1, fixed_point_position, quantization_info }; - SimpleTensor<TBias> bias{ bias_shape, bias_data_type, 1, fixed_point_position, quantization_info }; + SimpleTensor<T> src{ input_shape, data_type, 1, quantization_info }; + SimpleTensor<T> weights{ weights_shape, data_type, 1, quantization_info }; + SimpleTensor<TBias> bias{ bias_shape, bias_data_type, 1, quantization_info }; - SimpleTensor<T> dst{ output_shape, data_type, 1, fixed_point_position, quantization_info }; + SimpleTensor<T> dst{ output_shape, data_type, 1, quantization_info }; TensorShape output_shape1 = get_output_shape(output_shape, weights_shape, info); // Fill reference @@ -185,7 +183,6 @@ protected: TensorType _target{}; SimpleTensor<T> _reference{}; - int _fractional_bits{}; QuantizationInfo _quantization_info{}; DataType _data_type{}; @@ -212,7 +209,7 @@ public: template <typename...> void setup(TensorShape input_shape, int stride_x, int stride_y, int pad_x, int pad_y, unsigned int kernel_size, unsigned int num_kernels, DataType data_type) { - DirectConvolutionValidationGenericTensorShiftFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, stride_x, stride_y, pad_x, pad_y, kernel_size, num_kernels, data_type, 0, + DirectConvolutionValidationGenericTensorShiftFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, stride_x, stride_y, pad_x, pad_y, kernel_size, num_kernels, data_type, QuantizationInfo()); } }; @@ -222,10 +219,9 @@ class DirectConvolutionValidationFixedPointTensorShiftFixture : public DirectCon { public: template <typename...> - void setup(TensorShape input_shape, int stride_x, int stride_y, int pad_x, int pad_y, unsigned int kernel_size, unsigned int num_kernels, DataType data_type, int fractional_bits) + void setup(TensorShape input_shape, int stride_x, int stride_y, int pad_x, int pad_y, unsigned int kernel_size, unsigned int num_kernels, DataType data_type) { DirectConvolutionValidationGenericTensorShiftFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, stride_x, stride_y, pad_x, pad_y, kernel_size, num_kernels, data_type, - fractional_bits, QuantizationInfo()); } }; @@ -237,7 +233,7 @@ public: template <typename...> void setup(TensorShape input_shape, int stride_x, int stride_y, int pad_x, int pad_y, unsigned int kernel_size, unsigned int num_kernels, DataType data_type, QuantizationInfo quantization_info) { - DirectConvolutionValidationGenericTensorShiftFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, stride_x, stride_y, pad_x, pad_y, kernel_size, num_kernels, data_type, 0, + DirectConvolutionValidationGenericTensorShiftFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, stride_x, stride_y, pad_x, pad_y, kernel_size, num_kernels, data_type, quantization_info); } }; @@ -250,7 +246,7 @@ public: void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, unsigned int dilation_x, unsigned int dilation_y, DataType data_type, QuantizationInfo quantization_info) { - DirectConvolutionValidationGenericTensorShiftFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation_x, dilation_y, data_type, 0, + DirectConvolutionValidationGenericTensorShiftFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation_x, dilation_y, data_type, quantization_info); } }; @@ -263,7 +259,7 @@ public: void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, unsigned int dilation_x, unsigned int dilation_y, DataType data_type) { - DirectConvolutionValidationGenericTensorShiftFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation_x, dilation_y, data_type, 0, + DirectConvolutionValidationGenericTensorShiftFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation_x, dilation_y, data_type, QuantizationInfo()); } }; |