aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/DirectConvolutionLayerFixture.h
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2017-11-17 10:55:00 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:35:24 +0000
commit540d008180a39a84770e51b4bc891cd2ff85980e (patch)
tree500faadfafde7b92bdf7cb46c01501d2db01728f /tests/validation/fixtures/DirectConvolutionLayerFixture.h
parent43ce898014faabbe8ab1cdf714b7aad3d3c9b2a9 (diff)
downloadComputeLibrary-540d008180a39a84770e51b4bc891cd2ff85980e.tar.gz
COMPMID-556: Fixes bias in CLDirectConvolutionLayer to be int32.
Biases were incorrectly passed as uchar8 in asymmetric quantized calculations while they should be in int32. Change-Id: I461f5e4ef6eb44374c1094378a1bfe9ade5e247d Reviewed-on: http://mpd-gerrit.cambridge.arm.com/96244 Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'tests/validation/fixtures/DirectConvolutionLayerFixture.h')
-rw-r--r--tests/validation/fixtures/DirectConvolutionLayerFixture.h22
1 files changed, 13 insertions, 9 deletions
diff --git a/tests/validation/fixtures/DirectConvolutionLayerFixture.h b/tests/validation/fixtures/DirectConvolutionLayerFixture.h
index e302657158..279a4897eb 100644
--- a/tests/validation/fixtures/DirectConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/DirectConvolutionLayerFixture.h
@@ -44,6 +44,9 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ
class DirectConvolutionValidationGenericFixture : public framework::Fixture
{
public:
+ using TBias = typename std::conditional<std::is_same<typename std::decay<T>::type, uint8_t>::value, int32_t, T>::type;
+
+public:
template <typename...>
void setup(TensorShape input_shape, int stride_x, int stride_y, int pad_x, int pad_y, unsigned int kernel_size, unsigned int num_kernels,
DataType data_type, int fractional_bits, QuantizationInfo quantization_info)
@@ -55,10 +58,11 @@ public:
const TensorShape weights_shape(kernel_size, kernel_size, input_shape.z(), num_kernels);
const TensorShape bias_shape(num_kernels);
const PadStrideInfo info(stride_x, stride_y, pad_x, pad_y, DimensionRoundingType::FLOOR);
- const TensorShape output_shape = get_output_shape(input_shape, weights_shape, info);
+ const TensorShape output_shape = get_output_shape(input_shape, weights_shape, info);
+ const DataType bias_data_type = is_data_type_quantized_asymmetric(data_type) ? DataType::S32 : data_type;
- _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, data_type, fractional_bits, quantization_info);
- _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, data_type, fractional_bits, quantization_info);
+ _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, data_type, bias_data_type, fractional_bits, quantization_info);
+ _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, data_type, bias_data_type, fractional_bits, quantization_info);
}
protected:
@@ -86,12 +90,12 @@ protected:
}
TensorType compute_target(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info,
- DataType data_type, int fixed_point_position, QuantizationInfo quantization_info)
+ DataType data_type, DataType bias_data_type, int fixed_point_position, QuantizationInfo quantization_info)
{
// Create tensors
TensorType src = create_tensor<TensorType>(input_shape, data_type, 1, fixed_point_position, quantization_info);
TensorType weights = create_tensor<TensorType>(weights_shape, data_type, 1, fixed_point_position, quantization_info);
- TensorType bias = create_tensor<TensorType>(bias_shape, data_type, 1, fixed_point_position, quantization_info);
+ TensorType bias = create_tensor<TensorType>(bias_shape, bias_data_type, 1, fixed_point_position, quantization_info);
TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, fixed_point_position, quantization_info);
// Create and configure function
@@ -126,12 +130,12 @@ protected:
}
SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info,
- DataType data_type, int fixed_point_position, QuantizationInfo quantization_info)
+ DataType data_type, DataType bias_data_type, int fixed_point_position, QuantizationInfo quantization_info)
{
// Create reference
- SimpleTensor<T> src{ input_shape, data_type, 1, fixed_point_position, quantization_info };
- SimpleTensor<T> weights{ weights_shape, data_type, 1, fixed_point_position, quantization_info };
- SimpleTensor<T> bias{ bias_shape, data_type, 1, fixed_point_position, quantization_info };
+ SimpleTensor<T> src{ input_shape, data_type, 1, fixed_point_position, quantization_info };
+ SimpleTensor<T> weights{ weights_shape, data_type, 1, fixed_point_position, quantization_info };
+ SimpleTensor<TBias> bias{ bias_shape, bias_data_type, 1, fixed_point_position, quantization_info };
// Fill reference
fill(src, 0);