From a855af10a486c53c2271361cb87f349eca64b749 Mon Sep 17 00:00:00 2001 From: Giorgio Arena Date: Mon, 16 Jul 2018 17:20:38 +0100 Subject: COMPMID-1401 Implement NEFullyConnectedLayer for QASYMM8 Change-Id: I0404df6d369855e2f458f2db8f26e81c80a1ee87 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/140148 Reviewed-by: Georgios Pinitas Reviewed-by: Anthony Barbier Reviewed-by: Gian Marco Iodice Tested-by: Jenkins --- .../kernels/NEGEMMLowpOffsetContributionKernel.cpp | 10 +- .../NEON/functions/NEFullyConnectedLayer.cpp | 412 +++++++++++---------- src/runtime/NEON/functions/NEGEMM.cpp | 154 +++++++- 3 files changed, 357 insertions(+), 219 deletions(-) (limited to 'src') diff --git a/src/core/NEON/kernels/NEGEMMLowpOffsetContributionKernel.cpp b/src/core/NEON/kernels/NEGEMMLowpOffsetContributionKernel.cpp index ee334dfca0..af84d024d5 100644 --- a/src/core/NEON/kernels/NEGEMMLowpOffsetContributionKernel.cpp +++ b/src/core/NEON/kernels/NEGEMMLowpOffsetContributionKernel.cpp @@ -193,11 +193,14 @@ void NEGEMMLowpOffsetContributionKernel::run(const Window &window, const ThreadI Window win_vector_sum_row(collapsed_window); win_vector_sum_row.set(Window::DimX, Window::Dimension(0, 0, 0)); win_vector_sum_row.set(Window::DimY, Window::Dimension(0, 0, 0)); + win_vector_sum_row.set(Window::DimZ, Window::Dimension(0, 0, 0)); Iterator vector_sum_col(_vector_sum_col, win_vector_sum_col); Iterator vector_sum_row(_vector_sum_row, win_vector_sum_row); Iterator mm_result(_mm_result, window); + const size_t sum_row_stride_y = _vector_sum_row->info()->strides_in_bytes().y(); + execute_window_loop(collapsed_window, [&](const Coordinates & id) { // Compute the leftover term due to a_offset. @@ -217,7 +220,7 @@ void NEGEMMLowpOffsetContributionKernel::run(const Window &window, const ThreadI a_offset_term_s32.val[3] = vmulq_n_s32(a_offset_term_s32.val[3], _a_offset); // Compute the leftover term due to b_offset. - int32x4_t b_offset_term_s32 = vld1q_dup_s32(reinterpret_cast(vector_sum_row.ptr()) + id.y()); + int32x4_t b_offset_term_s32 = vld1q_dup_s32(reinterpret_cast(vector_sum_row.ptr() + id.z() * sum_row_stride_y) + id.y()); b_offset_term_s32 = vmulq_n_s32(b_offset_term_s32, _b_offset); // Add a_offset_term_s32 and b_offset_term_s32 @@ -266,14 +269,17 @@ void NEGEMMLowpOffsetContributionKernel::run(const Window &window, const ThreadI Window win_vector_sum_row(collapsed_window); win_vector_sum_row.set(Window::DimX, Window::Dimension(0, 0, 0)); win_vector_sum_row.set(Window::DimY, Window::Dimension(0, 0, 0)); + win_vector_sum_row.set(Window::DimZ, Window::Dimension(0, 0, 0)); Iterator vector_sum_row(_vector_sum_row, win_vector_sum_row); Iterator mm_result(_mm_result, window); + const size_t sum_row_stride_y = _vector_sum_row->info()->strides_in_bytes().y(); + execute_window_loop(window, [&](const Coordinates & id) { // Compute the leftover term due to b_offset. - int32x4_t b_offset_term_s32 = vld1q_dup_s32(reinterpret_cast(vector_sum_row.ptr()) + id.y()); + int32x4_t b_offset_term_s32 = vld1q_dup_s32(reinterpret_cast(vector_sum_row.ptr() + id.z() * sum_row_stride_y) + id.y()); b_offset_term_s32 = vmulq_n_s32(b_offset_term_s32, _b_offset); int32x4x4_t in_s32 = diff --git a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp index 1aab3a05e0..9d3cb31c9a 100644 --- a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp +++ b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp @@ -27,6 +27,7 @@ #include "arm_compute/core/Size2D.h" #include "arm_compute/core/Validate.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" +#include "arm_compute/core/utils/quantization/AsymmHelpers.h" #include "arm_compute/runtime/NEON/NEScheduler.h" #include @@ -35,121 +36,107 @@ using namespace arm_compute; using namespace arm_compute::misc::shape_calculator; -NEFullyConnectedLayerReshapeWeights::NEFullyConnectedLayerReshapeWeights(std::shared_ptr memory_manager) - : _memory_group(std::move(memory_manager)), _transpose_kernel(), _transpose1xW_kernel(), _transpose_output(), _transpose_weights(false), _is_batched_fc_layer(false) +namespace { -} - -void NEFullyConnectedLayerReshapeWeights::configure(const ITensor *input, ITensor *output, bool transpose_weights, bool is_batched_fc_layer) +Status validate_mm(const ITensorInfo &input, const ITensorInfo &weights, const ITensorInfo &output) { - ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); - - // Perform validate step - ARM_COMPUTE_ERROR_THROW_ON(NEFullyConnectedLayerReshapeWeights::validate(input->info(), output->info(), transpose_weights, is_batched_fc_layer)); - - _transpose_weights = transpose_weights; - _is_batched_fc_layer = is_batched_fc_layer; - - // Check if we need to transpose the weights - if(_transpose_weights) + if(is_data_type_quantized_asymmetric(input.data_type())) { - if(_is_batched_fc_layer) - { - // Initialize the output tensor for transpose - _transpose_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_transposed_shape(*input->info()))); - _memory_group.manage(&_transpose_output); - _transpose_kernel.configure(input, &_transpose_output); - - // Configure transpose 1xW kernel - _transpose1xW_kernel.configure(&_transpose_output, output); - - // Allocate temporary tensor used for transposing the weights - _transpose_output.allocator()->allocate(); - } - else - { - _transpose_kernel.configure(input, output); - } + // Since we need negative offsets for computing convolution, we need to change QuantizationInfo() + // Extract and negate input and weights offset + const QuantizationInfo input_quantization_info(input.quantization_info().scale, -input.quantization_info().offset); + const QuantizationInfo weights_quantization_info(weights.quantization_info().scale, -weights.quantization_info().offset); + + // Validate gemmlowp function + ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixMultiplyCore::validate(&input.clone()->set_quantization_info(input_quantization_info), + &weights.clone()->set_quantization_info(weights_quantization_info), + &output)); } else { - if(_is_batched_fc_layer) - { - // Configure transpose 1xW kernel - _transpose1xW_kernel.configure(input, output); - } + ARM_COMPUTE_RETURN_ON_ERROR(NEGEMM::validate(&input, &weights, nullptr, &output, 1.f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run */))); } + + return Status{}; } +} // namespace -Status NEFullyConnectedLayerReshapeWeights::validate(const ITensorInfo *input, const ITensorInfo *output, bool transpose_weights, bool is_batched_fc_layer) +void NEFullyConnectedLayerReshapeWeights::configure(const ITensor *input, ITensor *output) { - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 2); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(!transpose_weights && !is_batched_fc_layer, "Configuration transpose_weights=false & is_batched_fc_layer=false not supported"); + auto k = arm_compute::support::cpp14::make_unique(); + k->configure(input, output); + _kernel = std::move(k); +} + +Status NEFullyConnectedLayerReshapeWeights::validate(const ITensorInfo *input, const ITensorInfo *output) +{ + return NETransposeKernel::validate(input, output); +} + +NEFullyConnectedLayer::NEFullyConnectedLayer(std::shared_ptr memory_manager) + : _memory_group(std::move(memory_manager)), _im2col_kernel(), _reshape_weights_function(), _mm_gemm(), _mm_gemmlowp(), _gemmlowp_output_stage(), _accumulate_biases_kernel(), _im2col_output(), + _gemmlowp_output(), _reshape_weights_output(), _original_weights(nullptr), _are_weights_reshaped(false), _is_fc_after_conv(false), _accumulate_biases(false), _is_quantized(false), _is_prepared(false) +{ +} - if(transpose_weights) +void NEFullyConnectedLayer::configure_mm(const ITensor *input, const ITensor *weights, ITensor *output) +{ + if(_is_quantized) { - if(is_batched_fc_layer) - { - std::unique_ptr use_output = output->clone(); - use_output->set_is_resizable(true).reset_padding().set_tensor_shape(compute_transposed_shape(*input)); + // Since we need negative offsets for computing convolution, we need to change QuantizationInfo() + // Extract and negate input and weights offset + const QuantizationInfo input_quantization_info = input->info()->quantization_info(); + const QuantizationInfo weights_quantization_info = weights->info()->quantization_info(); - ARM_COMPUTE_RETURN_ON_ERROR(NETransposeKernel::validate(input, use_output.get())); - ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMTranspose1xWKernel::validate(use_output.get(), output)); - } - else - { - ARM_COMPUTE_RETURN_ON_ERROR(NETransposeKernel::validate(input, output)); - } + input->info()->set_quantization_info(QuantizationInfo(input_quantization_info.scale, -input_quantization_info.offset)); + weights->info()->set_quantization_info(QuantizationInfo(weights_quantization_info.scale, -weights_quantization_info.offset)); + + // Configure gemmlowp function + _mm_gemmlowp.configure(input, weights, output); + + // Revert back QuantizatioInfo as input and weights could be used in other fully connected layers + input->info()->set_quantization_info(input_quantization_info); + weights->info()->set_quantization_info(weights_quantization_info); } else { - if(is_batched_fc_layer) - { - ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMTranspose1xWKernel::validate(input, output)); - } + // Configure matrix multiply kernel + _mm_gemm.configure(input, weights, nullptr, output, 1.f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run */)); } - - return Status{}; } -void NEFullyConnectedLayerReshapeWeights::run() +void NEFullyConnectedLayer::configure_conv_fc(const ITensor *input, const ITensor *weights, ITensor *output) { - _memory_group.acquire(); + ARM_COMPUTE_ERROR_ON((weights->info()->dimension(1) != (input->info()->dimension(0) * input->info()->dimension(1) * input->info()->dimension(2)))); - if(_transpose_weights) - { - NEScheduler::get().schedule(&_transpose_kernel, Window::DimY); - } + // If the fully connected layer is called after a convolution layer, the input tensor must be linearized - if(_is_batched_fc_layer) - { - NEScheduler::get().schedule(&_transpose1xW_kernel, Window::DimY); - } + // Initialize output tensor for im2col + TensorShape shape_im2col = compute_im2col_fc_shape(input->info()); + _im2col_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_im2col)); - _memory_group.release(); + // Configure im2col kernel + _memory_group.manage(&_im2col_output); + _im2col_kernel.configure(input, &_im2col_output, Size2D(1, 1), PadStrideInfo(1, 1, 0, 0), false, true); + + // Configure matrix multiply kernel + configure_mm(&_im2col_output, weights, output); + + // Allocate the output tensor for im2col once all the configure methods have been called + _im2col_output.allocator()->allocate(); } -NEFullyConnectedLayer::NEFullyConnectedLayer(std::shared_ptr memory_manager) - : _memory_group(std::move(memory_manager)), _im2col_kernel(), _reshape_weights_function(), _interleave4x4_kernel(), _mm_kernel(), _accumulate_biases_kernel(), _im2col_output(), - _interleave4x4_output(), _reshape_weights_output(), _original_weights(nullptr), _is_batched_fc_layer(false), _linearize_input(false), _accumulate_biases(false), _is_prepared(false) +void NEFullyConnectedLayer::configure_fc_fc(const ITensor *input, const ITensor *weights, ITensor *output) { + ARM_COMPUTE_ERROR_ON(input->info()->dimension(0) != weights->info()->dimension(1)); + + // Configure matrix multiply kernel + configure_mm(input, weights, output); } void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, FullyConnectedLayerInfo fc_info) { - // With the Fully Connected layer we can have 4 different cases: - // 1) Convolution layer -> Fully Connected layer without batches - // 2) Fully Connected layer -> Fully Connected layer without batches - // 3) Convolution layer -> Fully Connected layer with batches - // 4) Fully Connected layer -> Fully Connected layer with batches - - // Expected shape before transpose and reshaping - // Input: In x B (In and B can be multi-dimensional) - // Weights: flat(In) x Out - // Biases: Out - // Output: Out x B (B can be multi-dimensional) ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output); // Perform validate step @@ -159,155 +146,158 @@ void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weigh output->info(), fc_info)); - const int num_batch_dimensions = std::max(0, static_cast(output->info()->tensor_shape().num_dimensions()) - 1); - const int num_input_dimensions = input->info()->tensor_shape().num_dimensions() - num_batch_dimensions; - const size_t linear_input_size = input->info()->tensor_shape().total_size_lower(num_input_dimensions); - - _original_weights = weights; - _linearize_input = (input->info()->tensor_shape().x() != linear_input_size) || (num_input_dimensions > 1 && linear_input_size == 1); - _accumulate_biases = biases != nullptr; - _is_batched_fc_layer = num_batch_dimensions > 0; - _is_prepared = fc_info.are_weights_reshaped || (!fc_info.transpose_weights && !_is_batched_fc_layer); + _are_weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true; + _is_fc_after_conv = true; + _accumulate_biases = false; + _is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type()); + _original_weights = weights; - const size_t interleave_width = 16 / input->info()->element_size(); - const ITensor *weights_to_use = weights; - - if(!_is_prepared) + // Configure gemmlowp output + if(_is_quantized) { - weights_to_use = &_reshape_weights_output; - - _reshape_weights_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_fully_connected_reshaped_weights_shape(weights->info(), - fc_info.transpose_weights, - _is_batched_fc_layer, interleave_width))); - - // Reshape the weights - _reshape_weights_function.configure(weights, &_reshape_weights_output, fc_info.transpose_weights, _is_batched_fc_layer); + _gemmlowp_output.allocator()->init(output->info()->clone()->set_is_resizable(true).reset_padding().set_data_type(DataType::S32)); } - const ITensor *multiply_input = input; - - if(_linearize_input) + // Configure accumulate biases kernel for non quantized asymmetric types + if(biases != nullptr && !_is_quantized) { - _im2col_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_im2col_fc_shape(input->info(), num_input_dimensions))); - - // Configure im2col kernel - _memory_group.manage(&_im2col_output); - _im2col_kernel.configure(input, &_im2col_output, Size2D(1, 1), PadStrideInfo(1, 1, 0, 0), false, true); + _accumulate_biases = true; - multiply_input = &_im2col_output; + // Configure accumulate biases kernel + _accumulate_biases_kernel.configure(output, biases); } - int m = multiply_input->info()->dimension(1); - int k = multiply_input->info()->dimension(0); + // With the Fully Connected layer we can have 4 different cases: + // 1) Convolution layer -> Fully Connected layer without batches + // 2) Fully Connected layer -> Fully Connected layer without batches + // 3) Convolution layer -> Fully Connected layer with batches + // 4) Fully Connected layer -> Fully Connected layer with batches + + const ITensor *weights_to_use = weights; - if(_is_batched_fc_layer) + if(!_are_weights_reshaped) { - _interleave4x4_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_interleaved_shape(*multiply_input->info()))); - - // Configure interleave4x4 kernel - _memory_group.manage(&_interleave4x4_output); - _interleave4x4_kernel.configure(multiply_input, &_interleave4x4_output); + weights_to_use = &_reshape_weights_output; - multiply_input = &_interleave4x4_output; + // Reshape the weights + _reshape_weights_function.configure(weights, &_reshape_weights_output); } - // Configure matrix multiply kernel - _mm_kernel.configure(multiply_input, weights_to_use, output, 1.0f, _is_batched_fc_layer, GEMMReshapeInfo(m, 0 /* no transpose */, k)); + // Check if we have a fully connected layer with batches + const bool is_batched_fc_layer = output->info()->dimension(1) > 1; - if(_accumulate_biases) + if(is_batched_fc_layer) { - // Configure accumulate biases kernel - _accumulate_biases_kernel.configure(output, biases); + _is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(input->info()->tensor_shape().cbegin() + 3, + input->info()->tensor_shape().cend(), + output->info()->tensor_shape().cbegin() + 1)); + } + else + { + _is_fc_after_conv = input->info()->num_dimensions() > 1; } - if(_linearize_input) + ITensor *tmp_output = (_is_quantized) ? &_gemmlowp_output : output; + if(_is_fc_after_conv) + { + // Fully Connected layer after a Convolution Layer without batches + configure_conv_fc(input, weights_to_use, tmp_output); + } + else { - _im2col_output.allocator()->allocate(); + // Fully Connected layer after a Fully Connected Layer without batches + configure_fc_fc(input, weights_to_use, tmp_output); } - if(_is_batched_fc_layer) + // Configure output stage for asymmetric quantized types + if(_is_quantized) { - _interleave4x4_output.allocator()->allocate(); + float multiplier = input->info()->quantization_info().scale * weights->info()->quantization_info().scale / output->info()->quantization_info().scale; + int output_multiplier, output_shift; + quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift); + _gemmlowp_output_stage.configure(&_gemmlowp_output, biases, output, output_multiplier, output_shift, output->info()->quantization_info().offset); + _gemmlowp_output.allocator()->allocate(); } + + _are_weights_reshaped = _are_weights_reshaped || fc_info.retain_internal_weights; } Status NEFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, FullyConnectedLayerInfo fc_info) { - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32); + ARM_COMPUTE_UNUSED(fc_info.retain_internal_weights); + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights, output); - - const int num_batch_dimensions = std::max(0, static_cast(output->tensor_shape().num_dimensions()) - 1); - const int num_input_dimensions = input->tensor_shape().num_dimensions() - num_batch_dimensions; - const size_t linear_input_size = input->tensor_shape().total_size_lower(num_input_dimensions); - - const bool linearize_input = (input->tensor_shape().x() != linear_input_size) || (num_input_dimensions > 1 && linear_input_size == 1); - const bool accumulate_biases = biases != nullptr; - const bool is_batched_fc_layer = num_batch_dimensions > 0; - - ARM_COMPUTE_RETURN_ERROR_ON(input->tensor_shape().total_size_upper(num_input_dimensions) != output->tensor_shape().total_size_upper(1)); ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 2); - const size_t interleave_width = 16 / input->element_size(); - const ITensorInfo *weights_to_use = weights; - std::unique_ptr reshape_weights_output = input->clone(); + bool weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true; + bool is_fc_after_conv = true; + bool is_quantized = is_data_type_quantized_asymmetric(input->data_type()); - if(!fc_info.are_weights_reshaped && (fc_info.transpose_weights || is_batched_fc_layer)) + const ITensorInfo &im2col_input = TensorInfo(input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_im2col_fc_shape(input))); + const ITensorInfo &reshaped_weights = TensorInfo(weights->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_transposed_shape(*weights))); + const ITensorInfo &gemmlowp_output = TensorInfo(output->clone()->set_is_resizable(true).reset_padding().set_data_type(DataType::S32)); + + // Configure accumulate biases kernel for non quantized asymmetric types + if(biases != nullptr && !is_quantized) { - reshape_weights_output->set_tensor_shape(compute_fully_connected_reshaped_weights_shape(weights, fc_info.transpose_weights, is_batched_fc_layer, interleave_width)); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, biases); + ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixAccumulateBiasesKernel::validate(output, biases)); + } + + // With the Fully Connected layer we can have 4 different cases: + // 1) Convolution layer -> Fully Connected layer without batches + // 2) Fully Connected layer -> Fully Connected layer without batches + // 3) Convolution layer -> Fully Connected layer with batches + // 4) Fully Connected layer -> Fully Connected layer with batches - ARM_COMPUTE_RETURN_ON_ERROR(NEFullyConnectedLayerReshapeWeights::validate(weights, reshape_weights_output.get(), fc_info.transpose_weights, is_batched_fc_layer)); + const ITensorInfo *input_to_use = input; + const ITensorInfo *weights_to_use = weights; + const ITensorInfo *tmp_output = (is_quantized) ? &gemmlowp_output : output; - weights_to_use = reshape_weights_output.get(); + if(!weights_reshaped) + { + // Validate reshape weights kernel + ARM_COMPUTE_RETURN_ON_ERROR(NEFullyConnectedLayerReshapeWeights::validate(weights, &reshaped_weights)); + weights_to_use = &reshaped_weights; } - // Check correct shape of weights + // Check if we have a fully connected layer with batches + const bool is_batched_fc_layer = output->dimension(1) > 1; + if(is_batched_fc_layer) { - // Transpose + Transpose1xW - ARM_COMPUTE_RETURN_ERROR_ON(weights_to_use->tensor_shape().x() != linear_input_size * interleave_width); - ARM_COMPUTE_RETURN_ERROR_ON(weights_to_use->tensor_shape().y() != static_cast(std::ceil(static_cast(output->tensor_shape().x()) / interleave_width))); + is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(input->tensor_shape().cbegin() + 3, + input->tensor_shape().cend(), + output->tensor_shape().cbegin() + 1)); } else { - // Transpose - ARM_COMPUTE_RETURN_ERROR_ON(weights_to_use->tensor_shape().x() != output->tensor_shape().x()); - ARM_COMPUTE_RETURN_ERROR_ON(weights_to_use->tensor_shape().y() != linear_input_size); + is_fc_after_conv = input->num_dimensions() > 1; } - const ITensorInfo *multiply_input = input; - std::unique_ptr im2col_output = input->clone(); - std::unique_ptr interleave4x4_output = input->clone(); - - if(linearize_input) + if(is_fc_after_conv) { - im2col_output->set_tensor_shape(compute_im2col_fc_shape(input, num_input_dimensions)); + // Fully Connected layer after a Convolution Layer without batches + ARM_COMPUTE_RETURN_ERROR_ON((weights_to_use->dimension(1) != (input->dimension(0) * input->dimension(1) * input->dimension(2)))); - ARM_COMPUTE_RETURN_ON_ERROR(NEIm2ColKernel::validate(input, im2col_output.get(), Size2D(1, 1), PadStrideInfo(1, 1, 0, 0), false, true)); - - multiply_input = im2col_output.get(); + // Validate im2col kernel + ARM_COMPUTE_RETURN_ON_ERROR(NEIm2ColKernel::validate(input, &im2col_input, Size2D(1, 1), PadStrideInfo(1, 1, 0, 0), false, true)); + input_to_use = &im2col_input; } - - int m = multiply_input->dimension(1); - int k = multiply_input->dimension(0); - - if(is_batched_fc_layer) + else { - interleave4x4_output->set_tensor_shape(compute_interleaved_shape(*multiply_input)); - - ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMInterleave4x4Kernel::validate(multiply_input, interleave4x4_output.get())); - - multiply_input = interleave4x4_output.get(); + // Fully Connected layer after a Fully Connected Layer without batches + ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(0) != weights_to_use->dimension(1)); } + // Validate matrix multiply kernel + ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(*input_to_use, *weights_to_use, *tmp_output)); - ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixMultiplyKernel::validate(multiply_input, weights_to_use, output, 1.0f, is_batched_fc_layer, GEMMReshapeInfo(m, 0 /* no transpose */, k))); - - if(accumulate_biases) + // Validate output stage for asymmetric quantized types + if(is_quantized) { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, biases); - ARM_COMPUTE_RETURN_ERROR_ON(biases->tensor_shape().x() != output->tensor_shape().x()); - - ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixAccumulateBiasesKernel::validate(output, biases)); + ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::validate(&gemmlowp_output, biases, output)); } return Status{}; @@ -320,24 +310,32 @@ void NEFullyConnectedLayer::run() _memory_group.acquire(); // Linearize input if it comes from a convolutional layer - if(_linearize_input) + if(_is_fc_after_conv) { NEScheduler::get().schedule(&_im2col_kernel, Window::DimY); } - // Interleave input - if(_is_batched_fc_layer) + // Run matrix multiply + if(_is_quantized) { - NEScheduler::get().schedule(&_interleave4x4_kernel, Window::DimY); + _mm_gemmlowp.run(); + } + else + { + _mm_gemm.run(); } - - // Run matrix multiply - NEScheduler::get().schedule(&_mm_kernel, _is_batched_fc_layer ? Window::DimY : Window::DimX); // Accumulate biases if provided - if(_accumulate_biases) + if(_is_quantized) { - NEScheduler::get().schedule(&_accumulate_biases_kernel, Window::DimY); + _gemmlowp_output_stage.run(); + } + else + { + if(_accumulate_biases) + { + NEScheduler::get().schedule(&_accumulate_biases_kernel, Window::DimY); + } } _memory_group.release(); @@ -345,16 +343,30 @@ void NEFullyConnectedLayer::run() void NEFullyConnectedLayer::prepare() { - // Reshape of the weights (happens only once) if(!_is_prepared) { - ARM_COMPUTE_ERROR_ON(!_original_weights->is_used()); - - // Run weights reshape, clean internal tensors and mark original weights tensor as unused - _reshape_weights_output.allocator()->allocate(); - _reshape_weights_function.run(); - _reshape_weights_function = NEFullyConnectedLayerReshapeWeights(); - _original_weights->mark_as_unused(); + // Reshape of the weights (happens only once) + if(!_are_weights_reshaped) + { + ARM_COMPUTE_ERROR_ON(!_original_weights->is_used()); + + // Run reshape weights kernel and mark weights as unused + _reshape_weights_output.allocator()->allocate(); + _reshape_weights_function.run(); + _original_weights->mark_as_unused(); + + // Prepare GEMM prepare and release unused weights + if(!_is_quantized) + { + _mm_gemm.prepare(); + if(!_reshape_weights_output.is_used()) + { + _reshape_weights_output.allocator()->free(); + } + } + + _are_weights_reshaped = true; + } _is_prepared = true; } diff --git a/src/runtime/NEON/functions/NEGEMM.cpp b/src/runtime/NEON/functions/NEGEMM.cpp index c958904b93..e47ef86a1c 100644 --- a/src/runtime/NEON/functions/NEGEMM.cpp +++ b/src/runtime/NEON/functions/NEGEMM.cpp @@ -23,12 +23,14 @@ */ #include "arm_compute/runtime/NEON/functions/NEGEMM.h" +#include "arm_compute/core/CPP/Validate.h" #include "arm_compute/core/Error.h" #include "arm_compute/core/Helpers.h" #include "arm_compute/core/ITensor.h" #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/Validate.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/runtime/NEON/NEScheduler.h" #include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h" #include "arm_compute/runtime/TensorAllocator.h" @@ -36,6 +38,8 @@ #include +using namespace arm_compute::misc::shape_calculator; + namespace arm_compute { NEGEMM::NEGEMM(std::shared_ptr memory_manager) @@ -46,21 +50,7 @@ NEGEMM::NEGEMM(std::shared_ptr memory_manager) void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::F32, DataType::F16); - ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(a, b, d); - ARM_COMPUTE_ERROR_ON_MSG(a->info()->dimension(0) != b->info()->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B"); - ARM_COMPUTE_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported"); - ARM_COMPUTE_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported"); - - if(c != nullptr) - { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(c, 1, DataType::F32, DataType::F16); - ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(a, c); - ARM_COMPUTE_ERROR_ON_MSG(a->info()->dimension(1) != c->info()->dimension(1), "The C matrix must have the same number of rows as the matrix A"); - ARM_COMPUTE_ERROR_ON_MSG(b->info()->dimension(0) != c->info()->dimension(0), "The C matrix must have the same number of columns as the matrix B"); - ARM_COMPUTE_ERROR_ON_MSG(c->info()->dimension(0) != d->info()->dimension(0), "The C matrix must have the same number of rows as the output matrix"); - ARM_COMPUTE_ERROR_ON_MSG(c->info()->dimension(1) != d->info()->dimension(1), "The C matrix must have the same number of columns as the output matrix"); - } + ARM_COMPUTE_ERROR_THROW_ON(NEGEMM::validate(a->info(), b->info(), (c != nullptr) ? c->info() : nullptr, d->info(), alpha, beta, gemm_info)); // Check if we need to reshape the matrix B only on the first run _is_prepared = false; @@ -68,7 +58,7 @@ void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITe _run_vector_matrix_multiplication = a->info()->dimension(1) < 2; _original_b = b; - bool run_optimised = a->info()->data_type() == DataType::F32 && (c == nullptr || beta == 0.f); + bool run_optimised = c == nullptr && bool(NEGEMMAssemblyDispatch::validate(a->info(), b->info(), d->info(), alpha, beta, _reshape_b_only_on_first_run)); if(run_optimised) { _asm_glue.configure(a, b, d, alpha, beta, _reshape_b_only_on_first_run); @@ -149,6 +139,137 @@ void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITe } } +Status NEGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info) +{ + ARM_COMPUTE_UNUSED(alpha); + + ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::F32, DataType::F16); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b, output); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(0) != b->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported"); + + if(c != nullptr) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(c, 1, DataType::F32, DataType::F16); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, c); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(1) != c->dimension(1), "The C matrix must have the same number of rows as the matrix A"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(b->dimension(0) != c->dimension(0), "The C matrix must have the same number of columns as the matrix B"); + } + + if(output->total_size() != 0) + { + ARM_COMPUTE_RETURN_ERROR_ON(b->dimension(0) != output->dimension(0)); + ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(1) != output->dimension(1)); + } + + // Check if the first input tensor is a vector. + const bool run_vector_matrix_multiplication = a->dimension(1) < 2; + // Check if we need to reshape the matrix A and matrix B + const bool run_interleave_transpose = !run_vector_matrix_multiplication && !(gemm_info.reshape_b_only_on_first_run()); + // Check if we need to run the optimized assembly kernel + const bool run_optimised = c == nullptr && bool(NEGEMMAssemblyDispatch::validate(a, b, output, alpha, beta, true)); + + const ITensorInfo *matrix_a_info = a; + const ITensorInfo *matrix_b_info = b; + + TensorInfo tmp_a_info{}; + TensorInfo tmp_b_info{}; + TensorInfo tmp_output_info = *output->clone(); + + // Arguments used by GEMMReshapeInfo + // If we pass the matrix A and matrix B reshaped to NEGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to NEGEMMReshapeInfo + // in order to know how the matrices have been reshaped + const int m = a->dimension(1); + const int n = b->dimension(0); + const int k = a->dimension(0); + int mult_transpose1xW_width = 1; + int mult_interleave4x4_height = 1; + + const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, gemm_info.depth_output_gemm3d()); + + // Initialize shapes + if(run_interleave_transpose) + { + matrix_a_info = &tmp_a_info; + auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_interleaved_shape(*a, mult_interleave4x4_height))); + + matrix_b_info = &tmp_b_info; + auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_transpose1xW_with_element_size_shape(*b, mult_transpose1xW_width))); + + auto_init_if_empty(tmp_output_info, matrix_a_info->clone()->set_tensor_shape(compute_mm_shape(*matrix_a_info, *matrix_b_info, run_interleave_transpose, reshape_info))); + } + + // Validate kernels + if(run_optimised && run_interleave_transpose) + { + /* Interleave */ + TensorShape tensor_shape0{ matrix_a_info->tensor_shape() }; + tensor_shape0.set(0, k); + tensor_shape0.set(1, m); + + const TensorInfo tensor_info0 = matrix_a_info->clone()->set_tensor_shape(tensor_shape0); + const TensorInfo tensor_info_reshaped0 = matrix_a_info->clone()->set_tensor_shape(compute_interleaved_shape(tensor_info0, mult_interleave4x4_height)); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(matrix_a_info, &tensor_info_reshaped0); + + if(n != 0) /* Transpose */ + { + TensorShape tensor_shape1{ matrix_b_info->tensor_shape() }; + tensor_shape1.set(0, n); + tensor_shape1.set(1, k); + + const TensorInfo tensor_info1 = matrix_b_info->clone()->set_tensor_shape(tensor_shape1); + const TensorInfo tensor_info_reshaped1 = matrix_b_info->clone()->set_tensor_shape(compute_transpose1xW_with_element_size_shape(tensor_info1, mult_transpose1xW_width)); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(matrix_b_info, &tensor_info_reshaped1); + } + + if(output->total_size() != 0) + { + if(n != 0) + { + ARM_COMPUTE_RETURN_ERROR_ON(tmp_output_info.dimension(0) != static_cast(n)); + } + ARM_COMPUTE_RETURN_ERROR_ON(tmp_output_info.dimension(1) != static_cast(m)); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(matrix_a_info, &tmp_output_info); + } + } + else if(run_vector_matrix_multiplication) + { + ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixMultiplyKernel::validate(a, b, output, alpha, false, reshape_info)); + + if(beta != 0 && c != nullptr) + { + // Validate matrix addition kernel + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(c, output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(c, output); + } + } + else + { + if(run_interleave_transpose) + { + // Validate interleave kernel + ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMInterleave4x4Kernel::validate(a, matrix_a_info)); + + // Validate transpose kernel + ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMTranspose1xWKernel::validate(b, matrix_b_info)); + } + + // Validate matrix multiply + ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, &tmp_output_info, alpha, run_interleave_transpose, reshape_info)); + + if(beta != 0 && c != nullptr) + { + // Validate matrix addition kernel + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(c, &tmp_output_info); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(c, &tmp_output_info); + } + } + + return Status{}; +} + void NEGEMM::run() { prepare(); @@ -196,7 +317,6 @@ void NEGEMM::prepare() ARM_COMPUTE_ERROR_ON(!_original_b->is_used()); _asm_glue.prepare(); - _original_b->mark_as_unused(); } else if(_reshape_b_only_on_first_run && !_run_vector_matrix_multiplication && !_asm_glue.is_configured()) { -- cgit v1.2.1