From dbdea0d1c025b18d4d82c278c87454427918f5b4 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Wed, 16 Oct 2019 19:21:40 +0100 Subject: COMPMID-2308: NEConvolutionLayer: support QUANT8_SYMM_PER_CHANNEL filters Change-Id: Ic1bf5f0d21ccd525f84213a360f7e199d7f50577 Signed-off-by: Georgios Pinitas Reviewed-on: https://review.mlplatform.org/c/2177 Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio Comments-Addressed: Arm Jenkins --- tests/validation/Helpers.cpp | 18 +++++ tests/validation/Helpers.h | 18 +++++ tests/validation/NEON/ConvolutionLayer.cpp | 37 +++++++++++ .../validation/fixtures/ConvolutionLayerFixture.h | 77 +++++++++++++++++----- tests/validation/reference/Convolution3d.h | 51 +++++++++----- tests/validation/reference/ConvolutionLayer.cpp | 20 +++--- tests/validation/reference/ConvolutionLayer.h | 4 +- 7 files changed, 180 insertions(+), 45 deletions(-) (limited to 'tests/validation') diff --git a/tests/validation/Helpers.cpp b/tests/validation/Helpers.cpp index 4158793295..95a5548628 100644 --- a/tests/validation/Helpers.cpp +++ b/tests/validation/Helpers.cpp @@ -326,6 +326,24 @@ std::pair get_quantized_bounds(const QuantizationInfo &quant_info, flo return std::pair { min_bound, max_bound }; } +std::pair get_symm_quantized_per_channel_bounds(const QuantizationInfo &quant_info, float min, float max, size_t channel_id) +{ + ARM_COMPUTE_ERROR_ON_MSG(min > max, "min must be lower equal than max"); + + const int min_bound = quantize_qsymm8_per_channel(min, quant_info, channel_id); + const int max_bound = quantize_qsymm8_per_channel(max, quant_info, channel_id); + return std::pair { min_bound, max_bound }; +} + +std::pair get_asymm_quantized_per_channel_bounds(const QuantizationInfo &quant_info, float min, float max, size_t channel_id) +{ + ARM_COMPUTE_ERROR_ON_MSG(min > max, "min must be lower equal than max"); + + const int min_bound = quantize_qasymm8_per_channel(min, quant_info, channel_id); + const int max_bound = quantize_qasymm8_per_channel(max, quant_info, channel_id); + return std::pair { min_bound, max_bound }; +} + template void get_tile(const SimpleTensor &in, SimpleTensor &roi, const Coordinates &coord); template void get_tile(const SimpleTensor &in, SimpleTensor &roi, const Coordinates &coord); template void get_tile(const SimpleTensor &in, SimpleTensor &roi, const Coordinates &coord); diff --git a/tests/validation/Helpers.h b/tests/validation/Helpers.h index 2ee2dc7aab..2c1df39f14 100644 --- a/tests/validation/Helpers.h +++ b/tests/validation/Helpers.h @@ -276,6 +276,24 @@ void zeros(SimpleTensor &in, const Coordinates &anchor, const TensorShape &sh * @param[in] max Floating point maximum value to be quantized */ std::pair get_quantized_bounds(const QuantizationInfo &quant_info, float min, float max); + +/** Helper function to compute symmetric quantized min and max bounds + * + * @param[in] quant_info Quantization info to be used for conversion + * @param[in] min Floating point minimum value to be quantized + * @param[in] max Floating point maximum value to be quantized + * @param[in] channel_id Channel id for per channel quantization info. + */ +std::pair get_symm_quantized_per_channel_bounds(const QuantizationInfo &quant_info, float min, float max, size_t channel_id = 0); + +/** Helper function to compute asymmetric quantized min and max bounds + * + * @param[in] quant_info Quantization info to be used for conversion + * @param[in] min Floating point minimum value to be quantized + * @param[in] max Floating point maximum value to be quantized + * @param[in] channel_id Channel id for per channel quantization info. + */ +std::pair get_asymm_quantized_per_channel_bounds(const QuantizationInfo &quant_info, float min, float max, size_t channel_id = 0); } // namespace validation } // namespace test } // namespace arm_compute diff --git a/tests/validation/NEON/ConvolutionLayer.cpp b/tests/validation/NEON/ConvolutionLayer.cpp index ceecd58058..df52d8065b 100644 --- a/tests/validation/NEON/ConvolutionLayer.cpp +++ b/tests/validation/NEON/ConvolutionLayer.cpp @@ -74,6 +74,13 @@ const auto ActivationFunctionsDataset = framework::dataset::make("ActivationInfo ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 0.5f) }); + +const auto QuantizationData = framework::dataset::make("QuantizationInfo", +{ + QuantizationInfo(0.5f, 10), + QuantizationInfo(0.3f, 3), + QuantizationInfo(1.f, 10), +}); } // namespace TEST_SUITE(NEON) @@ -422,6 +429,9 @@ TEST_SUITE_END() // Float template using NEGEMMConvolutionLayerQuantizedFixture = ConvolutionValidationQuantizedFixture; +template +using NEGEMMConvolutionLayerQuantizedPerChannelFixture = ConvolutionValidationQuantizedPerChannelFixture; + const auto QuantizedActivationFunctionsDataset = framework::dataset::make("ActivationInfo", { ActivationLayerInfo(), @@ -451,6 +461,33 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMConvolutionLayerQuantizedFixture validate(Accessor(_target), _reference, tolerance_qasymm8); } TEST_SUITE_END() // QASYMM8 + +TEST_SUITE(QSYMM8_PER_CHANNEL) +FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerQuantizedPerChannelFixture, framework::DatasetMode::PRECOMMIT, + combine(combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerReducedDataset(), + framework::dataset::make("ReshapeWeights", { true })), + framework::dataset::make("DataType", { DataType::QASYMM8 })), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), + QuantizationData), + ActivationFunctionsDataset), + framework::dataset::make("WeightsDataType", { DataType::QSYMM8_PER_CHANNEL }))) +{ + // Validate output + validate(Accessor(_target), _reference, tolerance_qasymm8); +} +FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMConvolutionLayerQuantizedPerChannelFixture, framework::DatasetMode::NIGHTLY, + combine(combine(combine(combine(combine(combine(framework::dataset::concat(datasets::SmallConvolutionLayerDataset(), datasets::LargeConvolutionLayerDataset()), + framework::dataset::make("ReshapeWeights", { true })), + framework::dataset::make("DataType", { DataType::QASYMM8 })), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), + QuantizationData), + QuantizedActivationFunctionsDataset), + framework::dataset::make("WeightsDataType", { DataType::QSYMM8_PER_CHANNEL }))) +{ + // Validate output + validate(Accessor(_target), _reference, tolerance_qasymm8); +} +TEST_SUITE_END() // QSYMM8_PER_CHANNEL TEST_SUITE_END() // Quantized TEST_SUITE_END() // GEMMConvolutionLayer diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h index 52fa8da60b..c5cddc28db 100644 --- a/tests/validation/fixtures/ConvolutionLayerFixture.h +++ b/tests/validation/fixtures/ConvolutionLayerFixture.h @@ -48,7 +48,7 @@ namespace test { namespace validation { -template +template class ConvolutionValidationGenericFixture : public framework::Fixture { public: @@ -57,13 +57,15 @@ public: public: template void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, bool reshape_weights, - DataType data_type, DataLayout data_layout, QuantizationInfo quantization_info, ActivationLayerInfo act_info) + DataType data_type, DataType weights_data_type, DataLayout data_layout, QuantizationInfo quantization_info, QuantizationInfo weight_quantization_info, ActivationLayerInfo act_info) { - _data_type = data_type; - _is_quantized = is_data_type_quantized_asymmetric(data_type); - _bias_data_type = _is_quantized ? DataType::S32 : data_type; - _quantization_info = quantization_info; - _data_layout = data_layout; + _data_type = data_type; + _weights_data_type = weights_data_type; + _is_quantized = is_data_type_quantized_asymmetric(data_type); + _bias_data_type = _is_quantized ? DataType::S32 : data_type; + _quantization_info = quantization_info; + _weight_quantization_info = weight_quantization_info; + _data_layout = data_layout; _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, reshape_weights, dilation, act_info); _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, dilation, act_info); @@ -82,6 +84,26 @@ protected: library->fill(tensor, distribution, i); break; } + case DataType::QSYMM8_PER_CHANNEL: + { + int min_bound = 128; + int max_bound = -127; + for(size_t i = 0; i < _weight_quantization_info.scale().size(); i++) + { + std::pair bounds = get_symm_quantized_per_channel_bounds(tensor.quantization_info(), -1.0f, 1.0f, i); + if(bounds.first < min_bound) + { + min_bound = bounds.first; + } + if(bounds.second > max_bound) + { + max_bound = bounds.second; + } + } + std::uniform_int_distribution distribution(min_bound, max_bound); + library->fill(tensor, distribution, i); + break; + } case DataType::S32: { std::uniform_int_distribution distribution(-100, 100); @@ -122,7 +144,7 @@ protected: // Create tensors TensorType src = create_tensor(input_shape, _data_type, 1, _quantization_info, _data_layout); - TensorType weights = create_tensor(reshaped_weights_shape, _data_type, 1, _quantization_info, _data_layout); + TensorType weights = create_tensor(reshaped_weights_shape, _weights_data_type, 1, _weight_quantization_info, _data_layout); TensorType bias = create_tensor(bias_shape, _bias_data_type, 1, _quantization_info, _data_layout); TensorType dst = create_tensor(output_shape, _data_type, 1, _quantization_info, _data_layout); @@ -166,7 +188,7 @@ protected: // Create reference SimpleTensor src{ input_shape, _data_type, 1, _quantization_info }; - SimpleTensor weights{ weights_shape, _data_type, 1, _quantization_info }; + SimpleTensor weights{ weights_shape, _weights_data_type, 1, _weight_quantization_info }; SimpleTensor bias{ bias_shape, _bias_data_type, 1, _quantization_info }; // Fill reference @@ -182,36 +204,59 @@ protected: TensorType _target{}; SimpleTensor _reference{}; DataType _data_type{}; + DataType _weights_data_type{}; DataType _bias_data_type{}; DataLayout _data_layout{}; QuantizationInfo _quantization_info{}; + QuantizationInfo _weight_quantization_info{}; bool _is_quantized = false; }; template -class ConvolutionValidationFixture : public ConvolutionValidationGenericFixture +class ConvolutionValidationFixture : public ConvolutionValidationGenericFixture { public: template void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, bool reshape_weights, DataType data_type, DataLayout data_layout, ActivationLayerInfo act_info) { - ConvolutionValidationGenericFixture::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights, - data_type, data_layout, - QuantizationInfo(), act_info); + ConvolutionValidationGenericFixture::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights, + data_type, data_type, data_layout, + QuantizationInfo(), QuantizationInfo(), act_info); } }; template -class ConvolutionValidationQuantizedFixture : public ConvolutionValidationGenericFixture +class ConvolutionValidationQuantizedFixture : public ConvolutionValidationGenericFixture { public: template void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, bool reshape_weights, DataType data_type, DataLayout data_layout, QuantizationInfo quantization_info, ActivationLayerInfo act_info) { - ConvolutionValidationGenericFixture::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights, - data_type, data_layout, quantization_info, act_info); + ConvolutionValidationGenericFixture::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights, + data_type, data_type, data_layout, quantization_info, quantization_info, act_info); + } +}; + +template +class ConvolutionValidationQuantizedPerChannelFixture : public ConvolutionValidationGenericFixture +{ +public: + template + void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, bool reshape_weights, DataType data_type, + DataLayout data_layout, QuantizationInfo quantization_info, ActivationLayerInfo act_info, DataType weights_data_type) + { + std::vector weights_scales{}; + std::mt19937 gen(library->seed()); + std::uniform_real_distribution<> dis(0.01f, 1); + for(size_t i = 0; i < output_shape[2]; ++i) + { + weights_scales.push_back(dis(gen)); + } + ConvolutionValidationGenericFixture::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, + reshape_weights, data_type, weights_data_type, data_layout, + quantization_info, QuantizationInfo(weights_scales), act_info); } }; } // namespace validation diff --git a/tests/validation/reference/Convolution3d.h b/tests/validation/reference/Convolution3d.h index 30be25f504..23918a4055 100644 --- a/tests/validation/reference/Convolution3d.h +++ b/tests/validation/reference/Convolution3d.h @@ -42,13 +42,16 @@ inline bool is_valid_pixel(int i, int min, int max) } // 3D convolution for floating point type -template < typename T, typename TB, typename std::enable_if < validation::is_floating_point::value &&validation::is_floating_point::value, int >::type = 0 > -inline void convolution3d(const SimpleTensor &in, const SimpleTensor &weights, const SimpleTensor &bias, SimpleTensor &out, +template < typename T, typename TW, typename TB, typename std::enable_if < validation::is_floating_point::value &&validation::is_floating_point::value + &&validation::is_floating_point::value, + int >::type = 0 > +inline void convolution3d(const SimpleTensor &in, const SimpleTensor &weights, const SimpleTensor &bias, SimpleTensor &out, int i_offset, int w_offset, int b_offset, int o_offset, - int xi, int yi, int width_in, int height_in, int depth_in, int width_weights, int height_weights, int dilation_x = 1, int dilation_y = 1) + int xi, int yi, int width_in, int height_in, int depth_in, int width_weights, int height_weights, int dilation_x = 1, int dilation_y = 1, int filter_id = 0) { + ARM_COMPUTE_UNUSED(filter_id); const T *in_ptr = in.data() + i_offset; - const T *w_ptr = weights.data() + w_offset; + const TW *w_ptr = weights.data() + w_offset; const TB *b_ptr = bias.data() + b_offset; T *out_ptr = out.data() + o_offset; @@ -77,8 +80,8 @@ inline void convolution3d(const SimpleTensor &in, const SimpleTensor &weig const int idx = xk + half_width_weights_start; const int idy = yk + half_height_weights_start; - const T i_value = in_ptr[offset_slice_in + xk * dilation_x + yk * dilation_y * width_in]; - const T w_value = w_ptr[idx + idy * width_weights + ifm * width_weights * height_weights]; + const T i_value = in_ptr[offset_slice_in + xk * dilation_x + yk * dilation_y * width_in]; + const TW w_value = w_ptr[idx + idy * width_weights + ifm * width_weights * height_weights]; acc += i_value * w_value; } @@ -91,13 +94,16 @@ inline void convolution3d(const SimpleTensor &in, const SimpleTensor &weig } // 3D convolution for QASYMM8 type -template < typename T, typename TB, typename std::enable_if < std::is_same::value &&std::is_same::value, int >::type = 0 > -inline void convolution3d(const SimpleTensor &in, const SimpleTensor &weights, const SimpleTensor &bias, SimpleTensor &out, +template < typename T, typename TW, typename TB, typename std::enable_if < std::is_same::value &&(std::is_same::value + || std::is_same::value) + &&std::is_same::value, + int >::type = 0 > +inline void convolution3d(const SimpleTensor &in, const SimpleTensor &weights, const SimpleTensor &bias, SimpleTensor &out, int i_offset, int w_offset, int b_offset, int o_offset, - int xi, int yi, int width_in, int height_in, int depth_in, int width_weights, int height_weights, int dilation_x = 1, int dilation_y = 1) + int xi, int yi, int width_in, int height_in, int depth_in, int width_weights, int height_weights, int dilation_x = 1, int dilation_y = 1, int filter_id = 0) { const T *in_ptr = in.data() + i_offset; - const T *w_ptr = weights.data() + w_offset; + const TW *w_ptr = weights.data() + w_offset; const TB *b_ptr = bias.data() + b_offset; T *out_ptr = out.data() + o_offset; @@ -107,10 +113,22 @@ inline void convolution3d(const SimpleTensor &in, const SimpleTensor &weig const int input_offset = -iq_info.offset; const float input_scale = iq_info.scale; - const int weights_offset = -wq_info.offset; - const float weights_scale = wq_info.scale; - const int output_offset = oq_info.offset; - const float output_scale = oq_info.scale; + int weights_offset = -wq_info.offset; + float weights_scale = wq_info.scale; + if(is_data_type_quantized_per_channel(weights.data_type())) + { + if(is_data_type_quantized_asymmetric(weights.data_type())) + { + weights_offset = weights.quantization_info().offset()[filter_id]; + } + else + { + weights_offset = 0; + } + weights_scale = weights.quantization_info().scale()[filter_id]; + } + const int output_offset = oq_info.offset; + const float output_scale = oq_info.scale; int output_multiplier = 0; int output_shift = 0; @@ -142,9 +160,8 @@ inline void convolution3d(const SimpleTensor &in, const SimpleTensor &weig const int idx = xk + half_width_weights_start; const int idy = yk + half_height_weights_start; - const uint8_t i_value = in_ptr[offset_slice_in + xk * dilation_x + yk * dilation_y * width_in]; - const uint8_t w_value = w_ptr[idx + idy * width_weights + ifm * width_weights * height_weights]; - + const int32_t i_value = in_ptr[offset_slice_in + xk * dilation_x + yk * dilation_y * width_in]; + const int32_t w_value = w_ptr[idx + idy * width_weights + ifm * width_weights * height_weights]; acc += (i_value + input_offset) * (w_value + weights_offset); } } diff --git a/tests/validation/reference/ConvolutionLayer.cpp b/tests/validation/reference/ConvolutionLayer.cpp index 69090117fe..4d2c1acb6f 100644 --- a/tests/validation/reference/ConvolutionLayer.cpp +++ b/tests/validation/reference/ConvolutionLayer.cpp @@ -45,8 +45,8 @@ namespace { } // namespace -template -SimpleTensor convolution_layer_nchw(const SimpleTensor &src, const SimpleTensor &weights, const SimpleTensor &bias, SimpleTensor &dst, const PadStrideInfo &info, +template +SimpleTensor convolution_layer_nchw(const SimpleTensor &src, const SimpleTensor &weights, const SimpleTensor &bias, SimpleTensor &dst, const PadStrideInfo &info, const Size2D &dilation, unsigned int num_groups) { ARM_COMPUTE_ERROR_ON((src.shape()[2] / num_groups) != weights.shape()[2]); @@ -73,7 +73,6 @@ SimpleTensor convolution_layer_nchw(const SimpleTensor &src, const SimpleT const int end_xi = output_wh.first * stride_xi; const int end_yi = output_wh.second * stride_yi; const int num_batches = src.shape().total_size() / (width_in * height_in * depth_in); - for(int r = 0; r < num_batches; ++r) { for(int yi = start_yi; yi < start_yi + end_yi; yi += stride_yi) @@ -100,17 +99,16 @@ SimpleTensor convolution_layer_nchw(const SimpleTensor &src, const SimpleT offset_in, offset_w, offset_b, offset_out, xi, yi, width_in, height_in, (depth_in / num_groups), - width_weights, height_weights, dilation.x(), dilation.y()); + width_weights, height_weights, dilation.x(), dilation.y(), ofm); } } } } } - return dst; } -template -SimpleTensor convolution_layer(const SimpleTensor &src, const SimpleTensor &weights, const SimpleTensor &bias, const TensorShape &output_shape, const PadStrideInfo &info, +template +SimpleTensor convolution_layer(const SimpleTensor &src, const SimpleTensor &weights, const SimpleTensor &bias, const TensorShape &output_shape, const PadStrideInfo &info, const Size2D &dilation, unsigned int num_groups, QuantizationInfo out_quant_info) { // if no explicit quantization has been set you the same as src @@ -123,9 +121,9 @@ SimpleTensor convolution_layer(const SimpleTensor &src, const SimpleTensor if(src.data_layout() == DataLayout::NHWC) { - SimpleTensor src_nchw = reference::permute(src, PermutationVector(1U, 2U, 0U)); - SimpleTensor weights_nchw = reference::permute(weights, PermutationVector(1U, 2U, 0U)); - SimpleTensor dst_nchw = reference::permute(dst, PermutationVector(1U, 2U, 0U)); + SimpleTensor src_nchw = reference::permute(src, PermutationVector(1U, 2U, 0U)); + SimpleTensor weights_nchw = reference::permute(weights, PermutationVector(1U, 2U, 0U)); + SimpleTensor dst_nchw = reference::permute(dst, PermutationVector(1U, 2U, 0U)); return reference::permute(convolution_layer_nchw(src_nchw, weights_nchw, bias, dst_nchw, info, dilation, num_groups), PermutationVector(2U, 0U, 1U)); } @@ -141,6 +139,8 @@ template SimpleTensor convolution_layer(const SimpleTensor &src, con const PadStrideInfo &info, const Size2D &dilation, unsigned int num_groups, QuantizationInfo out_quant_info); template SimpleTensor convolution_layer(const SimpleTensor &src, const SimpleTensor &weights, const SimpleTensor &bias, const TensorShape &output_shape, const PadStrideInfo &info, const Size2D &dilation, unsigned int num_groups, QuantizationInfo out_quant_info); +template SimpleTensor convolution_layer(const SimpleTensor &src, const SimpleTensor &weights, const SimpleTensor &bias, const TensorShape &output_shape, + const PadStrideInfo &info, const Size2D &dilation, unsigned int num_groups, QuantizationInfo out_quant_info); } // namespace reference } // namespace validation } // namespace test diff --git a/tests/validation/reference/ConvolutionLayer.h b/tests/validation/reference/ConvolutionLayer.h index c51a9b3ad7..8f41073fe2 100644 --- a/tests/validation/reference/ConvolutionLayer.h +++ b/tests/validation/reference/ConvolutionLayer.h @@ -35,8 +35,8 @@ namespace validation { namespace reference { -template -SimpleTensor convolution_layer(const SimpleTensor &src, const SimpleTensor &weights, const SimpleTensor &bias, const TensorShape &output_shape, const PadStrideInfo &info, +template +SimpleTensor convolution_layer(const SimpleTensor &src, const SimpleTensor &weights, const SimpleTensor &bias, const TensorShape &output_shape, const PadStrideInfo &info, const Size2D &dilation = Size2D(1U, 1U), unsigned int num_groups = 1, QuantizationInfo out_quant_info = QuantizationInfo()); } // namespace reference } // namespace validation -- cgit v1.2.1