diff options
author | Gunes Bayir <gunes.bayir@arm.com> | 2024-07-02 15:45:01 +0100 |
---|---|---|
committer | Gunes Bayir <gunes.bayir@arm.com> | 2024-07-02 16:00:11 +0000 |
commit | a3f238a44d9f306c77be0177f13d22ae3f3bcc57 (patch) | |
tree | 44bf40fb59fb8c4452d65d25e3a967c035bc6863 /src/cpu | |
parent | f92b0fffa0d32dc08340c1abfa1a7f09c6e53795 (diff) | |
download | ComputeLibrary-a3f238a44d9f306c77be0177f13d22ae3f3bcc57.tar.gz |
Revert "Update CPU kernels and add mixed sign GEMM support"
This reverts commit fc94f4d23abd4bc427b701f54ad85282e9ec7872 and 5d6fff041ade7eb44af0945867212f3979be3d3e (because the latter fixes a build failure caused by the former)
Change-Id: I7d07fea8307e9a7033b30874bbb14ba9202b23d8
Signed-off-by: Gunes Bayir <gunes.bayir@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11815
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Adnan AlSinan <adnan.alsinan@arm.com>
Diffstat (limited to 'src/cpu')
-rw-r--r-- | src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.cpp | 4 | ||||
-rw-r--r-- | src/cpu/kernels/CpuKernelSelectionTypes.h | 2 | ||||
-rw-r--r-- | src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h | 8 | ||||
-rw-r--r-- | src/cpu/kernels/assembly/arm_gemm.hpp | 12 | ||||
-rw-r--r-- | src/cpu/kernels/assembly/convolution_parameters.hpp | 2 | ||||
-rw-r--r-- | src/cpu/kernels/assembly/gemm_common.hpp | 18 | ||||
-rw-r--r-- | src/cpu/operators/CpuConv2d.h | 5 | ||||
-rw-r--r-- | src/cpu/operators/CpuGemmConv2d.h | 1 | ||||
-rw-r--r-- | src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp | 75 | ||||
-rw-r--r-- | src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h | 2 | ||||
-rw-r--r-- | src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp | 270 |
11 files changed, 178 insertions, 221 deletions
diff --git a/src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.cpp b/src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.cpp index 87340e566e..5b88735e7a 100644 --- a/src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.cpp +++ b/src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.cpp @@ -684,10 +684,6 @@ Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, cons DataType::U8); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(dst, 1, DataType::S32); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(src0->data_type() == DataType::QASYMM8_SIGNED && - src1->data_type() == DataType::QASYMM8, - "QASYMM8_SIGNED input with QASYMM8 weights not supported"); - TensorShape in0_shape = src0->tensor_shape(); TensorShape in1_shape = src1->tensor_shape(); TensorShape out_shape = dst->tensor_shape(); diff --git a/src/cpu/kernels/CpuKernelSelectionTypes.h b/src/cpu/kernels/CpuKernelSelectionTypes.h index 03a474de53..7c1e4772a6 100644 --- a/src/cpu/kernels/CpuKernelSelectionTypes.h +++ b/src/cpu/kernels/CpuKernelSelectionTypes.h @@ -105,7 +105,7 @@ struct SoftmaxKernelDataTypeISASelectorData cpuinfo::CpuIsaInfo isa; bool is_log; int axis; - uint64_t sme2_vector_length; + unsigned long sme2_vector_length; }; // Selector pointer types diff --git a/src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h b/src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h index 72fafca1bb..e2a27675b3 100644 --- a/src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h +++ b/src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h @@ -52,7 +52,7 @@ namespace kernel * * */ -template <typename TypeInput, typename TypeWeight, typename TypeOutput> +template <typename TypeInput, typename TypeOutput> class CpuGemmAssemblyWrapperKernel final : public INEKernel { public: @@ -101,7 +101,7 @@ public: * @param[in] kernel Pointer to an assembly kernel implementation. * @param[in] kernel_name_tag Tag to be attacehd to the kernel's name. */ - void configure(arm_gemm::GemmCommon<TypeInput, TypeWeight, TypeOutput> *kernel, std::string kernel_name_tag) + void configure(arm_gemm::GemmCommon<TypeInput, TypeOutput> *kernel, std::string kernel_name_tag) { ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast<void *>(kernel))); _kernel = kernel; @@ -131,8 +131,8 @@ public: } private: - arm_gemm::GemmCommon<TypeInput, TypeWeight, TypeOutput> *_kernel; - std::string _name; + arm_gemm::GemmCommon<TypeInput, TypeOutput> *_kernel; + std::string _name; }; } // namespace kernel } // namespace cpu diff --git a/src/cpu/kernels/assembly/arm_gemm.hpp b/src/cpu/kernels/assembly/arm_gemm.hpp index cbc8be416e..941fed0ba8 100644 --- a/src/cpu/kernels/assembly/arm_gemm.hpp +++ b/src/cpu/kernels/assembly/arm_gemm.hpp @@ -277,8 +277,8 @@ struct Nothing { }; -template <typename Tlop, typename Trop, typename Tret> -using UniqueGemmCommon = std::unique_ptr<GemmCommon<Tlop, Trop, Tret>>; +template <typename Top, typename Tret> +using UniqueGemmCommon = std::unique_ptr<GemmCommon<Top, Tret>>; /* Low level API calls. * These are implemented as 'GemmArgs' versions, or with the arguments explicitly listed. */ @@ -288,13 +288,13 @@ using UniqueGemmCommon = std::unique_ptr<GemmCommon<Tlop, Trop, Tret>>; template <typename Top, typename Tret, class OutputStage = Nothing> KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage & = {}); -template <typename Tlop, typename Trop, typename Tret, class OutputStage = Nothing> -UniqueGemmCommon<Tlop, Trop, Tret> gemm(const GemmArgs &args, const OutputStage & = {}); +template <typename Top, typename Tret, class OutputStage = Nothing> +UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage & = {}); -template <typename Tlop, typename Trop, typename Tret, class OutputStage = Nothing> +template <typename Top, typename Tret, class OutputStage = Nothing> std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage & = {}); -template <typename Tlop, typename Trop, typename Tret, class OutputStage = Nothing> +template <typename Top, typename Tret, class OutputStage = Nothing> bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const OutputStage & = {}); } // namespace arm_gemm diff --git a/src/cpu/kernels/assembly/convolution_parameters.hpp b/src/cpu/kernels/assembly/convolution_parameters.hpp index 09b73ca409..a6cf96344c 100644 --- a/src/cpu/kernels/assembly/convolution_parameters.hpp +++ b/src/cpu/kernels/assembly/convolution_parameters.hpp @@ -61,8 +61,6 @@ struct ConvolutionParameters int64_t output_stride_w; int64_t output_stride_h; // output_channels not included as they do not affect the input. - int64_t dilation_w; - int64_t dilation_h; int64_t padding_top; int64_t padding_left; float padding_value; diff --git a/src/cpu/kernels/assembly/gemm_common.hpp b/src/cpu/kernels/assembly/gemm_common.hpp index f693021fcb..45d1e43274 100644 --- a/src/cpu/kernels/assembly/gemm_common.hpp +++ b/src/cpu/kernels/assembly/gemm_common.hpp @@ -189,7 +189,7 @@ public: * 'set_arrays' to capture the provided arguments in protected class * members, as essentially any implementation will need these. */ -template <typename To, typename Tw, typename Tr> +template <typename To, typename Tr> class GemmCommon : public IGemmCommon { protected: @@ -197,7 +197,7 @@ protected: int _lda = 0; int _A_batch_stride = 0; int _A_multi_stride = 0; - const Tw *_Bptr = nullptr; + const To *_Bptr = nullptr; int _ldb = 0; int _B_multi_stride = 0; Tr *_Cptr = nullptr; @@ -214,7 +214,7 @@ public: const int lda, const int A_batch_stride, const int A_multi_stride, - const Tw *B, + const To *B, const int ldb, /* batches share B */ const int B_multi_stride, Tr *C, @@ -254,7 +254,7 @@ public: const void *bias, /* no row or batch stride needed */ const int bias_multi_stride) override { - set_arrays(static_cast<const To *>(A), lda, A_batch_stride, A_multi_stride, static_cast<const Tw *>(B), ldb, + set_arrays(static_cast<const To *>(A), lda, A_batch_stride, A_multi_stride, static_cast<const To *>(B), ldb, B_multi_stride, static_cast<Tr *>(C), ldc, C_batch_stride, C_multi_stride, static_cast<const Tr *>(bias), bias_multi_stride); } @@ -262,17 +262,17 @@ public: /*** "Pretransposed" interface ***/ /* Compute col sums over all columns */ - virtual void requantize_bias(void *, const Tw *, const int, const int){}; + virtual void requantize_bias(void *, const To *, const int, const int){}; /* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */ /* Arguments are: output buffer pointer, source pointer, source row stride, source multi stride */ - virtual void pretranspose_B_array(void *, const Tw *, const int, const int, bool){}; + virtual void pretranspose_B_array(void *, const To *, const int, const int, bool){}; /* Implementation of the void * overload which casts its arguments to the appropriate type. */ void pretranspose_B_array_generic( void *out, const void *in, const int row_stride, const int multi_stride, bool transposed) override { - pretranspose_B_array(out, static_cast<const Tw *>(in), row_stride, multi_stride, transposed); + pretranspose_B_array(out, static_cast<const To *>(in), row_stride, multi_stride, transposed); } /* Threaded versions of the above. @@ -280,7 +280,7 @@ public: * just calls the non-threaded functions to do the work. This is valid as with window size of 1 the only * legal values for start and end are 0 and 1 respectively. */ virtual void pretranspose_B_array_part( - void *out, const Tw *in, const int row_stride, const int multi_stride, bool transposed, size_t, size_t) + void *out, const To *in, const int row_stride, const int multi_stride, bool transposed, size_t, size_t) { pretranspose_B_array(out, in, row_stride, multi_stride, transposed); }; @@ -293,7 +293,7 @@ public: size_t start, size_t end) override { - pretranspose_B_array_part(out, static_cast<const Tw *>(in), row_stride, multi_stride, transposed, start, end); + pretranspose_B_array_part(out, static_cast<const To *>(in), row_stride, multi_stride, transposed, start, end); } /*** Indirect interface ***/ diff --git a/src/cpu/operators/CpuConv2d.h b/src/cpu/operators/CpuConv2d.h index 0012ff6609..3f98e71896 100644 --- a/src/cpu/operators/CpuConv2d.h +++ b/src/cpu/operators/CpuConv2d.h @@ -85,7 +85,6 @@ public: * |F16 |F16 |F16 |F16 | * |F32 |F32 |F32 |F32 | * |QASYMM8 |QASYMM8 |S32 |QASYMM8 | - * |QASYMM8 |QASYMM8_SIGNED |S32 |QASYMM8 | * |QASYMM8 |QSYMM8_PER_CHANNEL |S32 |QASYMM8 | * |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |QASYMM8_SIGNED | * |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32 |QASYMM8_SIGNED | @@ -94,7 +93,7 @@ public: * while every optional dimension from 4 and above represent a batch of inputs. * Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32. * @param[in] weights Weights tensor info. Weights are 4D tensor with dimensions [kernel_x, kernel_y, IFM, OFM]. - * Data type supported: Same as @p src, also could be QSYMM8_PER_CHANNEL or QASYMM8_SIGNED if input is QASYMM8/QASYMM8_SIGNED. + * Data type supported: Same as @p src, also could be QSYMM8_PER_CHANNEL if input is QASYMM8/QASYMM8_SIGNED. * @param[in] biases Biases tensor info. Shared biases supported. Biases are 1D tensor with dimensions [OFM]. * Data type supported: Same as @p src, except for input of QASYMM8/QASYMM8_SIGNED type where biases should be of S32 type. * @param[out] dst Destination tensor info. 3 lower dimensions represent a single output [width, height, OFM], while the rest represent batch of outputs. @@ -140,7 +139,7 @@ public: * while every optional dimension from 4 and above represent a batch of inputs. * Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32. * @param[in] weights Weights tensor info. Weights are 4D tensor with dimensions [kernel_x, kernel_y, IFM, OFM]. - * Data type supported:Same as @p src, also could be QSYMM8_PER_CHANNEL or QASYMM8_SIGNED if input is QASYMM8/QASYMM8_SIGNED. + * Data type supported:Same as @p src, also could be QSYMM8_PER_CHANNEL if input is QASYMM8/QASYMM8_SIGNED. * @param[in] dst Destination tensor info. 3 lower dimensions represent a single output [width, height, OFM], while the rest represent batch of outputs. * Data types supported: Same as @p src. * @param[in] conv_info Contains padding and stride information described in @ref PadStrideInfo. diff --git a/src/cpu/operators/CpuGemmConv2d.h b/src/cpu/operators/CpuGemmConv2d.h index ae5023a71a..fa16ce860b 100644 --- a/src/cpu/operators/CpuGemmConv2d.h +++ b/src/cpu/operators/CpuGemmConv2d.h @@ -76,7 +76,6 @@ public: * |F32 |F32 |F32 |F32 | * |BFLOAT16 |BFLOAT16 |BFLOAT16 |BFLOAT16 | * |QASYMM8 |QASYMM8 |S32 |QASYMM8 | - * |QASYMM8 |QASYMM8_SIGNED |S32 |QASYMM8 | * |QASYMM8 |QSYMM8_PER_CHANNEL |S32 |QASYMM8 | * |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |QASYMM8_SIGNED | * |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32 |QASYMM8_SIGNED | diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp index 1dbe3d8a31..f3396fbb5c 100644 --- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp +++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp @@ -128,31 +128,24 @@ void CpuGemmLowpMatrixMultiplyCore::configure( _reshape_b_only_on_first_run; _gemm_info = gemm_info; - const ITensorInfo *a_to_use = a; - - // Initialize assembly kernel meta-data - const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info); - - const int32_t offset_correction = 128; - const DataType dt = DataType::QASYMM8_SIGNED; - const UniformQuantizationInfo iqinfo = a_to_use->quantization_info().uniform(); - - _signed_a = a_to_use->clone()->set_data_type(dt).set_quantization_info( - QuantizationInfo(iqinfo.scale, iqinfo.offset + offset_correction)); - - // If inputs are mixed-sign but this machine does not support mixed sign kernels, - // flip the sign so matched-sign kernels can be used. - if (!_flip_signedness && a->data_type() == DataType::QASYMM8 && b->data_type() == DataType::QASYMM8_SIGNED && - !bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, c, dst, asm_info))) - { - _flip_signedness = true; - } + // Offset kernel is need if offset is non-zero or it may change (i.e. dynamic). + // It is not needed if the datatype is symmetric, because there is no offset + bool a_offset_kernel_needed = _a_offset != 0 || a->quantization_info().is_dynamic(); + bool b_offset_kernel_needed = _b_offset != 0 || b->quantization_info().is_dynamic(); _asm_glue = std::make_unique<cpu::CpuGemmAssemblyDispatch>(); + const ITensorInfo *a_to_use = a; + // Convert to QASYMM8 -> QASYMM8_SIGNED and back if (_flip_signedness) { + const int32_t offset_correction = 128; + const DataType dt = DataType::QASYMM8_SIGNED; + const UniformQuantizationInfo iqinfo = a_to_use->quantization_info().uniform(); + + _signed_a = a_to_use->clone()->set_data_type(dt).set_quantization_info( + QuantizationInfo(iqinfo.scale, iqinfo.offset + offset_correction)); _convert_to_signed_asymm = std::make_unique<kernels::CpuConvertQuantizedSignednessKernel>(); _convert_to_signed_asymm->configure(a_to_use, &_signed_a); a_to_use = &_signed_a; @@ -173,11 +166,6 @@ void CpuGemmLowpMatrixMultiplyCore::configure( matrix_a = &_signed_a; } - // Offset kernel is need if offset is non-zero or it may change (i.e. dynamic). - // It is not needed if the datatype is symmetric, because there is no offset - bool a_offset_kernel_needed = _a_offset != 0 || a->quantization_info().is_dynamic(); - bool b_offset_kernel_needed = _b_offset != 0 || b->quantization_info().is_dynamic(); - // If GEMMLowpOutputStage != NONE, fuse the offset contribution with the output stage if (info.gemmlowp_output_stage().type != GEMMLowpOutputStageType::NONE) { @@ -185,6 +173,8 @@ void CpuGemmLowpMatrixMultiplyCore::configure( _mm_result_s32 = TensorInfo(dst->tensor_shape(), 1, DataType::S32); } + // Initialize assembly kernel meta-data + const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info); #ifdef __aarch64__ if (!(!b->are_values_constant() && b->tensor_shape().z() > 1)) // Disable batch matmul as optimized GeMM handles batching differently. @@ -385,6 +375,10 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a, int32_t a_offset = a->quantization_info().uniform().offset; int32_t b_offset = b->quantization_info().uniform().offset; + // Offset kernel is need if offset is non-zero or it may change (i.e. dynamic). + bool a_offset_kernel_needed = a_offset != 0 || a->quantization_info().is_dynamic(); + bool b_offset_kernel_needed = b_offset != 0 || b->quantization_info().is_dynamic(); + bool fuse_output_stage = info.gemmlowp_output_stage().type != GEMMLowpOutputStageType::NONE; if (fuse_output_stage) { @@ -392,31 +386,19 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a, a->clone()->set_tensor_shape(output->tensor_shape()).set_data_type(DataType::S32)); } - // Initialize assembly kernel meta-data - const AsmGemmInfo asm_info = init_assembly_metadata(info); - // Convert QASYMM8->QASYMM8_SIGNED - const int32_t offset_correction = 128; - const DataType dt = DataType::QASYMM8_SIGNED; - const UniformQuantizationInfo iqinfo = a_to_use->quantization_info().uniform(); - - TensorInfo signed_a = a_to_use->clone()->set_data_type(dt).set_quantization_info( - QuantizationInfo(iqinfo.scale, iqinfo.offset + offset_correction)); + TensorInfo signed_a{}; TensorInfo signed_output{}; - - bool flip_signedness = is_data_type_quantized_per_channel(b->data_type()) && + bool flip_signedness = is_data_type_quantized_per_channel(b->data_type()) && (a->data_type() == DataType::QASYMM8) && info.reshape_b_only_on_first_run(); - - // If inputs are mixed-sign but this machine does not support mixed sign kernels, - // flip the sign so matched-sign kernels can be used. - if (!flip_signedness && a->data_type() == DataType::QASYMM8 && b->data_type() == DataType::QASYMM8_SIGNED && - !bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, c, output, asm_info))) - { - flip_signedness = true; - } - if (flip_signedness) { + const int32_t offset_correction = 128; + const DataType dt = DataType::QASYMM8_SIGNED; + const UniformQuantizationInfo iqinfo = a_to_use->quantization_info().uniform(); + + signed_a = a_to_use->clone()->set_data_type(dt).set_quantization_info( + QuantizationInfo(iqinfo.scale, iqinfo.offset + offset_correction)); ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuConvertQuantizedSignednessKernel::validate(a_to_use, &signed_a)); a_to_use = &signed_a; a_offset = signed_a.quantization_info().uniform().offset; @@ -436,9 +418,8 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a, matrix_a_info = &signed_a; } - // Offset kernel is need if offset is non-zero or it may change (i.e. dynamic). - bool a_offset_kernel_needed = a_offset != 0 || a->quantization_info().is_dynamic(); - bool b_offset_kernel_needed = b_offset != 0 || b->quantization_info().is_dynamic(); + // Initialize assembly kernel meta-data + const AsmGemmInfo asm_info = init_assembly_metadata(info); // Check if we need to run the optimized assembly kernel bool run_optimised = false; diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h index 11fe6f9ef0..38121c9bb4 100644 --- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h +++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h @@ -81,13 +81,11 @@ public: * |src0 |src1 |src2 |dst | * |:--------------|:------------------|:--------|:--------------| * |QASYMM8 |QASYMM8 |S32 |QASYMM8 | - * |QASYMM8 |QASYMM8_SIGNED |S32 |QASYMM8 | * |QASYMM8 |QSYMM8_PER_CHANNEL |S32 |QASYMM8 | * |QASYMM8 |QSYMM8 |S32 |QASYMM8 | * |QASYMM8 |QASYMM8 |S32 |S32 | * |QASYMM8 |QSYMM8_PER_CHANNEL |S32 |S32 | * |QASYMM8 |QSYMM8 |S32 |S32 | - * |QASYMM8 |QASYMM8_SIGNED |F32 |F32 | * |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |QASYMM8_SIGNED | * |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32 |QASYMM8_SIGNED | * |QASYMM8_SIGNED |QSYMM8 |S32 |QASYMM8_SIGNED | diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp index 785837dbc6..fb9bc15212 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp @@ -45,7 +45,6 @@ namespace /** Run pretranspose_B_array in parallel (1D static scheduling) * * @tparam TypeInput - * @tparam TypeWeight * @tparam TypeOutput * * @param[in] gemm_asm GemmCommon kernel to run @@ -55,14 +54,14 @@ namespace * @param[in] src_multi_stride Stride in z ("multi") * @param[in] num_threads Number of threads to run this method. Must be >= 1 */ -template <typename TypeInput, typename TypeWeight, typename TypeOutput> -void run_parallel_pretranspose_B_array(arm_gemm::GemmCommon<TypeInput, TypeWeight, TypeOutput> *gemm_asm, - ITensor *dst, - const TypeWeight *src, - int src_ld, - int src_multi_stride, - unsigned int num_threads, - bool transpose) +template <typename TypeInput, typename TypeOutput> +void run_parallel_pretranspose_B_array(arm_gemm::GemmCommon<TypeInput, TypeOutput> *gemm_asm, + ITensor *dst, + const TypeInput *src, + int src_ld, + int src_multi_stride, + unsigned int num_threads, + bool transpose) { ARM_COMPUTE_ERROR_ON(gemm_asm == nullptr); ARM_COMPUTE_ERROR_ON(num_threads == 0); @@ -92,6 +91,14 @@ using namespace arm_compute::experimental; namespace { +struct free_delete +{ + void operator()(void *x) + { + free(x); + } +}; + struct Params { unsigned int M; @@ -106,13 +113,14 @@ struct Params Params extract_parameters(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info) { ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d); - Params p{/* M */ static_cast<unsigned int>(d->tensor_shape().y()), - /* N */ static_cast<unsigned int>(d->tensor_shape().x()), - /* K */ static_cast<unsigned int>(a->tensor_shape().x()), - /* batches */ 1, - /* multis */ 1, - /* sections */ 1, - /* indirect */ false}; + Params p; + p.M = d->tensor_shape().y(); + p.K = a->tensor_shape().x(); + p.N = d->tensor_shape().x(); + p.batches = 1; + p.multis = 1; + p.sections = 1; + p.indirect = false; if (info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect) { @@ -164,10 +172,13 @@ IScheduler::Hints scheduling_hint_heuristic(arm_gemm::GemmMethod method, DataTyp } /** Fallback in case ACL doesn't have a function */ -template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage = arm_gemm::Nothing> +template <typename TypeInput, typename TypeOutput, class OutputStage = arm_gemm::Nothing> class Fallback : public CpuGemmAssemblyDispatch::IFallback { public: + /** Destructor */ + ~Fallback() = default; + /** Initialise the functions's input and output. * * @param[in] a Input tensor containing the Matrix A. @@ -211,9 +222,7 @@ public: bool isVarWeightsKernel() const override { if (!_gemm_kernel_asm) - { return false; - } const arm_compute::WeightFormat wf = assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format); return wf != arm_compute::WeightFormat::UNSPECIFIED && wf != arm_compute::WeightFormat::ANY; @@ -242,7 +251,7 @@ private: /** Operator to transpose B before gemm or pretranspose_B_array*/ std::unique_ptr<CpuTranspose> _pre_pretranspose_b{nullptr}; /** Assembly Gemm kernel */ - std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeWeight, TypeOutput>> _gemm_kernel_asm{nullptr}; + std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{nullptr}; /** Optimised Arm® Neon™ kernel */ std::unique_ptr<INEKernel> _optimised_kernel{nullptr}; /** Assembly GEMM workspace tensor info */ @@ -264,22 +273,22 @@ private: /** Per channel quantization multipliers */ std::vector<int32_t> _multipliers{}; /** Indirect buffer */ - std::vector<const TypeInput *const *> _indirect_arg{}; - std::vector<const TypeInput *> _indirect_buf{}; - std::vector<TypeInput> _indirect_pad{}; - arm_gemm::ConvolutionParameters _cp{}; - experimental::MemoryRequirements _aux_mem{Count}; - bool _B_pretranspose_required{false}; - bool _is_b_constant{true}; - bool _is_c_constant{true}; - bool _run_pre_pretranspose_b{false}; - bool _B_pre_pretranspose_required{false}; + std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{}; + std::unique_ptr<const TypeInput *, free_delete> _indirect_buf{}; + std::vector<TypeInput> _indirect_pad{}; + arm_gemm::ConvolutionParameters _cp{}; + experimental::MemoryRequirements _aux_mem{Count}; + bool _B_pretranspose_required{false}; + bool _is_b_constant{true}; + bool _is_c_constant{true}; + bool _run_pre_pretranspose_b{false}; + bool _B_pre_pretranspose_required{false}; }; -template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage> +template <typename TypeInput, typename TypeOutput, class OutputStage> std::tuple<bool, const int32_t *, const int32_t *, const int32_t *> -Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts, - const std::vector<int32_t> &multipliers) +Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts, + const std::vector<int32_t> &multipliers) { _multipliers = multipliers; _shifts = shifts; @@ -296,8 +305,8 @@ Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::set_requantize_data(co return std::make_tuple(need_left, left_shifts.data(), right_shifts.data(), _multipliers.data()); } -template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage> -void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::prepare_indirect_buffer(ITensorPack &tensors) +template <typename TypeInput, typename TypeOutput, class OutputStage> +void Fallback<TypeInput, TypeOutput, OutputStage>::prepare_indirect_buffer(ITensorPack &tensors) { auto a = tensors.get_const_tensor(TensorType::ACL_SRC_0); const TypeInput *A_ptr = reinterpret_cast<TypeInput *>(a->buffer()); @@ -334,12 +343,14 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::prepare_indirect_ if (input_x < 0 || input_x >= _cp.input_width || input_y < 0 || input_y >= _cp.input_height) { - _indirect_buf[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] = + _indirect_buf + .get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] = _indirect_pad.data(); } else { - _indirect_buf[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] = + _indirect_buf + .get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] = A_ptr + (m * multi_stride_A + b * batch_stride_A + input_xy * stride_A); } } @@ -350,11 +361,11 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::prepare_indirect_ } } -template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage> -void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure_indirect(const ITensorInfo *a, - const ITensorInfo *b, - const ITensorInfo *d, - const AsmGemmInfo &info) +template <typename TypeInput, typename TypeOutput, class OutputStage> +void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITensorInfo *a, + const ITensorInfo *b, + const ITensorInfo *d, + const AsmGemmInfo &info) { ARM_COMPUTE_ERROR_ON(!(info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect)); @@ -364,13 +375,13 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure_indirec zeropad = a->quantization_info().uniform().offset; } - const auto input_width = static_cast<int64_t>(a->tensor_shape()[1]); - const auto input_height = static_cast<int64_t>(a->tensor_shape()[2]); - const auto input_channels = static_cast<int64_t>(a->tensor_shape()[0]); - const auto kernel_width = static_cast<int64_t>(b->tensor_shape()[2]); - const auto kernel_height = static_cast<int64_t>(b->tensor_shape()[3]); - const auto output_width = static_cast<int64_t>(d->tensor_shape()[1]); - const auto output_height = static_cast<int64_t>(d->tensor_shape()[2]); + const int64_t input_width = static_cast<int64_t>(a->tensor_shape()[1]); + const int64_t input_height = static_cast<int64_t>(a->tensor_shape()[2]); + const int64_t input_channels = static_cast<int64_t>(a->tensor_shape()[0]); + const int64_t kernel_width = static_cast<int64_t>(b->tensor_shape()[2]); + const int64_t kernel_height = static_cast<int64_t>(b->tensor_shape()[3]); + const int64_t output_width = static_cast<int64_t>(d->tensor_shape()[1]); + const int64_t output_height = static_cast<int64_t>(d->tensor_shape()[2]); _cp = {input_width, input_height, @@ -381,8 +392,6 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure_indirec output_height, info.ps_info.stride().first, info.ps_info.stride().second, - 1, - 1, info.padding_top, info.padding_left, zeropad}; @@ -405,8 +414,10 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure_indirec const int multi_size = batch_size * batches; const size_t multi_stride = multi_size / sizeof(TypeInputPtr); - _indirect_buf = std::vector<const TypeInput *>(multi_size * multis); - _indirect_arg = std::vector<const TypeInput *const *>(sizeof(TypeInput **) * kernel_hw * multis * batches); + _indirect_buf = std::unique_ptr<const TypeInput *, free_delete>( + reinterpret_cast<const TypeInput **>(malloc(multi_size * multis))); + _indirect_arg = std::unique_ptr<const TypeInput *const *, free_delete>( + reinterpret_cast<const TypeInput *const **>(malloc(sizeof(TypeInput **) * kernel_hw * multis * batches))); _indirect_pad = std::vector<TypeInput>(_cp.input_channels, TypeInput(zeropad)); // Set indirect argument @@ -417,28 +428,29 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure_indirec { for (int64_t kernel_xy = 0; kernel_xy < kernel_hw; kernel_xy++) { - _indirect_arg[pos++] = &_indirect_buf[m * multi_stride + b * batch_stride + kernel_xy * output_hw]; + (_indirect_arg.get())[pos++] = + _indirect_buf.get() + m * multi_stride + b * batch_stride + kernel_xy * output_hw; } } } - _gemm_kernel_asm->set_indirect_parameters(a->tensor_shape()[0], _indirect_arg.data()); + _gemm_kernel_asm->set_indirect_parameters(a->tensor_shape()[0], _indirect_arg.get()); } } -template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage> -void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure(const ITensorInfo *a, - const ITensorInfo *b, - const ITensorInfo *c, - ITensorInfo *d, - arm_gemm::GemmArgs args, - const AsmGemmInfo &gemm_info, - const OutputStage &os) +template <typename TypeInput, typename TypeOutput, class OutputStage> +void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *a, + const ITensorInfo *b, + const ITensorInfo *c, + ITensorInfo *d, + arm_gemm::GemmArgs args, + const AsmGemmInfo &gemm_info, + const OutputStage &os) { _is_b_constant = b->are_values_constant(); _is_c_constant = c ? c->are_values_constant() : true; - _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeWeight, TypeOutput, OutputStage>(args, os); + _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput, OutputStage>(args, os); if (_gemm_kernel_asm == nullptr) { //configuration not supported: Leave function unconfigured: @@ -448,7 +460,7 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure(const I arm_gemm::GemmConfig gemm_cfg = _gemm_kernel_asm->get_config(); // arm_compute wrapper for the Gemm object (see above) - auto acl_gemm_wrapper = std::make_unique<kernel::CpuGemmAssemblyWrapperKernel<TypeInput, TypeWeight, TypeOutput>>(); + auto acl_gemm_wrapper = std::make_unique<kernel::CpuGemmAssemblyWrapperKernel<TypeInput, TypeOutput>>(); ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr); acl_gemm_wrapper->configure(_gemm_kernel_asm.get(), gemm_cfg.filter); const size_t workspace_size = _gemm_kernel_asm->get_working_size(); @@ -537,8 +549,8 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure(const I } } -template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage> -void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::prepare(ITensorPack &tensors) +template <typename TypeInput, typename TypeOutput, class OutputStage> +void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors) { if (!_is_prepared) { @@ -576,17 +588,17 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::prepare(ITensorPa // Fixed format kernels need no pretranspose. ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format( assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format))); - const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size(); - const auto in1_ptr = reinterpret_cast<const TypeWeight *>( - b_to_use->buffer() + b_to_use->info()->offset_first_element_in_bytes()); - const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size(); + const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size(); + const auto in1_ptr = reinterpret_cast<const TypeInput *>(b_to_use->buffer() + + b_to_use->info()->offset_first_element_in_bytes()); + const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size(); CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, false); ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr); const bool kernel_supports_transpose = _gemm_kernel_asm->B_pretranspose_supports_transpose(); - run_parallel_pretranspose_B_array<TypeInput, TypeWeight, TypeOutput>( + run_parallel_pretranspose_B_array<TypeInput, TypeOutput>( _gemm_kernel_asm.get(), pretranspose.get(), in1_ptr, ldb, multi_stride_b, NEScheduler::get().num_threads(), _B_pre_pretranspose_required && kernel_supports_transpose); @@ -604,20 +616,20 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::prepare(ITensorPa } } -template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage> -bool Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::is_configured() const +template <typename TypeInput, typename TypeOutput, class OutputStage> +bool Fallback<TypeInput, TypeOutput, OutputStage>::is_configured() const { return _optimised_kernel != nullptr; } -template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage> -experimental::MemoryRequirements Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::workspace() const +template <typename TypeInput, typename TypeOutput, class OutputStage> +experimental::MemoryRequirements Fallback<TypeInput, TypeOutput, OutputStage>::workspace() const { return _aux_mem; } -template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage> -void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::run(ITensorPack &tensors) +template <typename TypeInput, typename TypeOutput, class OutputStage> +void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors) { auto a = tensors.get_const_tensor(TensorType::ACL_SRC_0); auto b = tensors.get_const_tensor(TensorType::ACL_SRC_1); @@ -651,8 +663,8 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::run(ITensorPack & const int multi_stride_d = d->info()->strides_in_bytes()[d_multi_idx] / d->info()->element_size(); auto in0_ptr = reinterpret_cast<const TypeInput *>(a->buffer() + a->info()->offset_first_element_in_bytes()); - const TypeWeight *in1_ptr = nullptr; - auto out_ptr = reinterpret_cast<TypeOutput *>(d->buffer() + d->info()->offset_first_element_in_bytes()); + const TypeInput *in1_ptr = nullptr; + auto out_ptr = reinterpret_cast<TypeOutput *>(d->buffer() + d->info()->offset_first_element_in_bytes()); const ITensor *b_to_use = b; @@ -674,8 +686,8 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::run(ITensorPack & { ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size(); multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size(); - in1_ptr = reinterpret_cast<const TypeWeight *>(b_to_use->buffer() + - b_to_use->info()->offset_first_element_in_bytes()); + in1_ptr = + reinterpret_cast<const TypeInput *>(b_to_use->buffer() + b_to_use->info()->offset_first_element_in_bytes()); } // If necessary, run pretranspose every time if either weights or biases are non-constant @@ -694,8 +706,8 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::run(ITensorPack & ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format( assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format))); const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size(); - const auto b_ptr = reinterpret_cast<const TypeWeight *>(b_to_use->buffer() + - b_to_use->info()->offset_first_element_in_bytes()); + const auto b_ptr = reinterpret_cast<const TypeInput *>(b_to_use->buffer() + + b_to_use->info()->offset_first_element_in_bytes()); const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size(); CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, true); @@ -708,7 +720,7 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::run(ITensorPack & else { const bool kernel_supports_transpose = _gemm_kernel_asm->B_pretranspose_supports_transpose(); - run_parallel_pretranspose_B_array<TypeInput, TypeWeight, TypeOutput>( + run_parallel_pretranspose_B_array<TypeInput, TypeOutput>( _gemm_kernel_asm.get(), pretranspose.get(), b_ptr, ldb, multi_stride_b, NEScheduler::get().num_threads(), _B_pre_pretranspose_required && kernel_supports_transpose); } @@ -732,7 +744,7 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::run(ITensorPack & if (split_dim != IScheduler::split_dimensions_all) { // Make sure the kernel does not expect more threads than we can actually spawn - const unsigned int num_iterations = _optimised_kernel->window().num_iterations(split_dim); + const unsigned int num_iterations = _optimised_kernel.get()->window().num_iterations(split_dim); num_threads = std::min(num_iterations, num_threads); } _gemm_kernel_asm->set_nthreads(num_threads); @@ -763,7 +775,7 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::run(ITensorPack & NEScheduler::get().schedule(_optimised_kernel.get(), scheduling_hint); } -template <typename TypeInput, typename TypeWeight, typename TypeOutput> +template <typename TypeInput, typename TypeOutput> void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, const ITensorInfo *a, const ITensorInfo *b, @@ -782,12 +794,12 @@ void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_ge info.fixed_format, info.fast_mode, info.accumulate, &cfg); // Create arm_gemm fallback - auto fallback = std::make_unique<Fallback<TypeInput, TypeWeight, TypeOutput>>(); + auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput>>(); fallback->configure(a, b, c, d, args, info); arm_gemm = std::move(fallback); } -template <typename TypeInput, typename TypeWeight, typename TypeOutput> +template <typename TypeInput, typename TypeOutput> void create_arm_gemm_dequant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, const ITensorInfo *a, const ITensorInfo *b, @@ -808,7 +820,7 @@ void create_arm_gemm_dequant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> info.fixed_format, info.fast_mode, info.accumulate, &cfg); // Create arm_gemm fallback - auto fallback = std::make_unique<Fallback<TypeInput, TypeWeight, TypeOutput, arm_gemm::DequantizeFloat>>(); + auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::DequantizeFloat>>(); // Configure requantization info const GEMMLowpOutputStageInfo os_info = info.output_stage; @@ -820,7 +832,7 @@ void create_arm_gemm_dequant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> arm_gemm = std::move(fallback); } -template <typename TypeInput, typename TypeWeight, typename TypeOutput> +template <typename TypeInput, typename TypeOutput> void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, const ITensorInfo *a, const ITensorInfo *b, @@ -840,7 +852,7 @@ void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> & info.fixed_format, info.fast_mode, info.accumulate, &cfg); // Create arm_gemm fallback - auto fallback = std::make_unique<Fallback<TypeInput, TypeWeight, TypeOutput, arm_gemm::Requantize32>>(); + auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::Requantize32>>(); // Configure requantization info const int32_t negation = info.negated_offsets ? 1 : -1; @@ -893,12 +905,12 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format); arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fixed_format, info.fast_mode, info.accumulate, &cfg); - // TODO(COMPMID-6595): Incorporate info.transpose_b + // TODO: Incorporate info.transpose_b COMPMID-6595 switch (a->data_type()) { case DataType::F32: ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(arm_gemm::has_opt_gemm<float, float, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), + !(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for F32 input"); break; #ifdef __aarch64__ @@ -907,22 +919,13 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected if (d->data_type() == DataType::S32) { ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(arm_gemm::has_opt_gemm<uint8_t, uint8_t, uint32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, - {})), + !(arm_gemm::has_opt_gemm<uint8_t, uint32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for U8/QASYMM8 input and U32 output"); } - else if (b->data_type() == DataType::QASYMM8_SIGNED) - { - ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(arm_gemm::has_opt_gemm<uint8_t, int8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, - args, {})), - "We could not find an optimized kernel for U8 input with S8 weights and U8 output"); - } else { ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(arm_gemm::has_opt_gemm<uint8_t, uint8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, - args, {})), + !(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for U8 input and U8 output"); } break; @@ -931,15 +934,13 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected if (d->data_type() == DataType::S32) { ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(arm_gemm::has_opt_gemm<int8_t, int8_t, int32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, - {})), + !(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output"); } else { ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(arm_gemm::has_opt_gemm<int8_t, int8_t, int8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, - {})), + !(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for S8 input and S8 output"); } break; @@ -951,15 +952,13 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected if (d->data_type() == DataType::BFLOAT16) { ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(arm_gemm::has_opt_gemm<bfloat16, bfloat16, bfloat16, arm_gemm::Nothing>(arm_gemm_expected_wf, - args, {})), + !(arm_gemm::has_opt_gemm<bfloat16, bfloat16, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for BFLOAT16 input and BFLOAT16 output"); } else { ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(arm_gemm::has_opt_gemm<bfloat16, bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, - {})), + !(arm_gemm::has_opt_gemm<bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for BFLOAT16 input and F32 output"); } break; @@ -969,8 +968,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected #if defined(ENABLE_FP16_KERNELS) case DataType::F16: ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(arm_gemm::has_opt_gemm<float16_t, float16_t, float16_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, - {})), + !(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for F16 input and F16 output"); break; #endif /* ENABLE_FP16_KERNELS */ @@ -1011,7 +1009,7 @@ Status CpuGemmAssemblyDispatch::validate( ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(b, DataType::BFLOAT16); } - else if (!(a->data_type() == DataType::QASYMM8 && b->data_type() == DataType::QASYMM8_SIGNED)) + else { ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b); } @@ -1026,13 +1024,12 @@ Status CpuGemmAssemblyDispatch::validate( "Only U32 output supported for U8 input"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG( - a->data_type() == DataType::QASYMM8 && - (d->data_type() != DataType::QASYMM8 && d->data_type() != DataType::S32 && d->data_type() != DataType::F32), - "Only QASYMM8/S32/F32 output supported for QASYMM8 input"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && + (d->data_type() != DataType::QASYMM8 && d->data_type() != DataType::S32), + "Only QASYMM8/S32 output supported for QASYMM8 input"); arm_compute::WeightFormat expected_weight_format = arm_compute::WeightFormat::UNSPECIFIED; const Status ret = CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, info); - if (bool(ret) && expected_weight_format != arm_compute::WeightFormat::ANY) + if ((bool)ret && expected_weight_format != arm_compute::WeightFormat::ANY) { // Correctness check: if the format expected by the kernel is // not "any", make sure that the one found matches the format @@ -1065,44 +1062,33 @@ void CpuGemmAssemblyDispatch::configure( switch (a->data_type()) { case DataType::F32: - create_arm_gemm<float, float, float>(_arm_gemm, a, b, c, d, act, info); + create_arm_gemm<float, float>(_arm_gemm, a, b, c, d, act, info); break; #ifdef __aarch64__ case DataType::U8: case DataType::QASYMM8: - if (b->data_type() == DataType::S8 || b->data_type() == DataType::QASYMM8_SIGNED) - { - if (d->data_type() == DataType::F32) - { - create_arm_gemm_dequant<uint8_t, int8_t, float>(_arm_gemm, a, b, c, d, act, info); - } - else - { - create_arm_gemm_quant<uint8_t, int8_t, uint8_t>(_arm_gemm, a, b, c, d, act, info); - } - } - else if (d->data_type() == DataType::S32) + if (d->data_type() == DataType::S32) { - create_arm_gemm<uint8_t, uint8_t, uint32_t>(_arm_gemm, a, b, c, d, act, info); + create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, a, b, c, d, act, info); } else { - create_arm_gemm_quant<uint8_t, uint8_t, uint8_t>(_arm_gemm, a, b, c, d, act, info); + create_arm_gemm_quant<uint8_t, uint8_t>(_arm_gemm, a, b, c, d, act, info); } break; case DataType::S8: case DataType::QASYMM8_SIGNED: if (d->data_type() == DataType::S32) { - create_arm_gemm<int8_t, int8_t, int32_t>(_arm_gemm, a, b, c, d, act, info); + create_arm_gemm<int8_t, int32_t>(_arm_gemm, a, b, c, d, act, info); } else if (d->data_type() == DataType::F32) { - create_arm_gemm_dequant<int8_t, int8_t, float>(_arm_gemm, a, b, c, d, act, info); + create_arm_gemm_dequant<int8_t, float>(_arm_gemm, a, b, c, d, act, info); } else { - create_arm_gemm_quant<int8_t, int8_t, int8_t>(_arm_gemm, a, b, c, d, act, info); + create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, a, b, c, d, act, info); } break; #endif /* __aarch64__ */ @@ -1110,17 +1096,17 @@ void CpuGemmAssemblyDispatch::configure( case DataType::BFLOAT16: if (d->data_type() == DataType::BFLOAT16) { - create_arm_gemm<bfloat16, bfloat16, bfloat16>(_arm_gemm, a, b, c, d, act, info); + create_arm_gemm<bfloat16, bfloat16>(_arm_gemm, a, b, c, d, act, info); } else { - create_arm_gemm<bfloat16, bfloat16, float>(_arm_gemm, a, b, c, d, act, info); + create_arm_gemm<bfloat16, float>(_arm_gemm, a, b, c, d, act, info); } break; #endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ #ifdef ENABLE_FP16_KERNELS case DataType::F16: - create_arm_gemm<float16_t, float16_t, float16_t>(_arm_gemm, a, b, c, d, act, info); + create_arm_gemm<float16_t, float16_t>(_arm_gemm, a, b, c, d, act, info); break; #endif /* ENABLE_FP16_KERNELS */ default: |