From 45bcc3a1c287a208098ae99288273a5129ddd5eb Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Wed, 29 Nov 2017 11:06:49 +0000 Subject: COMPMID-661: QASYMM8 support for fully connected layer. Change-Id: I70e04d3a175ba366432ada98e9ca893c9f81b260 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/111094 Reviewed-by: Gian Marco Iodice Tested-by: BSG Visual Compute Jenkins server to access repositories on http://mpd-gerrit.cambridge.arm.com Reviewed-by: Anthony Barbier --- .../fixtures/FullyConnectedLayerFixture.h | 97 ++++++++++++++++------ 1 file changed, 71 insertions(+), 26 deletions(-) (limited to 'tests/validation/fixtures/FullyConnectedLayerFixture.h') diff --git a/tests/validation/fixtures/FullyConnectedLayerFixture.h b/tests/validation/fixtures/FullyConnectedLayerFixture.h index b19c40d5ea..dba20bb375 100644 --- a/tests/validation/fixtures/FullyConnectedLayerFixture.h +++ b/tests/validation/fixtures/FullyConnectedLayerFixture.h @@ -46,27 +46,43 @@ namespace test namespace validation { template -class FullyConnectedLayerValidationFixedPointFixture : public framework::Fixture +class FullyConnectedLayerValidationGenericFixture : public framework::Fixture { +public: + using TBias = typename std::conditional::type, uint8_t>::value, int32_t, T>::type; + public: template - void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, DataType data_type, int fractional_bits) + void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, + DataType data_type, int fractional_bits, QuantizationInfo quantization_info) { ARM_COMPUTE_UNUSED(weights_shape); ARM_COMPUTE_UNUSED(bias_shape); - _fractional_bits = fractional_bits; - _data_type = data_type; + _data_type = data_type; + _bias_data_type = is_data_type_quantized_asymmetric(data_type) ? DataType::S32 : data_type; + _fractional_bits = fractional_bits; + _quantization_info = quantization_info; - _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, reshape_weights, data_type, fractional_bits); - _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, reshape_weights, data_type, fractional_bits); + _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, reshape_weights); + _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, reshape_weights); } protected: template void fill(U &&tensor, int i) { - if(is_data_type_float(_data_type)) + if(is_data_type_quantized_asymmetric(_data_type)) + { + std::uniform_int_distribution distribution(0, 30); + library->fill(tensor, distribution, i); + } + else if(_data_type == DataType::S32) + { + std::uniform_int_distribution distribution(-50, 50); + library->fill(tensor, distribution, i); + } + else if(is_data_type_float(_data_type)) { std::uniform_real_distribution<> distribution(0.5f, 1.f); library->fill(tensor, distribution, i); @@ -78,7 +94,7 @@ protected: } TensorType compute_target(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, bool transpose_weights, - bool reshape_weights, DataType data_type, int fixed_point_position) + bool reshape_weights) { TensorShape reshaped_weights_shape(weights_shape); @@ -102,7 +118,7 @@ protected: // Transpose 1xW for batched version if(!reshape_weights && output_shape.y() > 1 && run_interleave) { - const int transpose_width = 16 / data_size_from_type(data_type); + const int transpose_width = 16 / data_size_from_type(_data_type); const float shape_x = reshaped_weights_shape.x(); reshaped_weights_shape.set(0, reshaped_weights_shape.y() * transpose_width); reshaped_weights_shape.set(1, static_cast(std::ceil(shape_x / transpose_width))); @@ -110,10 +126,10 @@ protected: } // Create tensors - TensorType src = create_tensor(input_shape, data_type, 1, fixed_point_position); - TensorType weights = create_tensor(reshaped_weights_shape, data_type, 1, fixed_point_position); - TensorType bias = create_tensor(bias_shape, data_type, 1, fixed_point_position); - TensorType dst = create_tensor(output_shape, data_type, 1, fixed_point_position); + TensorType src = create_tensor(input_shape, _data_type, 1, _fractional_bits, _quantization_info); + TensorType weights = create_tensor(reshaped_weights_shape, _data_type, 1, _fractional_bits, _quantization_info); + TensorType bias = create_tensor(bias_shape, _bias_data_type, 1, _fractional_bits, _quantization_info); + TensorType dst = create_tensor(output_shape, _data_type, 1, _fractional_bits, _quantization_info); // Create and configure function. FunctionType fc; @@ -142,7 +158,7 @@ protected: if(!reshape_weights || !transpose_weights) { TensorShape tmp_shape(weights_shape); - RawTensor tmp(tmp_shape, data_type, 1, fixed_point_position); + RawTensor tmp(tmp_shape, _data_type, 1, _fractional_bits); // Fill with original shape fill(tmp, 1); @@ -180,12 +196,12 @@ protected: } SimpleTensor compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, bool transpose_weights, - bool reshape_weights, DataType data_type, int fixed_point_position = 0) + bool reshape_weights) { // Create reference - SimpleTensor src{ input_shape, data_type, 1, fixed_point_position }; - SimpleTensor weights{ weights_shape, data_type, 1, fixed_point_position }; - SimpleTensor bias{ bias_shape, data_type, 1, fixed_point_position }; + SimpleTensor src{ input_shape, _data_type, 1, _fractional_bits, _quantization_info }; + SimpleTensor weights{ weights_shape, _data_type, 1, _fractional_bits, _quantization_info }; + SimpleTensor bias{ bias_shape, _bias_data_type, 1, _fractional_bits, _quantization_info }; // Fill reference fill(src, 0); @@ -195,22 +211,51 @@ protected: return reference::fully_connected_layer(src, weights, bias, output_shape); } - TensorType _target{}; - SimpleTensor _reference{}; - int _fractional_bits{}; - DataType _data_type{}; + TensorType _target{}; + SimpleTensor _reference{}; + DataType _data_type{}; + DataType _bias_data_type{}; + int _fractional_bits{}; + QuantizationInfo _quantization_info{}; }; template -class FullyConnectedLayerValidationFixture : public FullyConnectedLayerValidationFixedPointFixture +class FullyConnectedLayerValidationFixture : public FullyConnectedLayerValidationGenericFixture { public: template void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, DataType data_type) { - FullyConnectedLayerValidationFixedPointFixture::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, - reshape_weights, data_type, - 0); + FullyConnectedLayerValidationGenericFixture::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, + reshape_weights, data_type, + 0, QuantizationInfo()); + } +}; + +template +class FullyConnectedLayerValidationFixedPointFixture : public FullyConnectedLayerValidationGenericFixture +{ +public: + template + void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, DataType data_type, int fractional_bits) + { + FullyConnectedLayerValidationGenericFixture::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, + reshape_weights, data_type, + fractional_bits, QuantizationInfo()); + } +}; + +template +class FullyConnectedLayerValidationQuantizedFixture : public FullyConnectedLayerValidationGenericFixture +{ +public: + template + void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, DataType data_type, + QuantizationInfo quantization_info) + { + FullyConnectedLayerValidationGenericFixture::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, + reshape_weights, data_type, + 0, quantization_info); } }; } // namespace validation -- cgit v1.2.1