diff options
Diffstat (limited to 'tests/validation/fixtures/DeconvolutionLayerFixture.h')
-rw-r--r-- | tests/validation/fixtures/DeconvolutionLayerFixture.h | 45 |
1 files changed, 25 insertions, 20 deletions
diff --git a/tests/validation/fixtures/DeconvolutionLayerFixture.h b/tests/validation/fixtures/DeconvolutionLayerFixture.h index a25a65f997..b819e651ff 100644 --- a/tests/validation/fixtures/DeconvolutionLayerFixture.h +++ b/tests/validation/fixtures/DeconvolutionLayerFixture.h @@ -51,12 +51,13 @@ public: public: template <typename...> void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, - DataType data_type, DataLayout data_layout, QuantizationInfo quantization_info, bool add_bias) + DataType data_type, DataLayout data_layout, QuantizationInfo input_quantization_info, QuantizationInfo output_quantization_info, bool add_bias) { - _data_type = data_type; - _bias_data_type = is_data_type_quantized_asymmetric(data_type) ? DataType::S32 : data_type; - _data_layout = data_layout; - _quantization_info = quantization_info; + _data_type = data_type; + _bias_data_type = is_data_type_quantized_asymmetric(data_type) ? DataType::S32 : data_type; + _data_layout = data_layout; + _input_quantization_info = input_quantization_info; + _output_quantization_info = output_quantization_info; _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, add_bias); _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, add_bias); @@ -126,10 +127,10 @@ protected: } // Create tensors - TensorType src = create_tensor<TensorType>(input_shape, _data_type, 1, _quantization_info, _data_layout); - TensorType weights = create_tensor<TensorType>(weights_shape, _data_type, 1, _quantization_info, _data_layout); - TensorType bias = create_tensor<TensorType>(bias_shape, _bias_data_type, 1, _quantization_info, _data_layout); - TensorType dst = create_tensor<TensorType>(output_shape, _data_type, 1, _quantization_info, _data_layout); + TensorType src = create_tensor<TensorType>(input_shape, _data_type, 1, _input_quantization_info, _data_layout); + TensorType weights = create_tensor<TensorType>(weights_shape, _data_type, 1, _input_quantization_info, _data_layout); + TensorType bias = create_tensor<TensorType>(bias_shape, _bias_data_type, 1, _input_quantization_info, _data_layout); + TensorType dst = create_tensor<TensorType>(output_shape, _data_type, 1, _output_quantization_info, _data_layout); // Create and configure function FunctionType conv; @@ -178,9 +179,9 @@ protected: const PadStrideInfo &info, bool add_bias) { // Create reference - SimpleTensor<T> src{ input_shape, _data_type, 1, _quantization_info }; - SimpleTensor<T> weights{ weights_shape, _data_type, 1, _quantization_info }; - SimpleTensor<TBias> bias{ bias_shape, _bias_data_type, 1, _quantization_info }; + SimpleTensor<T> src{ input_shape, _data_type, 1, _input_quantization_info }; + SimpleTensor<T> weights{ weights_shape, _data_type, 1, _input_quantization_info }; + SimpleTensor<TBias> bias{ bias_shape, _bias_data_type, 1, _input_quantization_info }; // Fill reference fill(src, 0); @@ -195,7 +196,7 @@ protected: fill_zeros(bias); } - return reference::deconvolution_layer<T>(src, weights, bias, output_shape, info); + return reference::deconvolution_layer<T>(src, weights, bias, output_shape, info, _output_quantization_info); } TensorType _target{}; @@ -203,7 +204,8 @@ protected: DataType _data_type{}; DataType _bias_data_type{}; DataLayout _data_layout{}; - QuantizationInfo _quantization_info{}; + QuantizationInfo _input_quantization_info{}; + QuantizationInfo _output_quantization_info{}; }; template <typename TensorType, typename AccessorType, typename FunctionType, typename T, unsigned int kernel_size_x, unsigned int kernel_size_y> @@ -222,7 +224,8 @@ public: TensorInfo input_info(input_shape, 1, data_type); TensorInfo weights_info(weights_shape, 1, data_type); TensorShape output_shape = compute_deconvolution_output_shape(out_dim, input_info, weights_info); - DeconvolutionLayerFixtureBase<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, data_type, data_layout, QuantizationInfo(), add_bias); + DeconvolutionLayerFixtureBase<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, data_type, data_layout, QuantizationInfo(), + QuantizationInfo(), add_bias); } }; @@ -242,7 +245,8 @@ public: TensorInfo input_info(input_shape, 1, data_type); TensorInfo weights_info(weights_shape, 1, data_type); TensorShape output_shape = compute_deconvolution_output_shape(out_dim, input_info, weights_info); - DeconvolutionLayerFixtureBase<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, data_type, data_layout, QuantizationInfo(), add_bias); + DeconvolutionLayerFixtureBase<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, data_type, data_layout, QuantizationInfo(), + QuantizationInfo(), add_bias); } }; @@ -252,17 +256,18 @@ class DeconvolutionValidationQuantizedFixture : public DeconvolutionLayerFixture public: template <typename...> void setup(TensorShape input_shape, unsigned int sx, unsigned int sy, unsigned int padx, unsigned int pady, - unsigned int num_kernels, DataType data_type, DataLayout data_layout, QuantizationInfo quantization_info, bool add_bias) + unsigned int num_kernels, DataType data_type, DataLayout data_layout, QuantizationInfo input_quantization_info, QuantizationInfo output_quantization_info, bool add_bias) { ARM_COMPUTE_ERROR_ON_MSG(kernel_size_x != kernel_size_y, "Only square kernels supported"); const TensorShape weights_shape(kernel_size_x, kernel_size_y, input_shape.z(), num_kernels); const TensorShape bias_shape(num_kernels); const PadStrideInfo info(sx, sy, padx, pady, DimensionRoundingType::CEIL); auto out_dim = deconvolution_output_dimensions(input_shape.x(), input_shape.y(), kernel_size_x, kernel_size_y, info); - TensorInfo input_info(input_shape, 1, data_type, quantization_info); - TensorInfo weights_info(weights_shape, 1, data_type, quantization_info); + TensorInfo input_info(input_shape, 1, data_type, input_quantization_info); + TensorInfo weights_info(weights_shape, 1, data_type, input_quantization_info); TensorShape output_shape = compute_deconvolution_output_shape(out_dim, input_info, weights_info); - DeconvolutionLayerFixtureBase<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, data_type, data_layout, quantization_info, add_bias); + DeconvolutionLayerFixtureBase<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, data_type, data_layout, input_quantization_info, + output_quantization_info, add_bias); } }; |