diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2019-10-16 19:21:40 +0100 |
---|---|---|
committer | Michele Di Giorgio <michele.digiorgio@arm.com> | 2019-11-06 15:53:24 +0000 |
commit | dbdea0d1c025b18d4d82c278c87454427918f5b4 (patch) | |
tree | 68bc25452f5d5b41006fb507c41516446cf8e457 /tests/validation/fixtures | |
parent | 75d47330e7ca0325cf5d83711452f6aeb085998f (diff) | |
download | ComputeLibrary-dbdea0d1c025b18d4d82c278c87454427918f5b4.tar.gz |
COMPMID-2308: NEConvolutionLayer: support QUANT8_SYMM_PER_CHANNEL filters
Change-Id: Ic1bf5f0d21ccd525f84213a360f7e199d7f50577
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-on: https://review.mlplatform.org/c/2177
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r-- | tests/validation/fixtures/ConvolutionLayerFixture.h | 77 |
1 files changed, 61 insertions, 16 deletions
diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h index 52fa8da60b..c5cddc28db 100644 --- a/tests/validation/fixtures/ConvolutionLayerFixture.h +++ b/tests/validation/fixtures/ConvolutionLayerFixture.h @@ -48,7 +48,7 @@ namespace test { namespace validation { -template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T, typename TW> class ConvolutionValidationGenericFixture : public framework::Fixture { public: @@ -57,13 +57,15 @@ public: public: template <typename...> void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, bool reshape_weights, - DataType data_type, DataLayout data_layout, QuantizationInfo quantization_info, ActivationLayerInfo act_info) + DataType data_type, DataType weights_data_type, DataLayout data_layout, QuantizationInfo quantization_info, QuantizationInfo weight_quantization_info, ActivationLayerInfo act_info) { - _data_type = data_type; - _is_quantized = is_data_type_quantized_asymmetric(data_type); - _bias_data_type = _is_quantized ? DataType::S32 : data_type; - _quantization_info = quantization_info; - _data_layout = data_layout; + _data_type = data_type; + _weights_data_type = weights_data_type; + _is_quantized = is_data_type_quantized_asymmetric(data_type); + _bias_data_type = _is_quantized ? DataType::S32 : data_type; + _quantization_info = quantization_info; + _weight_quantization_info = weight_quantization_info; + _data_layout = data_layout; _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, reshape_weights, dilation, act_info); _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, dilation, act_info); @@ -82,6 +84,26 @@ protected: library->fill(tensor, distribution, i); break; } + case DataType::QSYMM8_PER_CHANNEL: + { + int min_bound = 128; + int max_bound = -127; + for(size_t i = 0; i < _weight_quantization_info.scale().size(); i++) + { + std::pair<int, int> bounds = get_symm_quantized_per_channel_bounds(tensor.quantization_info(), -1.0f, 1.0f, i); + if(bounds.first < min_bound) + { + min_bound = bounds.first; + } + if(bounds.second > max_bound) + { + max_bound = bounds.second; + } + } + std::uniform_int_distribution<int8_t> distribution(min_bound, max_bound); + library->fill(tensor, distribution, i); + break; + } case DataType::S32: { std::uniform_int_distribution<int32_t> distribution(-100, 100); @@ -122,7 +144,7 @@ protected: // Create tensors TensorType src = create_tensor<TensorType>(input_shape, _data_type, 1, _quantization_info, _data_layout); - TensorType weights = create_tensor<TensorType>(reshaped_weights_shape, _data_type, 1, _quantization_info, _data_layout); + TensorType weights = create_tensor<TensorType>(reshaped_weights_shape, _weights_data_type, 1, _weight_quantization_info, _data_layout); TensorType bias = create_tensor<TensorType>(bias_shape, _bias_data_type, 1, _quantization_info, _data_layout); TensorType dst = create_tensor<TensorType>(output_shape, _data_type, 1, _quantization_info, _data_layout); @@ -166,7 +188,7 @@ protected: // Create reference SimpleTensor<T> src{ input_shape, _data_type, 1, _quantization_info }; - SimpleTensor<T> weights{ weights_shape, _data_type, 1, _quantization_info }; + SimpleTensor<TW> weights{ weights_shape, _weights_data_type, 1, _weight_quantization_info }; SimpleTensor<TBias> bias{ bias_shape, _bias_data_type, 1, _quantization_info }; // Fill reference @@ -182,36 +204,59 @@ protected: TensorType _target{}; SimpleTensor<T> _reference{}; DataType _data_type{}; + DataType _weights_data_type{}; DataType _bias_data_type{}; DataLayout _data_layout{}; QuantizationInfo _quantization_info{}; + QuantizationInfo _weight_quantization_info{}; bool _is_quantized = false; }; template <typename TensorType, typename AccessorType, typename FunctionType, typename T> -class ConvolutionValidationFixture : public ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T> +class ConvolutionValidationFixture : public ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, T> { public: template <typename...> void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, bool reshape_weights, DataType data_type, DataLayout data_layout, ActivationLayerInfo act_info) { - ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights, - data_type, data_layout, - QuantizationInfo(), act_info); + ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights, + data_type, data_type, data_layout, + QuantizationInfo(), QuantizationInfo(), act_info); } }; template <typename TensorType, typename AccessorType, typename FunctionType, typename T> -class ConvolutionValidationQuantizedFixture : public ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T> +class ConvolutionValidationQuantizedFixture : public ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, T> { public: template <typename...> void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, bool reshape_weights, DataType data_type, DataLayout data_layout, QuantizationInfo quantization_info, ActivationLayerInfo act_info) { - ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights, - data_type, data_layout, quantization_info, act_info); + ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights, + data_type, data_type, data_layout, quantization_info, quantization_info, act_info); + } +}; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename T, typename TW> +class ConvolutionValidationQuantizedPerChannelFixture : public ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, TW> +{ +public: + template <typename...> + void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, bool reshape_weights, DataType data_type, + DataLayout data_layout, QuantizationInfo quantization_info, ActivationLayerInfo act_info, DataType weights_data_type) + { + std::vector<float> weights_scales{}; + std::mt19937 gen(library->seed()); + std::uniform_real_distribution<> dis(0.01f, 1); + for(size_t i = 0; i < output_shape[2]; ++i) + { + weights_scales.push_back(dis(gen)); + } + ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, TW>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, + reshape_weights, data_type, weights_data_type, data_layout, + quantization_info, QuantizationInfo(weights_scales), act_info); } }; } // namespace validation |