diff options
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r-- | tests/validation/fixtures/DirectConvolutionLayerFixture.h | 75 |
1 files changed, 54 insertions, 21 deletions
diff --git a/tests/validation/fixtures/DirectConvolutionLayerFixture.h b/tests/validation/fixtures/DirectConvolutionLayerFixture.h index a709157c7b..e302657158 100644 --- a/tests/validation/fixtures/DirectConvolutionLayerFixture.h +++ b/tests/validation/fixtures/DirectConvolutionLayerFixture.h @@ -41,22 +41,24 @@ namespace test namespace validation { template <typename TensorType, typename AccessorType, typename FunctionType, typename T> -class DirectConvolutionValidationFixedPointFixture : public framework::Fixture +class DirectConvolutionValidationGenericFixture : public framework::Fixture { public: template <typename...> - void setup(TensorShape input_shape, int stride_x, int stride_y, int pad_x, int pad_y, unsigned int kernel_size, unsigned int num_kernels, DataType data_type, int fractional_bits) + void setup(TensorShape input_shape, int stride_x, int stride_y, int pad_x, int pad_y, unsigned int kernel_size, unsigned int num_kernels, + DataType data_type, int fractional_bits, QuantizationInfo quantization_info) { - _fractional_bits = fractional_bits; - _data_type = data_type; + _fractional_bits = fractional_bits; + _quantization_info = quantization_info; + _data_type = data_type; const TensorShape weights_shape(kernel_size, kernel_size, input_shape.z(), num_kernels); const TensorShape bias_shape(num_kernels); const PadStrideInfo info(stride_x, stride_y, pad_x, pad_y, DimensionRoundingType::FLOOR); const TensorShape output_shape = get_output_shape(input_shape, weights_shape, info); - _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, data_type, fractional_bits); - _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, data_type, fractional_bits); + _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, data_type, fractional_bits, quantization_info); + _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, data_type, fractional_bits, quantization_info); } protected: @@ -65,6 +67,12 @@ protected: { switch(tensor.data_type()) { + case DataType::QASYMM8: + { + std::uniform_int_distribution<uint8_t> distribution(0, 10); + library->fill(tensor, distribution, i); + break; + } case DataType::F16: case DataType::F32: { @@ -78,13 +86,13 @@ protected: } TensorType compute_target(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info, - DataType data_type, int fixed_point_position) + DataType data_type, int fixed_point_position, QuantizationInfo quantization_info) { // Create tensors - TensorType src = create_tensor<TensorType>(input_shape, data_type, 1, fixed_point_position); - TensorType weights = create_tensor<TensorType>(weights_shape, data_type, 1, fixed_point_position); - TensorType bias = create_tensor<TensorType>(bias_shape, data_type, 1, fixed_point_position); - TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, fixed_point_position); + TensorType src = create_tensor<TensorType>(input_shape, data_type, 1, fixed_point_position, quantization_info); + TensorType weights = create_tensor<TensorType>(weights_shape, data_type, 1, fixed_point_position, quantization_info); + TensorType bias = create_tensor<TensorType>(bias_shape, data_type, 1, fixed_point_position, quantization_info); + TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, fixed_point_position, quantization_info); // Create and configure function FunctionType conv; @@ -118,12 +126,12 @@ protected: } SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info, - DataType data_type, int fixed_point_position) + DataType data_type, int fixed_point_position, QuantizationInfo quantization_info) { // Create reference - SimpleTensor<T> src{ input_shape, data_type, 1, fixed_point_position }; - SimpleTensor<T> weights{ weights_shape, data_type, 1, fixed_point_position }; - SimpleTensor<T> bias{ bias_shape, data_type, 1, fixed_point_position }; + SimpleTensor<T> src{ input_shape, data_type, 1, fixed_point_position, quantization_info }; + SimpleTensor<T> weights{ weights_shape, data_type, 1, fixed_point_position, quantization_info }; + SimpleTensor<T> bias{ bias_shape, data_type, 1, fixed_point_position, quantization_info }; // Fill reference fill(src, 0); @@ -133,10 +141,11 @@ protected: return reference::convolution_layer<T>(src, weights, bias, output_shape, info); } - TensorType _target{}; - SimpleTensor<T> _reference{}; - int _fractional_bits{}; - DataType _data_type{}; + TensorType _target{}; + SimpleTensor<T> _reference{}; + int _fractional_bits{}; + QuantizationInfo _quantization_info{}; + DataType _data_type{}; private: TensorShape get_output_shape(TensorShape in_shape, TensorShape kernel_shape, const PadStrideInfo &info) @@ -155,15 +164,39 @@ private: }; template <typename TensorType, typename AccessorType, typename FunctionType, typename T> -class DirectConvolutionValidationFixture : public DirectConvolutionValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T> +class DirectConvolutionValidationFixture : public DirectConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T> { public: template <typename...> void setup(TensorShape input_shape, int stride_x, int stride_y, int pad_x, int pad_y, unsigned int kernel_size, unsigned int num_kernels, DataType data_type) { - DirectConvolutionValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, stride_x, stride_y, pad_x, pad_y, kernel_size, num_kernels, data_type, 0); + DirectConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, stride_x, stride_y, pad_x, pad_y, kernel_size, num_kernels, data_type, 0, QuantizationInfo()); } }; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +class DirectConvolutionValidationFixedPointFixture : public DirectConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T> +{ +public: + template <typename...> + void setup(TensorShape input_shape, int stride_x, int stride_y, int pad_x, int pad_y, unsigned int kernel_size, unsigned int num_kernels, DataType data_type, int fractional_bits) + { + DirectConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, stride_x, stride_y, pad_x, pad_y, kernel_size, num_kernels, data_type, fractional_bits, + QuantizationInfo()); + } +}; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +class DirectConvolutionValidationQuantizedFixture : public DirectConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T> +{ +public: + template <typename...> + void setup(TensorShape input_shape, int stride_x, int stride_y, int pad_x, int pad_y, unsigned int kernel_size, unsigned int num_kernels, DataType data_type, QuantizationInfo quantization_info) + { + DirectConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, stride_x, stride_y, pad_x, pad_y, kernel_size, num_kernels, data_type, 0, quantization_info); + } +}; + } // namespace validation } // namespace test } // namespace arm_compute |