diff options
Diffstat (limited to 'tests/validation/fixtures/DirectConvolutionLayerFixture.h')
-rw-r--r-- | tests/validation/fixtures/DirectConvolutionLayerFixture.h | 22 |
1 files changed, 13 insertions, 9 deletions
diff --git a/tests/validation/fixtures/DirectConvolutionLayerFixture.h b/tests/validation/fixtures/DirectConvolutionLayerFixture.h index e302657158..279a4897eb 100644 --- a/tests/validation/fixtures/DirectConvolutionLayerFixture.h +++ b/tests/validation/fixtures/DirectConvolutionLayerFixture.h @@ -44,6 +44,9 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ class DirectConvolutionValidationGenericFixture : public framework::Fixture { public: + using TBias = typename std::conditional<std::is_same<typename std::decay<T>::type, uint8_t>::value, int32_t, T>::type; + +public: template <typename...> void setup(TensorShape input_shape, int stride_x, int stride_y, int pad_x, int pad_y, unsigned int kernel_size, unsigned int num_kernels, DataType data_type, int fractional_bits, QuantizationInfo quantization_info) @@ -55,10 +58,11 @@ public: const TensorShape weights_shape(kernel_size, kernel_size, input_shape.z(), num_kernels); const TensorShape bias_shape(num_kernels); const PadStrideInfo info(stride_x, stride_y, pad_x, pad_y, DimensionRoundingType::FLOOR); - const TensorShape output_shape = get_output_shape(input_shape, weights_shape, info); + const TensorShape output_shape = get_output_shape(input_shape, weights_shape, info); + const DataType bias_data_type = is_data_type_quantized_asymmetric(data_type) ? DataType::S32 : data_type; - _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, data_type, fractional_bits, quantization_info); - _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, data_type, fractional_bits, quantization_info); + _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, data_type, bias_data_type, fractional_bits, quantization_info); + _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, data_type, bias_data_type, fractional_bits, quantization_info); } protected: @@ -86,12 +90,12 @@ protected: } TensorType compute_target(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info, - DataType data_type, int fixed_point_position, QuantizationInfo quantization_info) + DataType data_type, DataType bias_data_type, int fixed_point_position, QuantizationInfo quantization_info) { // Create tensors TensorType src = create_tensor<TensorType>(input_shape, data_type, 1, fixed_point_position, quantization_info); TensorType weights = create_tensor<TensorType>(weights_shape, data_type, 1, fixed_point_position, quantization_info); - TensorType bias = create_tensor<TensorType>(bias_shape, data_type, 1, fixed_point_position, quantization_info); + TensorType bias = create_tensor<TensorType>(bias_shape, bias_data_type, 1, fixed_point_position, quantization_info); TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, fixed_point_position, quantization_info); // Create and configure function @@ -126,12 +130,12 @@ protected: } SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info, - DataType data_type, int fixed_point_position, QuantizationInfo quantization_info) + DataType data_type, DataType bias_data_type, int fixed_point_position, QuantizationInfo quantization_info) { // Create reference - SimpleTensor<T> src{ input_shape, data_type, 1, fixed_point_position, quantization_info }; - SimpleTensor<T> weights{ weights_shape, data_type, 1, fixed_point_position, quantization_info }; - SimpleTensor<T> bias{ bias_shape, data_type, 1, fixed_point_position, quantization_info }; + SimpleTensor<T> src{ input_shape, data_type, 1, fixed_point_position, quantization_info }; + SimpleTensor<T> weights{ weights_shape, data_type, 1, fixed_point_position, quantization_info }; + SimpleTensor<TBias> bias{ bias_shape, bias_data_type, 1, fixed_point_position, quantization_info }; // Fill reference fill(src, 0); |