diff options
author | Michael Tyler <michael.tyler@arm.com> | 2024-06-04 15:47:37 +0100 |
---|---|---|
committer | Michael Tyler <michael.tyler@arm.com> | 2024-06-25 09:10:13 +0000 |
commit | fc94f4d23abd4bc427b701f54ad85282e9ec7872 (patch) | |
tree | 5e2980599256e2b2f4374e5beb61596fc95c9d5a /src/cpu/operators | |
parent | c2237ec4094c7824f8f7e61bc89504d01c5b59ff (diff) | |
download | ComputeLibrary-fc94f4d23abd4bc427b701f54ad85282e9ec7872.tar.gz |
Update CPU kernels and add mixed sign GEMM support
- Add support for mixed sign quantized convolution.
- Add support for mixed sign dequantized GEMM.
- Add SME FP16 GEMV kernel.
- Change SME vector length function to use RDSVL instead of static variable.
- Add GEMM dilation support internally (not exposed yet).
- Remove unused "get_default_activation_values" functions.
- Add SVE fixed format interleaved BF16 DOT kernel.
- Updates and optimizations to assembly kernels.
Resolves COMPMID-6926
Change-Id: I227f502502611d4cc4111c89e30c53ce94079544
Signed-off-by: Michael Tyler <michael.tyler@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11570
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/cpu/operators')
-rw-r--r-- | src/cpu/operators/CpuConv2d.h | 13 | ||||
-rw-r--r-- | src/cpu/operators/CpuGemmConv2d.h | 3 | ||||
-rw-r--r-- | src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp | 75 | ||||
-rw-r--r-- | src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h | 2 | ||||
-rw-r--r-- | src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp | 270 |
5 files changed, 203 insertions, 160 deletions
diff --git a/src/cpu/operators/CpuConv2d.h b/src/cpu/operators/CpuConv2d.h index 71b9e15dc1..0012ff6609 100644 --- a/src/cpu/operators/CpuConv2d.h +++ b/src/cpu/operators/CpuConv2d.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021, 2023 Arm Limited. + * Copyright (c) 2017-2021, 2023-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,6 +21,10 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ + +#ifndef ACL_SRC_CPU_OPERATORS_CPUCONV2D_H +#define ACL_SRC_CPU_OPERATORS_CPUCONV2D_H + #include "arm_compute/function_info/ActivationLayerInfo.h" #include "src/core/common/Macros.h" @@ -81,6 +85,7 @@ public: * |F16 |F16 |F16 |F16 | * |F32 |F32 |F32 |F32 | * |QASYMM8 |QASYMM8 |S32 |QASYMM8 | + * |QASYMM8 |QASYMM8_SIGNED |S32 |QASYMM8 | * |QASYMM8 |QSYMM8_PER_CHANNEL |S32 |QASYMM8 | * |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |QASYMM8_SIGNED | * |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32 |QASYMM8_SIGNED | @@ -89,7 +94,7 @@ public: * while every optional dimension from 4 and above represent a batch of inputs. * Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32. * @param[in] weights Weights tensor info. Weights are 4D tensor with dimensions [kernel_x, kernel_y, IFM, OFM]. - * Data type supported: Same as @p src, also could be QSYMM8_PER_CHANNEL if input is QASYMM8/QASYMM8_SIGNED. + * Data type supported: Same as @p src, also could be QSYMM8_PER_CHANNEL or QASYMM8_SIGNED if input is QASYMM8/QASYMM8_SIGNED. * @param[in] biases Biases tensor info. Shared biases supported. Biases are 1D tensor with dimensions [OFM]. * Data type supported: Same as @p src, except for input of QASYMM8/QASYMM8_SIGNED type where biases should be of S32 type. * @param[out] dst Destination tensor info. 3 lower dimensions represent a single output [width, height, OFM], while the rest represent batch of outputs. @@ -135,7 +140,7 @@ public: * while every optional dimension from 4 and above represent a batch of inputs. * Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32. * @param[in] weights Weights tensor info. Weights are 4D tensor with dimensions [kernel_x, kernel_y, IFM, OFM]. - * Data type supported:Same as @p src, also could be QSYMM8_PER_CHANNEL if input is QASYMM8/QASYMM8_SIGNED. + * Data type supported:Same as @p src, also could be QSYMM8_PER_CHANNEL or QASYMM8_SIGNED if input is QASYMM8/QASYMM8_SIGNED. * @param[in] dst Destination tensor info. 3 lower dimensions represent a single output [width, height, OFM], while the rest represent batch of outputs. * Data types supported: Same as @p src. * @param[in] conv_info Contains padding and stride information described in @ref PadStrideInfo. @@ -167,3 +172,5 @@ private: }; } // namespace cpu } // namespace arm_compute + +#endif // ACL_SRC_CPU_OPERATORS_CPUCONV2D_H diff --git a/src/cpu/operators/CpuGemmConv2d.h b/src/cpu/operators/CpuGemmConv2d.h index 48a0d11107..ae5023a71a 100644 --- a/src/cpu/operators/CpuGemmConv2d.h +++ b/src/cpu/operators/CpuGemmConv2d.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023 Arm Limited. + * Copyright (c) 2021-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -76,6 +76,7 @@ public: * |F32 |F32 |F32 |F32 | * |BFLOAT16 |BFLOAT16 |BFLOAT16 |BFLOAT16 | * |QASYMM8 |QASYMM8 |S32 |QASYMM8 | + * |QASYMM8 |QASYMM8_SIGNED |S32 |QASYMM8 | * |QASYMM8 |QSYMM8_PER_CHANNEL |S32 |QASYMM8 | * |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |QASYMM8_SIGNED | * |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32 |QASYMM8_SIGNED | diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp index f3396fbb5c..1dbe3d8a31 100644 --- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp +++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp @@ -128,24 +128,31 @@ void CpuGemmLowpMatrixMultiplyCore::configure( _reshape_b_only_on_first_run; _gemm_info = gemm_info; - // Offset kernel is need if offset is non-zero or it may change (i.e. dynamic). - // It is not needed if the datatype is symmetric, because there is no offset - bool a_offset_kernel_needed = _a_offset != 0 || a->quantization_info().is_dynamic(); - bool b_offset_kernel_needed = _b_offset != 0 || b->quantization_info().is_dynamic(); + const ITensorInfo *a_to_use = a; - _asm_glue = std::make_unique<cpu::CpuGemmAssemblyDispatch>(); + // Initialize assembly kernel meta-data + const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info); - const ITensorInfo *a_to_use = a; + const int32_t offset_correction = 128; + const DataType dt = DataType::QASYMM8_SIGNED; + const UniformQuantizationInfo iqinfo = a_to_use->quantization_info().uniform(); + + _signed_a = a_to_use->clone()->set_data_type(dt).set_quantization_info( + QuantizationInfo(iqinfo.scale, iqinfo.offset + offset_correction)); + + // If inputs are mixed-sign but this machine does not support mixed sign kernels, + // flip the sign so matched-sign kernels can be used. + if (!_flip_signedness && a->data_type() == DataType::QASYMM8 && b->data_type() == DataType::QASYMM8_SIGNED && + !bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, c, dst, asm_info))) + { + _flip_signedness = true; + } + + _asm_glue = std::make_unique<cpu::CpuGemmAssemblyDispatch>(); // Convert to QASYMM8 -> QASYMM8_SIGNED and back if (_flip_signedness) { - const int32_t offset_correction = 128; - const DataType dt = DataType::QASYMM8_SIGNED; - const UniformQuantizationInfo iqinfo = a_to_use->quantization_info().uniform(); - - _signed_a = a_to_use->clone()->set_data_type(dt).set_quantization_info( - QuantizationInfo(iqinfo.scale, iqinfo.offset + offset_correction)); _convert_to_signed_asymm = std::make_unique<kernels::CpuConvertQuantizedSignednessKernel>(); _convert_to_signed_asymm->configure(a_to_use, &_signed_a); a_to_use = &_signed_a; @@ -166,6 +173,11 @@ void CpuGemmLowpMatrixMultiplyCore::configure( matrix_a = &_signed_a; } + // Offset kernel is need if offset is non-zero or it may change (i.e. dynamic). + // It is not needed if the datatype is symmetric, because there is no offset + bool a_offset_kernel_needed = _a_offset != 0 || a->quantization_info().is_dynamic(); + bool b_offset_kernel_needed = _b_offset != 0 || b->quantization_info().is_dynamic(); + // If GEMMLowpOutputStage != NONE, fuse the offset contribution with the output stage if (info.gemmlowp_output_stage().type != GEMMLowpOutputStageType::NONE) { @@ -173,8 +185,6 @@ void CpuGemmLowpMatrixMultiplyCore::configure( _mm_result_s32 = TensorInfo(dst->tensor_shape(), 1, DataType::S32); } - // Initialize assembly kernel meta-data - const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info); #ifdef __aarch64__ if (!(!b->are_values_constant() && b->tensor_shape().z() > 1)) // Disable batch matmul as optimized GeMM handles batching differently. @@ -375,10 +385,6 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a, int32_t a_offset = a->quantization_info().uniform().offset; int32_t b_offset = b->quantization_info().uniform().offset; - // Offset kernel is need if offset is non-zero or it may change (i.e. dynamic). - bool a_offset_kernel_needed = a_offset != 0 || a->quantization_info().is_dynamic(); - bool b_offset_kernel_needed = b_offset != 0 || b->quantization_info().is_dynamic(); - bool fuse_output_stage = info.gemmlowp_output_stage().type != GEMMLowpOutputStageType::NONE; if (fuse_output_stage) { @@ -386,19 +392,31 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a, a->clone()->set_tensor_shape(output->tensor_shape()).set_data_type(DataType::S32)); } + // Initialize assembly kernel meta-data + const AsmGemmInfo asm_info = init_assembly_metadata(info); + // Convert QASYMM8->QASYMM8_SIGNED - TensorInfo signed_a{}; + const int32_t offset_correction = 128; + const DataType dt = DataType::QASYMM8_SIGNED; + const UniformQuantizationInfo iqinfo = a_to_use->quantization_info().uniform(); + + TensorInfo signed_a = a_to_use->clone()->set_data_type(dt).set_quantization_info( + QuantizationInfo(iqinfo.scale, iqinfo.offset + offset_correction)); TensorInfo signed_output{}; - bool flip_signedness = is_data_type_quantized_per_channel(b->data_type()) && + + bool flip_signedness = is_data_type_quantized_per_channel(b->data_type()) && (a->data_type() == DataType::QASYMM8) && info.reshape_b_only_on_first_run(); - if (flip_signedness) + + // If inputs are mixed-sign but this machine does not support mixed sign kernels, + // flip the sign so matched-sign kernels can be used. + if (!flip_signedness && a->data_type() == DataType::QASYMM8 && b->data_type() == DataType::QASYMM8_SIGNED && + !bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, c, output, asm_info))) { - const int32_t offset_correction = 128; - const DataType dt = DataType::QASYMM8_SIGNED; - const UniformQuantizationInfo iqinfo = a_to_use->quantization_info().uniform(); + flip_signedness = true; + } - signed_a = a_to_use->clone()->set_data_type(dt).set_quantization_info( - QuantizationInfo(iqinfo.scale, iqinfo.offset + offset_correction)); + if (flip_signedness) + { ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuConvertQuantizedSignednessKernel::validate(a_to_use, &signed_a)); a_to_use = &signed_a; a_offset = signed_a.quantization_info().uniform().offset; @@ -418,8 +436,9 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a, matrix_a_info = &signed_a; } - // Initialize assembly kernel meta-data - const AsmGemmInfo asm_info = init_assembly_metadata(info); + // Offset kernel is need if offset is non-zero or it may change (i.e. dynamic). + bool a_offset_kernel_needed = a_offset != 0 || a->quantization_info().is_dynamic(); + bool b_offset_kernel_needed = b_offset != 0 || b->quantization_info().is_dynamic(); // Check if we need to run the optimized assembly kernel bool run_optimised = false; diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h index 38121c9bb4..11fe6f9ef0 100644 --- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h +++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h @@ -81,11 +81,13 @@ public: * |src0 |src1 |src2 |dst | * |:--------------|:------------------|:--------|:--------------| * |QASYMM8 |QASYMM8 |S32 |QASYMM8 | + * |QASYMM8 |QASYMM8_SIGNED |S32 |QASYMM8 | * |QASYMM8 |QSYMM8_PER_CHANNEL |S32 |QASYMM8 | * |QASYMM8 |QSYMM8 |S32 |QASYMM8 | * |QASYMM8 |QASYMM8 |S32 |S32 | * |QASYMM8 |QSYMM8_PER_CHANNEL |S32 |S32 | * |QASYMM8 |QSYMM8 |S32 |S32 | + * |QASYMM8 |QASYMM8_SIGNED |F32 |F32 | * |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |QASYMM8_SIGNED | * |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32 |QASYMM8_SIGNED | * |QASYMM8_SIGNED |QSYMM8 |S32 |QASYMM8_SIGNED | diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp index a4c856bb8f..156a798d50 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp @@ -45,6 +45,7 @@ namespace /** Run pretranspose_B_array in parallel (1D static scheduling) * * @tparam TypeInput + * @tparam TypeWeight * @tparam TypeOutput * * @param[in] gemm_asm GemmCommon kernel to run @@ -54,14 +55,14 @@ namespace * @param[in] src_multi_stride Stride in z ("multi") * @param[in] num_threads Number of threads to run this method. Must be >= 1 */ -template <typename TypeInput, typename TypeOutput> -void run_parallel_pretranspose_B_array(arm_gemm::GemmCommon<TypeInput, TypeOutput> *gemm_asm, - ITensor *dst, - const TypeInput *src, - int src_ld, - int src_multi_stride, - unsigned int num_threads, - bool transpose) +template <typename TypeInput, typename TypeWeight, typename TypeOutput> +void run_parallel_pretranspose_B_array(arm_gemm::GemmCommon<TypeInput, TypeWeight, TypeOutput> *gemm_asm, + ITensor *dst, + const TypeWeight *src, + int src_ld, + int src_multi_stride, + unsigned int num_threads, + bool transpose) { ARM_COMPUTE_ERROR_ON(gemm_asm == nullptr); ARM_COMPUTE_ERROR_ON(num_threads == 0); @@ -91,14 +92,6 @@ using namespace arm_compute::experimental; namespace { -struct free_delete -{ - void operator()(void *x) - { - free(x); - } -}; - struct Params { unsigned int M; @@ -113,14 +106,13 @@ struct Params Params extract_parameters(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info) { ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d); - Params p; - p.M = d->tensor_shape().y(); - p.K = a->tensor_shape().x(); - p.N = d->tensor_shape().x(); - p.batches = 1; - p.multis = 1; - p.sections = 1; - p.indirect = false; + Params p{/* M */ static_cast<unsigned int>(d->tensor_shape().y()), + /* N */ static_cast<unsigned int>(d->tensor_shape().x()), + /* K */ static_cast<unsigned int>(a->tensor_shape().x()), + /* batches */ 1, + /* multis */ 1, + /* sections */ 1, + /* indirect */ false}; if (info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect) { @@ -172,13 +164,10 @@ IScheduler::Hints scheduling_hint_heuristic(arm_gemm::GemmMethod method, DataTyp } /** Fallback in case ACL doesn't have a function */ -template <typename TypeInput, typename TypeOutput, class OutputStage = arm_gemm::Nothing> +template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage = arm_gemm::Nothing> class Fallback : public CpuGemmAssemblyDispatch::IFallback { public: - /** Destructor */ - ~Fallback() = default; - /** Initialise the functions's input and output. * * @param[in] a Input tensor containing the Matrix A. @@ -222,7 +211,9 @@ public: bool isVarWeightsKernel() const override { if (!_gemm_kernel_asm) + { return false; + } const arm_compute::WeightFormat wf = assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format); return wf != arm_compute::WeightFormat::UNSPECIFIED && wf != arm_compute::WeightFormat::ANY; @@ -251,7 +242,7 @@ private: /** Operator to transpose B before gemm or pretranspose_B_array*/ std::unique_ptr<CpuTranspose> _pre_pretranspose_b{nullptr}; /** Assembly Gemm kernel */ - std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{nullptr}; + std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeWeight, TypeOutput>> _gemm_kernel_asm{nullptr}; /** Optimised Arm® Neon™ kernel */ std::unique_ptr<INEKernel> _optimised_kernel{nullptr}; /** Assembly GEMM workspace tensor info */ @@ -273,22 +264,22 @@ private: /** Per channel quantization multipliers */ std::vector<int32_t> _multipliers{}; /** Indirect buffer */ - std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{}; - std::unique_ptr<const TypeInput *, free_delete> _indirect_buf{}; - std::vector<TypeInput> _indirect_pad{}; - arm_gemm::ConvolutionParameters _cp{}; - experimental::MemoryRequirements _aux_mem{Count}; - bool _B_pretranspose_required{false}; - bool _is_b_constant{true}; - bool _is_c_constant{true}; - bool _run_pre_pretranspose_b{false}; - bool _B_pre_pretranspose_required{false}; + std::vector<const TypeInput *const *> _indirect_arg{}; + std::vector<const TypeInput *> _indirect_buf{}; + std::vector<TypeInput> _indirect_pad{}; + arm_gemm::ConvolutionParameters _cp{}; + experimental::MemoryRequirements _aux_mem{Count}; + bool _B_pretranspose_required{false}; + bool _is_b_constant{true}; + bool _is_c_constant{true}; + bool _run_pre_pretranspose_b{false}; + bool _B_pre_pretranspose_required{false}; }; -template <typename TypeInput, typename TypeOutput, class OutputStage> +template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage> std::tuple<bool, const int32_t *, const int32_t *, const int32_t *> -Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts, - const std::vector<int32_t> &multipliers) +Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts, + const std::vector<int32_t> &multipliers) { _multipliers = multipliers; _shifts = shifts; @@ -305,8 +296,8 @@ Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(const std::vec return std::make_tuple(need_left, left_shifts.data(), right_shifts.data(), _multipliers.data()); } -template <typename TypeInput, typename TypeOutput, class OutputStage> -void Fallback<TypeInput, TypeOutput, OutputStage>::prepare_indirect_buffer(ITensorPack &tensors) +template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage> +void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::prepare_indirect_buffer(ITensorPack &tensors) { auto a = tensors.get_const_tensor(TensorType::ACL_SRC_0); const TypeInput *A_ptr = reinterpret_cast<TypeInput *>(a->buffer()); @@ -343,14 +334,12 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare_indirect_buffer(ITens if (input_x < 0 || input_x >= _cp.input_width || input_y < 0 || input_y >= _cp.input_height) { - _indirect_buf - .get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] = + _indirect_buf[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] = _indirect_pad.data(); } else { - _indirect_buf - .get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] = + _indirect_buf[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] = A_ptr + (m * multi_stride_A + b * batch_stride_A + input_xy * stride_A); } } @@ -361,11 +350,11 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare_indirect_buffer(ITens } } -template <typename TypeInput, typename TypeOutput, class OutputStage> -void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITensorInfo *a, - const ITensorInfo *b, - const ITensorInfo *d, - const AsmGemmInfo &info) +template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage> +void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure_indirect(const ITensorInfo *a, + const ITensorInfo *b, + const ITensorInfo *d, + const AsmGemmInfo &info) { ARM_COMPUTE_ERROR_ON(!(info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect)); @@ -375,13 +364,13 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITen zeropad = a->quantization_info().uniform().offset; } - const int64_t input_width = static_cast<int64_t>(a->tensor_shape()[1]); - const int64_t input_height = static_cast<int64_t>(a->tensor_shape()[2]); - const int64_t input_channels = static_cast<int64_t>(a->tensor_shape()[0]); - const int64_t kernel_width = static_cast<int64_t>(b->tensor_shape()[2]); - const int64_t kernel_height = static_cast<int64_t>(b->tensor_shape()[3]); - const int64_t output_width = static_cast<int64_t>(d->tensor_shape()[1]); - const int64_t output_height = static_cast<int64_t>(d->tensor_shape()[2]); + const auto input_width = static_cast<int64_t>(a->tensor_shape()[1]); + const auto input_height = static_cast<int64_t>(a->tensor_shape()[2]); + const auto input_channels = static_cast<int64_t>(a->tensor_shape()[0]); + const auto kernel_width = static_cast<int64_t>(b->tensor_shape()[2]); + const auto kernel_height = static_cast<int64_t>(b->tensor_shape()[3]); + const auto output_width = static_cast<int64_t>(d->tensor_shape()[1]); + const auto output_height = static_cast<int64_t>(d->tensor_shape()[2]); _cp = {input_width, input_height, @@ -392,6 +381,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITen output_height, info.ps_info.stride().first, info.ps_info.stride().second, + 1, + 1, info.padding_top, info.padding_left, zeropad}; @@ -414,10 +405,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITen const int multi_size = batch_size * batches; const size_t multi_stride = multi_size / sizeof(TypeInputPtr); - _indirect_buf = std::unique_ptr<const TypeInput *, free_delete>( - reinterpret_cast<const TypeInput **>(malloc(multi_size * multis))); - _indirect_arg = std::unique_ptr<const TypeInput *const *, free_delete>( - reinterpret_cast<const TypeInput *const **>(malloc(sizeof(TypeInput **) * kernel_hw * multis * batches))); + _indirect_buf = std::vector<const TypeInput *>(multi_size * multis); + _indirect_arg = std::vector<const TypeInput *const *>(sizeof(TypeInput **) * kernel_hw * multis * batches); _indirect_pad = std::vector<TypeInput>(_cp.input_channels, TypeInput(zeropad)); // Set indirect argument @@ -428,29 +417,28 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITen { for (int64_t kernel_xy = 0; kernel_xy < kernel_hw; kernel_xy++) { - (_indirect_arg.get())[pos++] = - _indirect_buf.get() + m * multi_stride + b * batch_stride + kernel_xy * output_hw; + _indirect_arg[pos++] = &_indirect_buf[m * multi_stride + b * batch_stride + kernel_xy * output_hw]; } } } - _gemm_kernel_asm->set_indirect_parameters(a->tensor_shape()[0], _indirect_arg.get()); + _gemm_kernel_asm->set_indirect_parameters(a->tensor_shape()[0], _indirect_arg.data()); } } -template <typename TypeInput, typename TypeOutput, class OutputStage> -void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *a, - const ITensorInfo *b, - const ITensorInfo *c, - ITensorInfo *d, - arm_gemm::GemmArgs args, - const AsmGemmInfo &gemm_info, - const OutputStage &os) +template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage> +void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure(const ITensorInfo *a, + const ITensorInfo *b, + const ITensorInfo *c, + ITensorInfo *d, + arm_gemm::GemmArgs args, + const AsmGemmInfo &gemm_info, + const OutputStage &os) { _is_b_constant = b->are_values_constant(); _is_c_constant = c ? c->are_values_constant() : true; - _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput, OutputStage>(args, os); + _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeWeight, TypeOutput, OutputStage>(args, os); if (_gemm_kernel_asm == nullptr) { //configuration not supported: Leave function unconfigured: @@ -460,7 +448,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo * arm_gemm::GemmConfig gemm_cfg = _gemm_kernel_asm->get_config(); // arm_compute wrapper for the Gemm object (see above) - auto acl_gemm_wrapper = std::make_unique<kernel::CpuGemmAssemblyWrapperKernel<TypeInput, TypeOutput>>(); + auto acl_gemm_wrapper = std::make_unique<kernel::CpuGemmAssemblyWrapperKernel<TypeInput, TypeWeight, TypeOutput>>(); ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr); acl_gemm_wrapper->configure(_gemm_kernel_asm.get(), gemm_cfg.filter); const size_t workspace_size = _gemm_kernel_asm->get_working_size(); @@ -549,8 +537,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo * } } -template <typename TypeInput, typename TypeOutput, class OutputStage> -void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors) +template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage> +void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::prepare(ITensorPack &tensors) { if (!_is_prepared) { @@ -588,17 +576,17 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors) // Fixed format kernels need no pretranspose. ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format( assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format))); - const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size(); - const auto in1_ptr = reinterpret_cast<const TypeInput *>(b_to_use->buffer() + - b_to_use->info()->offset_first_element_in_bytes()); - const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size(); + const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size(); + const auto in1_ptr = reinterpret_cast<const TypeWeight *>( + b_to_use->buffer() + b_to_use->info()->offset_first_element_in_bytes()); + const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size(); CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, false); ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr); const bool kernel_supports_transpose = _gemm_kernel_asm->B_pretranspose_supports_transpose(); - run_parallel_pretranspose_B_array<TypeInput, TypeOutput>( + run_parallel_pretranspose_B_array<TypeInput, TypeWeight, TypeOutput>( _gemm_kernel_asm.get(), pretranspose.get(), in1_ptr, ldb, multi_stride_b, NEScheduler::get().num_threads(), _B_pre_pretranspose_required && kernel_supports_transpose); @@ -616,20 +604,20 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors) } } -template <typename TypeInput, typename TypeOutput, class OutputStage> -bool Fallback<TypeInput, TypeOutput, OutputStage>::is_configured() const +template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage> +bool Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::is_configured() const { return _optimised_kernel != nullptr; } -template <typename TypeInput, typename TypeOutput, class OutputStage> -experimental::MemoryRequirements Fallback<TypeInput, TypeOutput, OutputStage>::workspace() const +template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage> +experimental::MemoryRequirements Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::workspace() const { return _aux_mem; } -template <typename TypeInput, typename TypeOutput, class OutputStage> -void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors) +template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage> +void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::run(ITensorPack &tensors) { auto a = tensors.get_const_tensor(TensorType::ACL_SRC_0); auto b = tensors.get_const_tensor(TensorType::ACL_SRC_1); @@ -663,8 +651,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors) const int multi_stride_d = d->info()->strides_in_bytes()[d_multi_idx] / d->info()->element_size(); auto in0_ptr = reinterpret_cast<const TypeInput *>(a->buffer() + a->info()->offset_first_element_in_bytes()); - const TypeInput *in1_ptr = nullptr; - auto out_ptr = reinterpret_cast<TypeOutput *>(d->buffer() + d->info()->offset_first_element_in_bytes()); + const TypeWeight *in1_ptr = nullptr; + auto out_ptr = reinterpret_cast<TypeOutput *>(d->buffer() + d->info()->offset_first_element_in_bytes()); const ITensor *b_to_use = b; @@ -686,8 +674,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors) { ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size(); multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size(); - in1_ptr = - reinterpret_cast<const TypeInput *>(b_to_use->buffer() + b_to_use->info()->offset_first_element_in_bytes()); + in1_ptr = reinterpret_cast<const TypeWeight *>(b_to_use->buffer() + + b_to_use->info()->offset_first_element_in_bytes()); } // If necessary, run pretranspose every time if either weights or biases are non-constant @@ -706,8 +694,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors) ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format( assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format))); const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size(); - const auto b_ptr = reinterpret_cast<const TypeInput *>(b_to_use->buffer() + - b_to_use->info()->offset_first_element_in_bytes()); + const auto b_ptr = reinterpret_cast<const TypeWeight *>(b_to_use->buffer() + + b_to_use->info()->offset_first_element_in_bytes()); const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size(); CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, true); @@ -720,7 +708,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors) else { const bool kernel_supports_transpose = _gemm_kernel_asm->B_pretranspose_supports_transpose(); - run_parallel_pretranspose_B_array<TypeInput, TypeOutput>( + run_parallel_pretranspose_B_array<TypeInput, TypeWeight, TypeOutput>( _gemm_kernel_asm.get(), pretranspose.get(), b_ptr, ldb, multi_stride_b, NEScheduler::get().num_threads(), _B_pre_pretranspose_required && kernel_supports_transpose); } @@ -744,7 +732,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors) if (split_dim != IScheduler::split_dimensions_all) { // Make sure the kernel does not expect more threads than we can actually spawn - const unsigned int num_iterations = _optimised_kernel.get()->window().num_iterations(split_dim); + const unsigned int num_iterations = _optimised_kernel->window().num_iterations(split_dim); num_threads = std::min(num_iterations, num_threads); } _gemm_kernel_asm->set_nthreads(num_threads); @@ -775,7 +763,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors) NEScheduler::get().schedule(_optimised_kernel.get(), scheduling_hint); } -template <typename TypeInput, typename TypeOutput> +template <typename TypeInput, typename TypeWeight, typename TypeOutput> void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, const ITensorInfo *a, const ITensorInfo *b, @@ -794,12 +782,12 @@ void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_ge info.fixed_format, info.fast_mode, info.accumulate, &cfg); // Create arm_gemm fallback - auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput>>(); + auto fallback = std::make_unique<Fallback<TypeInput, TypeWeight, TypeOutput>>(); fallback->configure(a, b, c, d, args, info); arm_gemm = std::move(fallback); } -template <typename TypeInput, typename TypeOutput> +template <typename TypeInput, typename TypeWeight, typename TypeOutput> void create_arm_gemm_dequant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, const ITensorInfo *a, const ITensorInfo *b, @@ -820,7 +808,7 @@ void create_arm_gemm_dequant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> info.fixed_format, info.fast_mode, info.accumulate, &cfg); // Create arm_gemm fallback - auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::DequantizeFloat>>(); + auto fallback = std::make_unique<Fallback<TypeInput, TypeWeight, TypeOutput, arm_gemm::DequantizeFloat>>(); // Configure requantization info const GEMMLowpOutputStageInfo os_info = info.output_stage; @@ -832,7 +820,7 @@ void create_arm_gemm_dequant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> arm_gemm = std::move(fallback); } -template <typename TypeInput, typename TypeOutput> +template <typename TypeInput, typename TypeWeight, typename TypeOutput> void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, const ITensorInfo *a, const ITensorInfo *b, @@ -852,7 +840,7 @@ void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> & info.fixed_format, info.fast_mode, info.accumulate, &cfg); // Create arm_gemm fallback - auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::Requantize32>>(); + auto fallback = std::make_unique<Fallback<TypeInput, TypeWeight, TypeOutput, arm_gemm::Requantize32>>(); // Configure requantization info const int32_t negation = info.negated_offsets ? 1 : -1; @@ -905,12 +893,12 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format); arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fixed_format, info.fast_mode, info.accumulate, &cfg); - // TODO: Incorporate info.transpose_b COMPMID-6595 + // TODO(COMPMID-6595): Incorporate info.transpose_b switch (a->data_type()) { case DataType::F32: ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), + !(arm_gemm::has_opt_gemm<float, float, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for F32 input"); break; #ifdef __aarch64__ @@ -919,13 +907,22 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected if (d->data_type() == DataType::S32) { ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(arm_gemm::has_opt_gemm<uint8_t, uint32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), + !(arm_gemm::has_opt_gemm<uint8_t, uint8_t, uint32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, + {})), "We could not find an optimized kernel for U8/QASYMM8 input and U32 output"); } + else if (b->data_type() == DataType::QASYMM8_SIGNED) + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG( + !(arm_gemm::has_opt_gemm<uint8_t, int8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, + args, {})), + "We could not find an optimized kernel for U8 input with S8 weights and U8 output"); + } else { ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})), + !(arm_gemm::has_opt_gemm<uint8_t, uint8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, + args, {})), "We could not find an optimized kernel for U8 input and U8 output"); } break; @@ -934,13 +931,15 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected if (d->data_type() == DataType::S32) { ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), + !(arm_gemm::has_opt_gemm<int8_t, int8_t, int32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, + {})), "We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output"); } else { ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})), + !(arm_gemm::has_opt_gemm<int8_t, int8_t, int8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, + {})), "We could not find an optimized kernel for S8 input and S8 output"); } break; @@ -952,13 +951,15 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected if (d->data_type() == DataType::BFLOAT16) { ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(arm_gemm::has_opt_gemm<bfloat16, bfloat16, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), + !(arm_gemm::has_opt_gemm<bfloat16, bfloat16, bfloat16, arm_gemm::Nothing>(arm_gemm_expected_wf, + args, {})), "We could not find an optimized kernel for BFLOAT16 input and BFLOAT16 output"); } else { ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(arm_gemm::has_opt_gemm<bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), + !(arm_gemm::has_opt_gemm<bfloat16, bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, + {})), "We could not find an optimized kernel for BFLOAT16 input and F32 output"); } break; @@ -968,7 +969,8 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected #if defined(ENABLE_FP16_KERNELS) case DataType::F16: ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), + !(arm_gemm::has_opt_gemm<float16_t, float16_t, float16_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, + {})), "We could not find an optimized kernel for F16 input and F16 output"); break; #endif /* ENABLE_FP16_KERNELS */ @@ -1009,7 +1011,7 @@ Status CpuGemmAssemblyDispatch::validate( ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(b, DataType::BFLOAT16); } - else + else if (!(a->data_type() == DataType::QASYMM8 && b->data_type() == DataType::QASYMM8_SIGNED)) { ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b); } @@ -1024,12 +1026,13 @@ Status CpuGemmAssemblyDispatch::validate( "Only U32 output supported for U8 input"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && - (d->data_type() != DataType::QASYMM8 && d->data_type() != DataType::S32), - "Only QASYMM8/S32 output supported for QASYMM8 input"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG( + a->data_type() == DataType::QASYMM8 && + (d->data_type() != DataType::QASYMM8 && d->data_type() != DataType::S32 && d->data_type() != DataType::F32), + "Only QASYMM8/S32/F32 output supported for QASYMM8 input"); arm_compute::WeightFormat expected_weight_format = arm_compute::WeightFormat::UNSPECIFIED; const Status ret = CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, info); - if ((bool)ret && expected_weight_format != arm_compute::WeightFormat::ANY) + if (bool(ret) && expected_weight_format != arm_compute::WeightFormat::ANY) { // Correctness check: if the format expected by the kernel is // not "any", make sure that the one found matches the format @@ -1062,33 +1065,44 @@ void CpuGemmAssemblyDispatch::configure( switch (a->data_type()) { case DataType::F32: - create_arm_gemm<float, float>(_arm_gemm, a, b, c, d, act, info); + create_arm_gemm<float, float, float>(_arm_gemm, a, b, c, d, act, info); break; #ifdef __aarch64__ case DataType::U8: case DataType::QASYMM8: - if (d->data_type() == DataType::S32) + if (b->data_type() == DataType::S8 || b->data_type() == DataType::QASYMM8_SIGNED) + { + if (d->data_type() == DataType::F32) + { + create_arm_gemm_dequant<uint8_t, int8_t, float>(_arm_gemm, a, b, c, d, act, info); + } + else + { + create_arm_gemm_quant<uint8_t, int8_t, uint8_t>(_arm_gemm, a, b, c, d, act, info); + } + } + else if (d->data_type() == DataType::S32) { - create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, a, b, c, d, act, info); + create_arm_gemm<uint8_t, uint8_t, uint32_t>(_arm_gemm, a, b, c, d, act, info); } else { - create_arm_gemm_quant<uint8_t, uint8_t>(_arm_gemm, a, b, c, d, act, info); + create_arm_gemm_quant<uint8_t, uint8_t, uint8_t>(_arm_gemm, a, b, c, d, act, info); } break; case DataType::S8: case DataType::QASYMM8_SIGNED: if (d->data_type() == DataType::S32) { - create_arm_gemm<int8_t, int32_t>(_arm_gemm, a, b, c, d, act, info); + create_arm_gemm<int8_t, int8_t, int32_t>(_arm_gemm, a, b, c, d, act, info); } else if (d->data_type() == DataType::F32) { - create_arm_gemm_dequant<int8_t, float>(_arm_gemm, a, b, c, d, act, info); + create_arm_gemm_dequant<int8_t, int8_t, float>(_arm_gemm, a, b, c, d, act, info); } else { - create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, a, b, c, d, act, info); + create_arm_gemm_quant<int8_t, int8_t, int8_t>(_arm_gemm, a, b, c, d, act, info); } break; #endif /* __aarch64__ */ @@ -1096,17 +1110,17 @@ void CpuGemmAssemblyDispatch::configure( case DataType::BFLOAT16: if (d->data_type() == DataType::BFLOAT16) { - create_arm_gemm<bfloat16, bfloat16>(_arm_gemm, a, b, c, d, act, info); + create_arm_gemm<bfloat16, bfloat16, bfloat16>(_arm_gemm, a, b, c, d, act, info); } else { - create_arm_gemm<bfloat16, float>(_arm_gemm, a, b, c, d, act, info); + create_arm_gemm<bfloat16, bfloat16, float>(_arm_gemm, a, b, c, d, act, info); } break; #endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ #ifdef ENABLE_FP16_KERNELS case DataType::F16: - create_arm_gemm<float16_t, float16_t>(_arm_gemm, a, b, c, d, act, info); + create_arm_gemm<float16_t, float16_t, float16_t>(_arm_gemm, a, b, c, d, act, info); break; #endif /* ENABLE_FP16_KERNELS */ default: |