diff options
Diffstat (limited to 'tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h')
-rw-r--r-- | tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h | 16 |
1 files changed, 13 insertions, 3 deletions
diff --git a/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h b/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h index 85930eb95e..f909885245 100644 --- a/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h +++ b/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h @@ -449,12 +449,22 @@ class DepthwiseConvolutionLayerValidationQuantizedPerChannelFixture : public Dep public: template <typename...> void setup(TensorShape in_shape, Size2D kernel_size, PadStrideInfo pad_stride_info, Size2D dilation, unsigned int depth_multiplier, DataType input_data_type, DataType weights_data_type, - QuantizationInfo input_quantization_info, QuantizationInfo weights_quantization_info, QuantizationInfo output_quantization_info, - DataLayout data_layout, ActivationLayerInfo act_info) + QuantizationInfo input_quantization_info, QuantizationInfo output_quantization_info, DataLayout data_layout, ActivationLayerInfo act_info) { + const float out_scale = output_quantization_info.uniform().scale; + const float in_scale = input_quantization_info.uniform().scale; + + std::vector<float> weights_scales{}; + std::mt19937 gen(library->seed()); + std::uniform_real_distribution<> dis(0.01f, out_scale / in_scale); + for(size_t i = 0; i < in_shape.z() * depth_multiplier; ++i) + { + weights_scales.push_back(dis(gen)); + } + DepthwiseConvolutionLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T, TW>::setup(in_shape, kernel_size, pad_stride_info, dilation, depth_multiplier, input_data_type, weights_data_type, - input_quantization_info, weights_quantization_info, output_quantization_info, + input_quantization_info, QuantizationInfo(weights_scales), output_quantization_info, data_layout, act_info); } }; |