diff options
Diffstat (limited to 'src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.cpp')
-rw-r--r-- | src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.cpp | 51 |
1 files changed, 32 insertions, 19 deletions
diff --git a/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.cpp b/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.cpp index 89aa36486c..c69af55f57 100644 --- a/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.cpp +++ b/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021 Arm Limited. + * Copyright (c) 2019-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -475,9 +475,18 @@ inline void run_offset_contribution_output_stage_window_symm(const int32_t *vect template <typename T> void run_offset_contribution_output_stage(const Window &window, const ITensor *mm_result, const ITensor *vector_sum_col, const ITensor *vector_sum_row, const ITensor *bias, ITensor *output, - int32_t a_offset, int32_t b_offset, int32_t k_offset, bool slide_vector_sum_col, + int32_t a_offset, int32_t b_offset, int32_t k_offset, bool is_vector_sum_col_batched, GEMMLowpOutputStageInfo output_stage, bool is_gemm3d, bool is_bounded_relu, bool is_fixed_point) { + // Semantics of XYZW Explained for each tensor + // + // | Tensor | XYZW when is_gemm3d == false | XYZW when is_gemm3d == true | + // ------------------------------------------------------------------------------------------------------------------- + // | mm_result | x -> width, y -> height, z -> batch | x -> width, y -> height, z -> depth, w -> batch | + // | collapsed window | x -> width, y -> height, z -> batch | x -> width, y -> height, z -> depth * batch | + // | vector_sum_row | x -> height, y -> batch | x -> height * depth, y -> batch | + // | Vector_sum_col | x -> width, y -> batch | x -> width, y -> batch | + using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>; using Typer = VectorTyper<T>; @@ -517,8 +526,8 @@ void run_offset_contribution_output_stage(const Window &window, const size_t sum_row_stride_y = vector_sum_row->info()->strides_in_bytes().y(); - // Offset in case vector_sum_col is batched - const int vector_sum_col_batch_offset = slide_vector_sum_col ? vector_sum_col->info()->strides_in_bytes().z() : 0; + // Offset in case vector_sum_col is batched in y dimension + const int vector_sum_col_stride_batch = is_vector_sum_col_batched ? vector_sum_col->info()->strides_in_bytes().y() : 0; if(bias != nullptr) { @@ -526,7 +535,7 @@ void run_offset_contribution_output_stage(const Window &window, execute_window_loop(collapsed_window, [&](const Coordinates & id) { const int batch_id = id.z() / depth_input; - const auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_id * vector_sum_col_batch_offset); + const auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_id * vector_sum_col_stride_batch); const auto vector_sum_row_ptr = reinterpret_cast<const int32_t *>(vector_sum_row_it.ptr() + batch_id * sum_row_stride_y) + id.y() + (id.z() % depth_input) * height_input; run_offset_contribution_output_stage_window<Typer>(vector_sum_col_ptr, vector_sum_row_ptr, reinterpret_cast<const int32_t *>(bias_it.ptr()), @@ -544,7 +553,7 @@ void run_offset_contribution_output_stage(const Window &window, execute_window_loop(collapsed_window, [&](const Coordinates & id) { const int batch_id = id.z() / depth_input; - const auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_id * vector_sum_col_batch_offset); + const auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_id * vector_sum_col_stride_batch); const auto vector_sum_row_ptr = reinterpret_cast<const int32_t *>(vector_sum_row_it.ptr() + batch_id * sum_row_stride_y) + id.y() + (id.z() % depth_input) * height_input; run_offset_contribution_output_stage_window<Typer>(vector_sum_col_ptr, vector_sum_row_ptr, nullptr, mm_result_it, out_it, @@ -603,8 +612,8 @@ void run_offset_contribution_output_stage(const Window &window, Iterator vector_sum_col_it = get_vector_sum_col_it(collapsed_window, vector_sum_col); - // Offset in case vector_sum_col is batched - const int vector_sum_col_batch_offset = slide_vector_sum_col ? vector_sum_col->info()->strides_in_bytes().z() : 0; + // Offset in case vector_sum_col is batched in y dimension + const int vector_sum_col_stride_batch = is_vector_sum_col_batched ? vector_sum_col->info()->strides_in_bytes().y() : 0; if(bias != nullptr) { @@ -612,7 +621,7 @@ void run_offset_contribution_output_stage(const Window &window, execute_window_loop(collapsed_window, [&](const Coordinates & id) { const int batch_id = id.z() / depth_input; - const auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_id * vector_sum_col_batch_offset); + const auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_id * vector_sum_col_stride_batch); run_offset_contribution_output_stage_window<Typer>(vector_sum_col_ptr, nullptr, reinterpret_cast<const int32_t *>(bias_it.ptr()), mm_result_it, out_it, result_offset_s32, result_shift_s32, @@ -627,7 +636,7 @@ void run_offset_contribution_output_stage(const Window &window, execute_window_loop(collapsed_window, [&](const Coordinates & id) { const int batch_id = id.z() / depth_input; - const auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_id * vector_sum_col_batch_offset); + const auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_id * vector_sum_col_stride_batch); run_offset_contribution_output_stage_window<Typer>(vector_sum_col_ptr, nullptr, nullptr, mm_result_it, out_it, result_offset_s32, result_shift_s32, min_vec, max_vec, a_offset, b_offset, k_offset, @@ -670,7 +679,7 @@ void run_offset_contribution_output_stage(const Window &window, void run_offset_contribution_output_stage_symm(const Window &window, const ITensor *mm_result, const ITensor *vector_sum_col, const ITensor *vector_sum_row, const ITensor *bias, ITensor *output, - int32_t a_offset, int32_t b_offset, int32_t k_offset, bool slide_vector_sum_col, + int32_t a_offset, int32_t b_offset, int32_t k_offset, bool is_vector_sum_col_batched, GEMMLowpOutputStageInfo output_stage, bool is_gemm3d, bool is_bounded_relu, bool is_fixed_point) { ARM_COMPUTE_UNUSED(vector_sum_row, b_offset, k_offset); @@ -705,8 +714,8 @@ void run_offset_contribution_output_stage_symm(const Window &window, Iterator vector_sum_col_it = get_vector_sum_col_it(collapsed_window, vector_sum_col); - // Offset in case vector_sum_col is batched - const int vector_sum_col_batch_offset = slide_vector_sum_col ? vector_sum_col->info()->strides_in_bytes().z() : 0; + // Offset in case vector_sum_col is batched in y dimension + const int vector_sum_col_stride_batch = is_vector_sum_col_batched ? vector_sum_col->info()->strides_in_bytes().y() : 0; if(bias != nullptr) { @@ -714,7 +723,7 @@ void run_offset_contribution_output_stage_symm(const Window &window, execute_window_loop(collapsed_window, [&](const Coordinates & id) { const int batch_id = id.z() / depth_input; - const auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_id * vector_sum_col_batch_offset); + const auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_id * vector_sum_col_stride_batch); run_offset_contribution_output_stage_window_symm(vector_sum_col_ptr, reinterpret_cast<const int32_t *>(bias_it.ptr()), mm_result_it, out_it, result_multipliers, result_shifts, result_offset_s32, min_s8, max_s8, @@ -728,7 +737,7 @@ void run_offset_contribution_output_stage_symm(const Window &window, execute_window_loop(collapsed_window, [&](const Coordinates & id) { const int batch_id = id.z() / depth_input; - const auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_id * vector_sum_col_batch_offset); + const auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_id * vector_sum_col_stride_batch); run_offset_contribution_output_stage_window_symm(vector_sum_col_ptr, nullptr, mm_result_it, out_it, result_multipliers, result_shifts, result_offset_s32, min_s8, max_s8, @@ -792,6 +801,7 @@ Status validate_arguments(const ITensorInfo *mm_result, const ITensorInfo *vecto { ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(vector_sum_col, 1, DataType::S32); ARM_COMPUTE_RETURN_ERROR_ON(vector_sum_col->dimension(0) != mm_result->dimension(0)); + ARM_COMPUTE_RETURN_ERROR_ON(vector_sum_col->num_dimensions() > 2); } // If b_offset == 0, vector_sum_row can be a nullptr @@ -827,6 +837,9 @@ Status validate_arguments(const ITensorInfo *mm_result, const ITensorInfo *vecto "vector_sum_col tensor must have the same number of batches of vector_sum_row_shape or the number of batches must be set to 1"); } } + + // Check Tensor Rank of vector_sum_row + ARM_COMPUTE_RETURN_ERROR_ON(vector_sum_row->num_dimensions() > 2); } if(output->total_size() != 0) @@ -860,7 +873,7 @@ void CpuGemmLowpOffsetContributionOutputStageKernel::configure(const ITensorInfo // Check if vector_sum_col_shape should be slidden or not // Don't slide vector_sum_col_shape along the y dimension if vector_sum_col_shape has just 1 dimension and vector_sum_row_shape more than 1 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation - _slide_vector_sum_col = vector_sum_col->tensor_shape().num_dimensions() > 1; + _is_vector_sum_col_batched = vector_sum_col->tensor_shape().num_dimensions() > 1; } // Output auto inizialitation if not yet initialized @@ -919,19 +932,19 @@ void CpuGemmLowpOffsetContributionOutputStageKernel::run_op(ITensorPack &tensors if(is_symm) { - run_offset_contribution_output_stage_symm(window, mm_result, vector_sum_col, vector_sum_row, bias, dst, _a_offset, _b_offset, _k_offset, _slide_vector_sum_col, _output_stage, + run_offset_contribution_output_stage_symm(window, mm_result, vector_sum_col, vector_sum_row, bias, dst, _a_offset, _b_offset, _k_offset, _is_vector_sum_col_batched, _output_stage, reinterpret_as_3d, is_bounded_relu, is_fixed_point); } else { if(is_signed) { - run_offset_contribution_output_stage<int8_t>(window, mm_result, vector_sum_col, vector_sum_row, bias, dst, _a_offset, _b_offset, _k_offset, _slide_vector_sum_col, _output_stage, + run_offset_contribution_output_stage<int8_t>(window, mm_result, vector_sum_col, vector_sum_row, bias, dst, _a_offset, _b_offset, _k_offset, _is_vector_sum_col_batched, _output_stage, reinterpret_as_3d, is_bounded_relu, is_fixed_point); } else { - run_offset_contribution_output_stage<uint8_t>(window, mm_result, vector_sum_col, vector_sum_row, bias, dst, _a_offset, _b_offset, _k_offset, _slide_vector_sum_col, _output_stage, + run_offset_contribution_output_stage<uint8_t>(window, mm_result, vector_sum_col, vector_sum_row, bias, dst, _a_offset, _b_offset, _k_offset, _is_vector_sum_col_batched, _output_stage, reinterpret_as_3d, is_bounded_relu, is_fixed_point); } } |