From 2bbd96457e3740fd9df5556607514b5e80a25720 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Tue, 4 Jul 2017 16:46:32 +0100 Subject: COMPMID-436, COMPMID-437 - Port NEConvolutionLayer & NEFullyConnectedLayer to support 16 bit fixed point Change-Id: I69edf2dac242f941bac95c8479d921e7be6abca7 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/79725 Tested-by: Kaizen Reviewed-by: Pablo Tello --- .../kernels/NEGEMMMatrixAccumulateBiasesKernel.h | 2 +- arm_compute/core/NEON/kernels/NEIm2ColKernel.h | 2 +- arm_compute/core/NEON/kernels/NETransposeKernel.h | 4 ++-- .../core/NEON/kernels/NEWeightsReshapeKernel.h | 2 +- .../runtime/NEON/functions/NEConvolutionLayer.h | 4 ++-- .../runtime/NEON/functions/NEFullyConnectedLayer.h | 4 ++-- src/core/NEON/kernels/NECol2ImKernel.cpp | 3 ++- .../kernels/NEGEMMMatrixAccumulateBiasesKernel.cpp | 19 +++++++++++++++++-- src/core/NEON/kernels/NEIm2ColKernel.cpp | 21 ++++++++++++++++++--- src/core/NEON/kernels/NETransposeKernel.cpp | 3 ++- src/core/NEON/kernels/NEWeightsReshapeKernel.cpp | 12 ++++++------ src/runtime/NEON/functions/NEConvolutionLayer.cpp | 5 ++--- .../NEON/functions/NEFullyConnectedLayer.cpp | 7 +++---- tests/validation/NEON/ConvolutionLayer.cpp | 12 ++++++------ tests/validation/NEON/FullyConnectedLayer.cpp | 12 ++++++------ tests/validation/TensorOperations.h | 13 +++++++------ 16 files changed, 78 insertions(+), 47 deletions(-) diff --git a/arm_compute/core/NEON/kernels/NEGEMMMatrixAccumulateBiasesKernel.h b/arm_compute/core/NEON/kernels/NEGEMMMatrixAccumulateBiasesKernel.h index c0ecafcd39..1eed4e7a84 100644 --- a/arm_compute/core/NEON/kernels/NEGEMMMatrixAccumulateBiasesKernel.h +++ b/arm_compute/core/NEON/kernels/NEGEMMMatrixAccumulateBiasesKernel.h @@ -47,7 +47,7 @@ public: ~NEGEMMMatrixAccumulateBiasesKernel() = default; /** Set the accumulate buffer and the biases of the kernel. * - * @param[in, out] accum The accumulate tensor to convert. Data type supported: QS8/F32 + * @param[in, out] accum The accumulate tensor to convert. Data type supported: QS8/QS16/F32 * @param[in] biases The shared biases tensor to append. It must be 1D Tensor. Data type supported: Same as @p input */ void configure(ITensor *accum, const ITensor *biases); diff --git a/arm_compute/core/NEON/kernels/NEIm2ColKernel.h b/arm_compute/core/NEON/kernels/NEIm2ColKernel.h index 9b8b98b388..87d7cc0a8b 100644 --- a/arm_compute/core/NEON/kernels/NEIm2ColKernel.h +++ b/arm_compute/core/NEON/kernels/NEIm2ColKernel.h @@ -73,7 +73,7 @@ public: /** Set the input and output of the kernel. * * @param[in] input The input tensor to convert. 3 lower dimensions represent a single input [width, height, IFM], - * while every optional dimension from 4 and above represent a batch of inputs. Data types supported: QS8/F16/F32 + * while every optional dimension from 4 and above represent a batch of inputs. Data types supported: QS8/QS16/F16/F32 * @param[out] output The output tensor. Data types supported: Same as @p input * @param[in] kernel_dims The kernel dimensions (width and height). * @param[in] conv_info Contains padding and stride information described in @ref PadStrideInfo. diff --git a/arm_compute/core/NEON/kernels/NETransposeKernel.h b/arm_compute/core/NEON/kernels/NETransposeKernel.h index ac9449ff92..2f757f18eb 100644 --- a/arm_compute/core/NEON/kernels/NETransposeKernel.h +++ b/arm_compute/core/NEON/kernels/NETransposeKernel.h @@ -53,7 +53,7 @@ public: /** Initialise the kernel's input and output. * - * @param[in] input Input tensor. Data types supported: U8/S8/QS8/U16/S16/F16/U32/S32/F32 + * @param[in] input Input tensor. Data types supported: U8/S8/QS8/U16/S16/QS16/F16/U32/S32/F32 * @param[out] output Output tensor. Data type supported: Same as @p input */ void configure(const ITensor *input, ITensor *output); @@ -64,7 +64,7 @@ public: private: /** Common signature for all the transpose functions * - * @param[in] input An input tensor. Data types supported: U8/S8/QS8/U16/S16/F16/U32/S32/F32 + * @param[in] input An input tensor. Data types supported: U8/S8/QS8/U16/S16/QS16/F16/U32/S32/F32 * @param[out] output The output tensor. Data type supported: same as @p input * @param[in] window Region on which to execute the kernel. */ diff --git a/arm_compute/core/NEON/kernels/NEWeightsReshapeKernel.h b/arm_compute/core/NEON/kernels/NEWeightsReshapeKernel.h index cad2d00b1f..6b76d19314 100644 --- a/arm_compute/core/NEON/kernels/NEWeightsReshapeKernel.h +++ b/arm_compute/core/NEON/kernels/NEWeightsReshapeKernel.h @@ -71,7 +71,7 @@ public: /** Set the input and output of the kernel. * * @param[in] input The input tensor to convert. Weights are 4D tensor with dimensions [kernel_x, kernel_y, IFM, OFM] if shared, - * and 5D tensor with dimensions [kernel_x, kernel_y, IFM, OFM, num_patches] if unshared. Data types supported: QS8/F32 + * and 5D tensor with dimensions [kernel_x, kernel_y, IFM, OFM, num_patches] if unshared. Data types supported: QS8/QS16/F32 * @param[in] bias The shared biases tensor to append. Bias is 1D tensor with dimensions [OFM] if shared and 2D tensor with * dimensions [OFM, num_patches] if unshared. Data types supported: Same as @p input * @param[out] output The output tensor. Data types supported: Same as @p input diff --git a/arm_compute/runtime/NEON/functions/NEConvolutionLayer.h b/arm_compute/runtime/NEON/functions/NEConvolutionLayer.h index a8fff8d047..1bd7e6a95f 100644 --- a/arm_compute/runtime/NEON/functions/NEConvolutionLayer.h +++ b/arm_compute/runtime/NEON/functions/NEConvolutionLayer.h @@ -51,7 +51,7 @@ public: NEConvolutionLayerReshapeWeights(); /** Set the input and output tensors. * - * @param[in] weights Weights tensor. Weights are 4D tensor with dimensions [kernel_x, kernel_y, IFM, OFM]. Data type supported: QS8/F32. + * @param[in] weights Weights tensor. Weights are 4D tensor with dimensions [kernel_x, kernel_y, IFM, OFM]. Data type supported: QS8/QS16/F32. * @param[in] biases Biases tensor. Shared biases supported. Biases are 1D tensor with dimensions [OFM]. Data type supported: Same as @p weights. * @param[out] output Destination tensor. Data types supported: Same as @p weights. * @param[in] transpose1xW True if the weights are to undergo a 1xW transposition after reshaping (in case of GEMM operation), false otherwise. @@ -84,7 +84,7 @@ public: * * @param[in] input Source tensor. 3 lower dimensions represent a single input [width, height, IFM], * while every optional dimension from 4 and above represent a batch of inputs. - * Data types supported: QS8/F32. + * Data types supported: QS8/QS16/F32. * @param[in] weights Weights tensor. Weights are 4D tensor with dimensions [kernel_x, kernel_y, IFM, OFM]. Data type supported: Same as @p input. * @param[in] biases Biases tensor. Shared biases supported. Biases are 1D tensor with dimensions [OFM]. Data type supported: Same as @p input. * @param[out] output Destination tensor. 3 lower dimensions represent a single output [width, height, OFM], while the rest represent batch of outputs. diff --git a/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h b/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h index 33ec4ef721..af571d1057 100644 --- a/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h +++ b/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h @@ -50,7 +50,7 @@ public: NEFullyConnectedLayerReshapeWeights(); /** Set the input and output tensors. * - * @param[in] input Weights tensor. The weights must be 2 dimensional. Data types supported: QS8/F32. + * @param[in] input Weights tensor. The weights must be 2 dimensional. Data types supported: QS8/QS16/F32. * @param[out] output Destination tensor. Data type supported: Same as @p input. * @param[in] transpose_weights True if the weights must be transposed. Data types supported: Same as @p weights. * @param[in] is_batched_fc_layer True if it is a batched fully connected layer @@ -84,7 +84,7 @@ public: NEFullyConnectedLayer(); /** Set the input and output tensors. * - * @param[in] input Source tensor. Data type supported: QS8/F32. + * @param[in] input Source tensor. Data type supported: QS8/QS16/F32. * @param[in] weights Weights tensor. The weights must be 2 dimensional. Data type supported: Same as @p input. * @param[in] biases Bias tensor. Can be nullptr. Data type supported:Same as @p input. * @param[out] output Destination tensor. Data type supported: Same as @p input. diff --git a/src/core/NEON/kernels/NECol2ImKernel.cpp b/src/core/NEON/kernels/NECol2ImKernel.cpp index e9a73607e6..95a9364082 100644 --- a/src/core/NEON/kernels/NECol2ImKernel.cpp +++ b/src/core/NEON/kernels/NECol2ImKernel.cpp @@ -69,7 +69,8 @@ NECol2ImKernel::NECol2ImKernel() void NECol2ImKernel::configure(const ITensor *input, ITensor *output, std::pair convolved_dims) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QS8, DataType::U16, DataType::S16, DataType::U32, DataType::S32, DataType::F16, DataType::F32); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QS8, DataType::U16, DataType::S16, DataType::QS16, DataType::U32, DataType::S32, DataType::F16, + DataType::F32); ARM_COMPUTE_ERROR_ON_NULLPTR(output); TensorShape output_shape = input->info()->tensor_shape(); diff --git a/src/core/NEON/kernels/NEGEMMMatrixAccumulateBiasesKernel.cpp b/src/core/NEON/kernels/NEGEMMMatrixAccumulateBiasesKernel.cpp index 7a3bae50c0..826a386557 100644 --- a/src/core/NEON/kernels/NEGEMMMatrixAccumulateBiasesKernel.cpp +++ b/src/core/NEON/kernels/NEGEMMMatrixAccumulateBiasesKernel.cpp @@ -45,9 +45,9 @@ NEGEMMMatrixAccumulateBiasesKernel::NEGEMMMatrixAccumulateBiasesKernel() void NEGEMMMatrixAccumulateBiasesKernel::configure(ITensor *accum, const ITensor *biases) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(accum, 1, DataType::QS8, DataType::F32); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(biases, 1, DataType::QS8, DataType::F32); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(accum, 1, DataType::QS8, DataType::QS16, DataType::F32); ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(biases, accum); + ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(biases, accum); ARM_COMPUTE_ERROR_ON(biases->info()->num_dimensions() != 1); _biases = biases; @@ -121,6 +121,21 @@ void NEGEMMMatrixAccumulateBiasesKernel::run(const Window &window) in0_out, in1); break; } + case DataType::QS16: + { + execute_window_loop(window, [&](const Coordinates & id) + { + qint16x8x2_t accum = vld2q_s16(reinterpret_cast(in0_out.ptr())); + const qint16x8x2_t biases = vld2q_s16(reinterpret_cast(in1.ptr())); + + accum.val[0] = vqaddq_qs16(accum.val[0], biases.val[0]); + accum.val[1] = vqaddq_qs16(accum.val[1], biases.val[1]); + + vst2q_s16(reinterpret_cast(in0_out.ptr()), accum); + }, + in0_out, in1); + break; + } default: ARM_COMPUTE_ERROR("Data type not supported"); break; diff --git a/src/core/NEON/kernels/NEIm2ColKernel.cpp b/src/core/NEON/kernels/NEIm2ColKernel.cpp index 8c9d12c57c..5bb8b1c22a 100644 --- a/src/core/NEON/kernels/NEIm2ColKernel.cpp +++ b/src/core/NEON/kernels/NEIm2ColKernel.cpp @@ -134,10 +134,14 @@ inline void linearize_volume(const uint8_t *const in_ptr, // Append 1 if the convolution layer has biases if(has_bias) { - if(std::is_same::value) + if(std::is_same::value) { *out_ptr = scvt_qs8_f32(1.0f, fixed_point_position); } + else if(std::is_same::value) + { + *out_ptr = scvt_qs16_f32(1.0f, fixed_point_position); + } else { *out_ptr = static_cast(1); @@ -249,10 +253,14 @@ void NEIm2ColKernel::run_reduced(const Window &window) // Add bias if(_has_bias) { - if(std::is_same::value) + if(std::is_same::value) { *(reinterpret_cast(out_ptr) + out_width - 1) = scvt_qs8_f32(1.0f, _input->info()->fixed_point_position()); } + else if(std::is_same::value) + { + *(reinterpret_cast(out_ptr) + out_width - 1) = scvt_qs16_f32(1.0f, _input->info()->fixed_point_position()); + } else { *(reinterpret_cast(out_ptr) + out_width - 1) = static_cast(1); @@ -269,8 +277,9 @@ NEIm2ColKernel::NEIm2ColKernel() void NEIm2ColKernel::configure(const ITensor *input, ITensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32, DataType::QS8); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32, DataType::QS8, DataType::QS16); ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); + ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output); _input = input; _output = output; @@ -309,6 +318,9 @@ void NEIm2ColKernel::configure(const ITensor *input, ITensor *output, const Size case DataType::QS8: _func = &NEIm2ColKernel::run_reduced; break; + case DataType::QS16: + _func = &NEIm2ColKernel::run_reduced; + break; default: ARM_COMPUTE_ERROR("Data type not supported"); break; @@ -329,6 +341,9 @@ void NEIm2ColKernel::configure(const ITensor *input, ITensor *output, const Size case DataType::QS8: _func = ((pad_x == 0) && (pad_y == 0)) ? &NEIm2ColKernel::run_generic : &NEIm2ColKernel::run_generic; break; + case DataType::QS16: + _func = ((pad_x == 0) && (pad_y == 0)) ? &NEIm2ColKernel::run_generic : &NEIm2ColKernel::run_generic; + break; default: ARM_COMPUTE_ERROR("Data type not supported"); break; diff --git a/src/core/NEON/kernels/NETransposeKernel.cpp b/src/core/NEON/kernels/NETransposeKernel.cpp index a990e9068e..732a0ef4f6 100644 --- a/src/core/NEON/kernels/NETransposeKernel.cpp +++ b/src/core/NEON/kernels/NETransposeKernel.cpp @@ -179,7 +179,8 @@ NETransposeKernel::NETransposeKernel() void NETransposeKernel::configure(const ITensor *input, ITensor *output) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QS8, DataType::U16, DataType::S16, DataType::U32, DataType::S32, DataType::F16, DataType::F32); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QS8, DataType::U16, DataType::S16, DataType::QS16, DataType::U32, DataType::S32, DataType::F16, + DataType::F32); ARM_COMPUTE_ERROR_ON_NULLPTR(output); TensorShape output_shape{ input->info()->tensor_shape() }; diff --git a/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp b/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp index ac688e1381..d685ec7962 100644 --- a/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp +++ b/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp @@ -95,7 +95,7 @@ NEWeightsReshapeKernel::NEWeightsReshapeKernel() void NEWeightsReshapeKernel::configure(const ITensor *input, const ITensor *bias, ITensor *output) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::F16, DataType::F32); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32); ARM_COMPUTE_ERROR_ON_NULLPTR(output); const int fixed_point_position = input->info()->fixed_point_position(); @@ -129,26 +129,26 @@ void NEWeightsReshapeKernel::configure(const ITensor *input, const ITensor *bias _bias = bias; _output = output; - switch(_input->info()->data_type()) + switch(_input->info()->element_size()) { - case DataType::F32: + case 4: { _func = &weights_reshape; break; } - case DataType::F16: + case 2: { _func = &weights_reshape; break; } - case DataType::QS8: + case 1: { _func = &weights_reshape; break; } default: { - ARM_COMPUTE_ERROR_ON("Data type not supported"); + ARM_COMPUTE_ERROR_ON("Element size not supported"); break; } } diff --git a/src/runtime/NEON/functions/NEConvolutionLayer.cpp b/src/runtime/NEON/functions/NEConvolutionLayer.cpp index dc8652747f..f6481f1918 100644 --- a/src/runtime/NEON/functions/NEConvolutionLayer.cpp +++ b/src/runtime/NEON/functions/NEConvolutionLayer.cpp @@ -41,14 +41,13 @@ NEConvolutionLayerReshapeWeights::NEConvolutionLayerReshapeWeights() void NEConvolutionLayerReshapeWeights::configure(const ITensor *weights, const ITensor *biases, ITensor *output, bool transpose1xW) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::QS8, DataType::F16, DataType::F32); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32); ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(weights, output); ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(weights, output); ARM_COMPUTE_ERROR_ON(weights->info()->num_dimensions() > 4); if(biases != nullptr) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(biases, 1, DataType::QS8, DataType::F16, DataType::F32); ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(weights, biases); ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(weights, biases); ARM_COMPUTE_ERROR_ON(biases->info()->dimension(0) != weights->info()->dimension(3)); @@ -96,7 +95,7 @@ NEConvolutionLayer::NEConvolutionLayer() void NEConvolutionLayer::configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info, const WeightsInfo &weights_info) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::F16, DataType::F32); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32); ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights, output); ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, weights, output); ARM_COMPUTE_ERROR_ON(!weights_info.are_reshaped() && weights->info()->dimension(2) != input->info()->dimension(2)); diff --git a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp index 6e27ed344a..eb84ccaddc 100644 --- a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp +++ b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp @@ -39,7 +39,7 @@ NEFullyConnectedLayerReshapeWeights::NEFullyConnectedLayerReshapeWeights() void NEFullyConnectedLayerReshapeWeights::configure(const ITensor *input, ITensor *output, bool transpose_weights, bool is_batched_fc_layer) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::F32); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F32); ARM_COMPUTE_ERROR_ON(output == nullptr); ARM_COMPUTE_ERROR_ON(input->info()->num_dimensions() != 2); ARM_COMPUTE_ERROR_ON((transpose_weights == false) && (is_batched_fc_layer == false)); @@ -196,10 +196,9 @@ void NEFullyConnectedLayer::configure_fc_fc_nb(const ITensor *input, const ITens void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, bool transpose_weights, bool are_weights_reshaped) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::F32); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::QS8, DataType::F32); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS8, DataType::F32); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F32); ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights, output); + ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, weights, output); ARM_COMPUTE_ERROR_ON(weights->info()->num_dimensions() != 2); const DataType dt = input->info()->data_type(); diff --git a/tests/validation/NEON/ConvolutionLayer.cpp b/tests/validation/NEON/ConvolutionLayer.cpp index 128fb8e842..1cf630a473 100644 --- a/tests/validation/NEON/ConvolutionLayer.cpp +++ b/tests/validation/NEON/ConvolutionLayer.cpp @@ -46,7 +46,7 @@ const float tolerance_f32 = 1e-03f; /**< Tolerance value for comparing reference #ifdef ARM_COMPUTE_ENABLE_FP16 const float tolerance_f16 = 0.01f; /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */ #endif /* ARM_COMPUTE_ENABLE_FP16 */ -const float tolerance_qs8 = 3.0f; /**< Tolerance value for comparing reference's output against implementation's output for DataType::QS8 */ +const float tolerance_q = 1.0f; /**< Tolerance value for comparing reference's output against implementation's output for fixed point data types */ Tensor compute_convolution_layer(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, DataType dt, const PadStrideInfo &conv_info, int fixed_point_position) @@ -101,7 +101,7 @@ BOOST_AUTO_TEST_SUITE(GEMM) BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit") * boost::unit_test::label("nightly")) BOOST_DATA_TEST_CASE(Configuration, - AlexNetConvolutionLayerDataset() * boost::unit_test::data::make({ DataType::F32, DataType::QS8 }), + AlexNetConvolutionLayerDataset() * boost::unit_test::data::make({ DataType::F32, DataType::QS8, DataType::QS16 }), conv_set, dt) { // Set fixed point position data type allowed @@ -188,7 +188,7 @@ BOOST_AUTO_TEST_SUITE_END() BOOST_AUTO_TEST_SUITE(Quantized) BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit")) BOOST_DATA_TEST_CASE(SmallConvolutionLayer, - SmallConvolutionLayerDataset() * boost::unit_test::data::make(DataType::QS8) * boost::unit_test::data::xrange(4, 7), + SmallConvolutionLayerDataset() * boost::unit_test::data::make({ DataType::QS8, DataType::QS16 }) * boost::unit_test::data::xrange(4, 7), conv_set, dt, fixed_point_position) { // Compute function @@ -198,12 +198,12 @@ BOOST_DATA_TEST_CASE(SmallConvolutionLayer, RawTensor ref_dst = Reference::compute_reference_convolution_layer(conv_set.src_shape, conv_set.weights_shape, conv_set.bias_shape, conv_set.dst_shape, dt, conv_set.info, fixed_point_position); // Validate output - validate(NEAccessor(dst), ref_dst, tolerance_qs8); + validate(NEAccessor(dst), ref_dst, tolerance_q); } BOOST_TEST_DECORATOR(*boost::unit_test::label("nightly")) BOOST_DATA_TEST_CASE(LargeConvolutionLayer, - AlexNetConvolutionLayerDataset() * boost::unit_test::data::make(DataType::QS8) * boost::unit_test::data::xrange(4, 7), + AlexNetConvolutionLayerDataset() * boost::unit_test::data::make({ DataType::QS8, DataType::QS16 }) * boost::unit_test::data::xrange(4, 7), conv_set, dt, fixed_point_position) { // Compute function @@ -213,7 +213,7 @@ BOOST_DATA_TEST_CASE(LargeConvolutionLayer, RawTensor ref_dst = Reference::compute_reference_convolution_layer(conv_set.src_shape, conv_set.weights_shape, conv_set.bias_shape, conv_set.dst_shape, dt, conv_set.info, fixed_point_position); // Validate output - validate(NEAccessor(dst), ref_dst, tolerance_qs8); + validate(NEAccessor(dst), ref_dst, tolerance_q); } BOOST_AUTO_TEST_SUITE_END() diff --git a/tests/validation/NEON/FullyConnectedLayer.cpp b/tests/validation/NEON/FullyConnectedLayer.cpp index ae0d94a53c..87e0071007 100644 --- a/tests/validation/NEON/FullyConnectedLayer.cpp +++ b/tests/validation/NEON/FullyConnectedLayer.cpp @@ -44,7 +44,7 @@ using namespace arm_compute::test::validation; namespace { const float tolerance_f32 = 1e-03f; /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */ -const float tolerance_qs8 = 1.0f; /**< Tolerance value for comparing reference's output against implementation's output for DataType::QS8 */ +const float tolerance_q = 1.0f; /**< Tolerance value for comparing reference's output against implementation's output for fixed point data types */ Tensor compute_fully_connected_layer(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, DataType dt, bool transpose_weights, int fixed_point_position) @@ -109,7 +109,7 @@ BOOST_AUTO_TEST_SUITE(FullyConnectedLayer) BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit") * boost::unit_test::label("nightly")) BOOST_DATA_TEST_CASE(Configuration, - SmallFullyConnectedLayerDataset() * boost::unit_test::data::make({ DataType::F32, DataType::QS8 }), + SmallFullyConnectedLayerDataset() * boost::unit_test::data::make({ DataType::F32, DataType::QS8, DataType::QS16 }), fc_set, dt) { // Set fixed point position data type allowed @@ -188,7 +188,7 @@ BOOST_AUTO_TEST_SUITE_END() BOOST_AUTO_TEST_SUITE(Quantized) BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit")) BOOST_DATA_TEST_CASE(RunSmall, - SmallFullyConnectedLayerDataset() * boost::unit_test::data::make({ DataType::QS8 }) * boost::unit_test::data::xrange(4, 7), + SmallFullyConnectedLayerDataset() * boost::unit_test::data::make({ DataType::QS8, DataType::QS16 }) * boost::unit_test::data::xrange(4, 7), fc_set, dt, fixed_point_position) { // Compute function @@ -198,12 +198,12 @@ BOOST_DATA_TEST_CASE(RunSmall, RawTensor ref_dst = Reference::compute_reference_fully_connected_layer(fc_set.src_shape, fc_set.weights_shape, fc_set.bias_shape, fc_set.dst_shape, dt, fc_set.transpose_weights, fixed_point_position); // Validate output - validate(NEAccessor(dst), ref_dst, tolerance_qs8); + validate(NEAccessor(dst), ref_dst, tolerance_q); } BOOST_TEST_DECORATOR(*boost::unit_test::label("nightly")) BOOST_DATA_TEST_CASE(RunLarge, - LargeFullyConnectedLayerDataset() * boost::unit_test::data::make({ DataType::QS8 }) * boost::unit_test::data::xrange(4, 7), + LargeFullyConnectedLayerDataset() * boost::unit_test::data::make({ DataType::QS8, DataType::QS16 }) * boost::unit_test::data::xrange(4, 7), fc_set, dt, fixed_point_position) { // Compute function @@ -213,7 +213,7 @@ BOOST_DATA_TEST_CASE(RunLarge, RawTensor ref_dst = Reference::compute_reference_fully_connected_layer(fc_set.src_shape, fc_set.weights_shape, fc_set.bias_shape, fc_set.dst_shape, dt, fc_set.transpose_weights, fixed_point_position); // Validate output - validate(NEAccessor(dst), ref_dst, tolerance_qs8); + validate(NEAccessor(dst), ref_dst, tolerance_q); } BOOST_AUTO_TEST_SUITE_END() diff --git a/tests/validation/TensorOperations.h b/tests/validation/TensorOperations.h index 488ffa90d9..0502f53186 100644 --- a/tests/validation/TensorOperations.h +++ b/tests/validation/TensorOperations.h @@ -158,7 +158,7 @@ void convolution3d(const T *in, const T *weights, const T *bias, T *out, int xi, *out = res.raw(); } -template +template ::value, int>::type * = nullptr> void vector_matrix_multiply(const T *in, const T *weights, const T *bias, T *out, int cols_weights, int rows_weights, uint8_t fixed_point_position) { for(int x = 0; x < cols_weights; ++x) @@ -172,11 +172,12 @@ void vector_matrix_multiply(const T *in, const T *weights, const T *bias, T *out } } -template <> -void vector_matrix_multiply(const int8_t *in, const int8_t *weights, const int8_t *bias, int8_t *out, int cols_weights, int rows_weights, uint8_t fixed_point_position) +// Vector matrix multiply for fixed point type +template ::value, int>::type * = nullptr> +void vector_matrix_multiply(const T *in, const T *weights, const T *bias, T *out, int cols_weights, int rows_weights, uint8_t fixed_point_position) { using namespace fixed_point_arithmetic; - using promoted_type = typename fixed_point_arithmetic::traits::promote::type; + using promoted_type = typename fixed_point_arithmetic::traits::promote::type; for(int x = 0; x < cols_weights; ++x) { @@ -192,10 +193,10 @@ void vector_matrix_multiply(const int8_t *in, const int8_t *weights, const int8_ } // Get the bias - const fixed_point b(bias[x], fixed_point_position, true); + const fixed_point b(bias[x], fixed_point_position, true); // Convert back and accumulate the bias - fixed_point res(acc); + fixed_point res(acc); res = res + b; // Store the result -- cgit v1.2.1