diff options
Diffstat (limited to 'src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp')
-rw-r--r-- | src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp | 270 |
1 files changed, 128 insertions, 142 deletions
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp index 785837dbc6..fb9bc15212 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp @@ -45,7 +45,6 @@ namespace /** Run pretranspose_B_array in parallel (1D static scheduling) * * @tparam TypeInput - * @tparam TypeWeight * @tparam TypeOutput * * @param[in] gemm_asm GemmCommon kernel to run @@ -55,14 +54,14 @@ namespace * @param[in] src_multi_stride Stride in z ("multi") * @param[in] num_threads Number of threads to run this method. Must be >= 1 */ -template <typename TypeInput, typename TypeWeight, typename TypeOutput> -void run_parallel_pretranspose_B_array(arm_gemm::GemmCommon<TypeInput, TypeWeight, TypeOutput> *gemm_asm, - ITensor *dst, - const TypeWeight *src, - int src_ld, - int src_multi_stride, - unsigned int num_threads, - bool transpose) +template <typename TypeInput, typename TypeOutput> +void run_parallel_pretranspose_B_array(arm_gemm::GemmCommon<TypeInput, TypeOutput> *gemm_asm, + ITensor *dst, + const TypeInput *src, + int src_ld, + int src_multi_stride, + unsigned int num_threads, + bool transpose) { ARM_COMPUTE_ERROR_ON(gemm_asm == nullptr); ARM_COMPUTE_ERROR_ON(num_threads == 0); @@ -92,6 +91,14 @@ using namespace arm_compute::experimental; namespace { +struct free_delete +{ + void operator()(void *x) + { + free(x); + } +}; + struct Params { unsigned int M; @@ -106,13 +113,14 @@ struct Params Params extract_parameters(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info) { ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d); - Params p{/* M */ static_cast<unsigned int>(d->tensor_shape().y()), - /* N */ static_cast<unsigned int>(d->tensor_shape().x()), - /* K */ static_cast<unsigned int>(a->tensor_shape().x()), - /* batches */ 1, - /* multis */ 1, - /* sections */ 1, - /* indirect */ false}; + Params p; + p.M = d->tensor_shape().y(); + p.K = a->tensor_shape().x(); + p.N = d->tensor_shape().x(); + p.batches = 1; + p.multis = 1; + p.sections = 1; + p.indirect = false; if (info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect) { @@ -164,10 +172,13 @@ IScheduler::Hints scheduling_hint_heuristic(arm_gemm::GemmMethod method, DataTyp } /** Fallback in case ACL doesn't have a function */ -template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage = arm_gemm::Nothing> +template <typename TypeInput, typename TypeOutput, class OutputStage = arm_gemm::Nothing> class Fallback : public CpuGemmAssemblyDispatch::IFallback { public: + /** Destructor */ + ~Fallback() = default; + /** Initialise the functions's input and output. * * @param[in] a Input tensor containing the Matrix A. @@ -211,9 +222,7 @@ public: bool isVarWeightsKernel() const override { if (!_gemm_kernel_asm) - { return false; - } const arm_compute::WeightFormat wf = assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format); return wf != arm_compute::WeightFormat::UNSPECIFIED && wf != arm_compute::WeightFormat::ANY; @@ -242,7 +251,7 @@ private: /** Operator to transpose B before gemm or pretranspose_B_array*/ std::unique_ptr<CpuTranspose> _pre_pretranspose_b{nullptr}; /** Assembly Gemm kernel */ - std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeWeight, TypeOutput>> _gemm_kernel_asm{nullptr}; + std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{nullptr}; /** Optimised Arm® Neon™ kernel */ std::unique_ptr<INEKernel> _optimised_kernel{nullptr}; /** Assembly GEMM workspace tensor info */ @@ -264,22 +273,22 @@ private: /** Per channel quantization multipliers */ std::vector<int32_t> _multipliers{}; /** Indirect buffer */ - std::vector<const TypeInput *const *> _indirect_arg{}; - std::vector<const TypeInput *> _indirect_buf{}; - std::vector<TypeInput> _indirect_pad{}; - arm_gemm::ConvolutionParameters _cp{}; - experimental::MemoryRequirements _aux_mem{Count}; - bool _B_pretranspose_required{false}; - bool _is_b_constant{true}; - bool _is_c_constant{true}; - bool _run_pre_pretranspose_b{false}; - bool _B_pre_pretranspose_required{false}; + std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{}; + std::unique_ptr<const TypeInput *, free_delete> _indirect_buf{}; + std::vector<TypeInput> _indirect_pad{}; + arm_gemm::ConvolutionParameters _cp{}; + experimental::MemoryRequirements _aux_mem{Count}; + bool _B_pretranspose_required{false}; + bool _is_b_constant{true}; + bool _is_c_constant{true}; + bool _run_pre_pretranspose_b{false}; + bool _B_pre_pretranspose_required{false}; }; -template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage> +template <typename TypeInput, typename TypeOutput, class OutputStage> std::tuple<bool, const int32_t *, const int32_t *, const int32_t *> -Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts, - const std::vector<int32_t> &multipliers) +Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts, + const std::vector<int32_t> &multipliers) { _multipliers = multipliers; _shifts = shifts; @@ -296,8 +305,8 @@ Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::set_requantize_data(co return std::make_tuple(need_left, left_shifts.data(), right_shifts.data(), _multipliers.data()); } -template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage> -void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::prepare_indirect_buffer(ITensorPack &tensors) +template <typename TypeInput, typename TypeOutput, class OutputStage> +void Fallback<TypeInput, TypeOutput, OutputStage>::prepare_indirect_buffer(ITensorPack &tensors) { auto a = tensors.get_const_tensor(TensorType::ACL_SRC_0); const TypeInput *A_ptr = reinterpret_cast<TypeInput *>(a->buffer()); @@ -334,12 +343,14 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::prepare_indirect_ if (input_x < 0 || input_x >= _cp.input_width || input_y < 0 || input_y >= _cp.input_height) { - _indirect_buf[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] = + _indirect_buf + .get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] = _indirect_pad.data(); } else { - _indirect_buf[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] = + _indirect_buf + .get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] = A_ptr + (m * multi_stride_A + b * batch_stride_A + input_xy * stride_A); } } @@ -350,11 +361,11 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::prepare_indirect_ } } -template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage> -void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure_indirect(const ITensorInfo *a, - const ITensorInfo *b, - const ITensorInfo *d, - const AsmGemmInfo &info) +template <typename TypeInput, typename TypeOutput, class OutputStage> +void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITensorInfo *a, + const ITensorInfo *b, + const ITensorInfo *d, + const AsmGemmInfo &info) { ARM_COMPUTE_ERROR_ON(!(info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect)); @@ -364,13 +375,13 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure_indirec zeropad = a->quantization_info().uniform().offset; } - const auto input_width = static_cast<int64_t>(a->tensor_shape()[1]); - const auto input_height = static_cast<int64_t>(a->tensor_shape()[2]); - const auto input_channels = static_cast<int64_t>(a->tensor_shape()[0]); - const auto kernel_width = static_cast<int64_t>(b->tensor_shape()[2]); - const auto kernel_height = static_cast<int64_t>(b->tensor_shape()[3]); - const auto output_width = static_cast<int64_t>(d->tensor_shape()[1]); - const auto output_height = static_cast<int64_t>(d->tensor_shape()[2]); + const int64_t input_width = static_cast<int64_t>(a->tensor_shape()[1]); + const int64_t input_height = static_cast<int64_t>(a->tensor_shape()[2]); + const int64_t input_channels = static_cast<int64_t>(a->tensor_shape()[0]); + const int64_t kernel_width = static_cast<int64_t>(b->tensor_shape()[2]); + const int64_t kernel_height = static_cast<int64_t>(b->tensor_shape()[3]); + const int64_t output_width = static_cast<int64_t>(d->tensor_shape()[1]); + const int64_t output_height = static_cast<int64_t>(d->tensor_shape()[2]); _cp = {input_width, input_height, @@ -381,8 +392,6 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure_indirec output_height, info.ps_info.stride().first, info.ps_info.stride().second, - 1, - 1, info.padding_top, info.padding_left, zeropad}; @@ -405,8 +414,10 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure_indirec const int multi_size = batch_size * batches; const size_t multi_stride = multi_size / sizeof(TypeInputPtr); - _indirect_buf = std::vector<const TypeInput *>(multi_size * multis); - _indirect_arg = std::vector<const TypeInput *const *>(sizeof(TypeInput **) * kernel_hw * multis * batches); + _indirect_buf = std::unique_ptr<const TypeInput *, free_delete>( + reinterpret_cast<const TypeInput **>(malloc(multi_size * multis))); + _indirect_arg = std::unique_ptr<const TypeInput *const *, free_delete>( + reinterpret_cast<const TypeInput *const **>(malloc(sizeof(TypeInput **) * kernel_hw * multis * batches))); _indirect_pad = std::vector<TypeInput>(_cp.input_channels, TypeInput(zeropad)); // Set indirect argument @@ -417,28 +428,29 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure_indirec { for (int64_t kernel_xy = 0; kernel_xy < kernel_hw; kernel_xy++) { - _indirect_arg[pos++] = &_indirect_buf[m * multi_stride + b * batch_stride + kernel_xy * output_hw]; + (_indirect_arg.get())[pos++] = + _indirect_buf.get() + m * multi_stride + b * batch_stride + kernel_xy * output_hw; } } } - _gemm_kernel_asm->set_indirect_parameters(a->tensor_shape()[0], _indirect_arg.data()); + _gemm_kernel_asm->set_indirect_parameters(a->tensor_shape()[0], _indirect_arg.get()); } } -template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage> -void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure(const ITensorInfo *a, - const ITensorInfo *b, - const ITensorInfo *c, - ITensorInfo *d, - arm_gemm::GemmArgs args, - const AsmGemmInfo &gemm_info, - const OutputStage &os) +template <typename TypeInput, typename TypeOutput, class OutputStage> +void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *a, + const ITensorInfo *b, + const ITensorInfo *c, + ITensorInfo *d, + arm_gemm::GemmArgs args, + const AsmGemmInfo &gemm_info, + const OutputStage &os) { _is_b_constant = b->are_values_constant(); _is_c_constant = c ? c->are_values_constant() : true; - _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeWeight, TypeOutput, OutputStage>(args, os); + _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput, OutputStage>(args, os); if (_gemm_kernel_asm == nullptr) { //configuration not supported: Leave function unconfigured: @@ -448,7 +460,7 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure(const I arm_gemm::GemmConfig gemm_cfg = _gemm_kernel_asm->get_config(); // arm_compute wrapper for the Gemm object (see above) - auto acl_gemm_wrapper = std::make_unique<kernel::CpuGemmAssemblyWrapperKernel<TypeInput, TypeWeight, TypeOutput>>(); + auto acl_gemm_wrapper = std::make_unique<kernel::CpuGemmAssemblyWrapperKernel<TypeInput, TypeOutput>>(); ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr); acl_gemm_wrapper->configure(_gemm_kernel_asm.get(), gemm_cfg.filter); const size_t workspace_size = _gemm_kernel_asm->get_working_size(); @@ -537,8 +549,8 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure(const I } } -template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage> -void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::prepare(ITensorPack &tensors) +template <typename TypeInput, typename TypeOutput, class OutputStage> +void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors) { if (!_is_prepared) { @@ -576,17 +588,17 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::prepare(ITensorPa // Fixed format kernels need no pretranspose. ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format( assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format))); - const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size(); - const auto in1_ptr = reinterpret_cast<const TypeWeight *>( - b_to_use->buffer() + b_to_use->info()->offset_first_element_in_bytes()); - const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size(); + const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size(); + const auto in1_ptr = reinterpret_cast<const TypeInput *>(b_to_use->buffer() + + b_to_use->info()->offset_first_element_in_bytes()); + const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size(); CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, false); ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr); const bool kernel_supports_transpose = _gemm_kernel_asm->B_pretranspose_supports_transpose(); - run_parallel_pretranspose_B_array<TypeInput, TypeWeight, TypeOutput>( + run_parallel_pretranspose_B_array<TypeInput, TypeOutput>( _gemm_kernel_asm.get(), pretranspose.get(), in1_ptr, ldb, multi_stride_b, NEScheduler::get().num_threads(), _B_pre_pretranspose_required && kernel_supports_transpose); @@ -604,20 +616,20 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::prepare(ITensorPa } } -template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage> -bool Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::is_configured() const +template <typename TypeInput, typename TypeOutput, class OutputStage> +bool Fallback<TypeInput, TypeOutput, OutputStage>::is_configured() const { return _optimised_kernel != nullptr; } -template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage> -experimental::MemoryRequirements Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::workspace() const +template <typename TypeInput, typename TypeOutput, class OutputStage> +experimental::MemoryRequirements Fallback<TypeInput, TypeOutput, OutputStage>::workspace() const { return _aux_mem; } -template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage> -void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::run(ITensorPack &tensors) +template <typename TypeInput, typename TypeOutput, class OutputStage> +void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors) { auto a = tensors.get_const_tensor(TensorType::ACL_SRC_0); auto b = tensors.get_const_tensor(TensorType::ACL_SRC_1); @@ -651,8 +663,8 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::run(ITensorPack & const int multi_stride_d = d->info()->strides_in_bytes()[d_multi_idx] / d->info()->element_size(); auto in0_ptr = reinterpret_cast<const TypeInput *>(a->buffer() + a->info()->offset_first_element_in_bytes()); - const TypeWeight *in1_ptr = nullptr; - auto out_ptr = reinterpret_cast<TypeOutput *>(d->buffer() + d->info()->offset_first_element_in_bytes()); + const TypeInput *in1_ptr = nullptr; + auto out_ptr = reinterpret_cast<TypeOutput *>(d->buffer() + d->info()->offset_first_element_in_bytes()); const ITensor *b_to_use = b; @@ -674,8 +686,8 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::run(ITensorPack & { ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size(); multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size(); - in1_ptr = reinterpret_cast<const TypeWeight *>(b_to_use->buffer() + - b_to_use->info()->offset_first_element_in_bytes()); + in1_ptr = + reinterpret_cast<const TypeInput *>(b_to_use->buffer() + b_to_use->info()->offset_first_element_in_bytes()); } // If necessary, run pretranspose every time if either weights or biases are non-constant @@ -694,8 +706,8 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::run(ITensorPack & ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format( assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format))); const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size(); - const auto b_ptr = reinterpret_cast<const TypeWeight *>(b_to_use->buffer() + - b_to_use->info()->offset_first_element_in_bytes()); + const auto b_ptr = reinterpret_cast<const TypeInput *>(b_to_use->buffer() + + b_to_use->info()->offset_first_element_in_bytes()); const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size(); CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, true); @@ -708,7 +720,7 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::run(ITensorPack & else { const bool kernel_supports_transpose = _gemm_kernel_asm->B_pretranspose_supports_transpose(); - run_parallel_pretranspose_B_array<TypeInput, TypeWeight, TypeOutput>( + run_parallel_pretranspose_B_array<TypeInput, TypeOutput>( _gemm_kernel_asm.get(), pretranspose.get(), b_ptr, ldb, multi_stride_b, NEScheduler::get().num_threads(), _B_pre_pretranspose_required && kernel_supports_transpose); } @@ -732,7 +744,7 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::run(ITensorPack & if (split_dim != IScheduler::split_dimensions_all) { // Make sure the kernel does not expect more threads than we can actually spawn - const unsigned int num_iterations = _optimised_kernel->window().num_iterations(split_dim); + const unsigned int num_iterations = _optimised_kernel.get()->window().num_iterations(split_dim); num_threads = std::min(num_iterations, num_threads); } _gemm_kernel_asm->set_nthreads(num_threads); @@ -763,7 +775,7 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::run(ITensorPack & NEScheduler::get().schedule(_optimised_kernel.get(), scheduling_hint); } -template <typename TypeInput, typename TypeWeight, typename TypeOutput> +template <typename TypeInput, typename TypeOutput> void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, const ITensorInfo *a, const ITensorInfo *b, @@ -782,12 +794,12 @@ void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_ge info.fixed_format, info.fast_mode, info.accumulate, &cfg); // Create arm_gemm fallback - auto fallback = std::make_unique<Fallback<TypeInput, TypeWeight, TypeOutput>>(); + auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput>>(); fallback->configure(a, b, c, d, args, info); arm_gemm = std::move(fallback); } -template <typename TypeInput, typename TypeWeight, typename TypeOutput> +template <typename TypeInput, typename TypeOutput> void create_arm_gemm_dequant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, const ITensorInfo *a, const ITensorInfo *b, @@ -808,7 +820,7 @@ void create_arm_gemm_dequant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> info.fixed_format, info.fast_mode, info.accumulate, &cfg); // Create arm_gemm fallback - auto fallback = std::make_unique<Fallback<TypeInput, TypeWeight, TypeOutput, arm_gemm::DequantizeFloat>>(); + auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::DequantizeFloat>>(); // Configure requantization info const GEMMLowpOutputStageInfo os_info = info.output_stage; @@ -820,7 +832,7 @@ void create_arm_gemm_dequant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> arm_gemm = std::move(fallback); } -template <typename TypeInput, typename TypeWeight, typename TypeOutput> +template <typename TypeInput, typename TypeOutput> void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, const ITensorInfo *a, const ITensorInfo *b, @@ -840,7 +852,7 @@ void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> & info.fixed_format, info.fast_mode, info.accumulate, &cfg); // Create arm_gemm fallback - auto fallback = std::make_unique<Fallback<TypeInput, TypeWeight, TypeOutput, arm_gemm::Requantize32>>(); + auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::Requantize32>>(); // Configure requantization info const int32_t negation = info.negated_offsets ? 1 : -1; @@ -893,12 +905,12 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format); arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fixed_format, info.fast_mode, info.accumulate, &cfg); - // TODO(COMPMID-6595): Incorporate info.transpose_b + // TODO: Incorporate info.transpose_b COMPMID-6595 switch (a->data_type()) { case DataType::F32: ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(arm_gemm::has_opt_gemm<float, float, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), + !(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for F32 input"); break; #ifdef __aarch64__ @@ -907,22 +919,13 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected if (d->data_type() == DataType::S32) { ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(arm_gemm::has_opt_gemm<uint8_t, uint8_t, uint32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, - {})), + !(arm_gemm::has_opt_gemm<uint8_t, uint32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for U8/QASYMM8 input and U32 output"); } - else if (b->data_type() == DataType::QASYMM8_SIGNED) - { - ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(arm_gemm::has_opt_gemm<uint8_t, int8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, - args, {})), - "We could not find an optimized kernel for U8 input with S8 weights and U8 output"); - } else { ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(arm_gemm::has_opt_gemm<uint8_t, uint8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, - args, {})), + !(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for U8 input and U8 output"); } break; @@ -931,15 +934,13 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected if (d->data_type() == DataType::S32) { ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(arm_gemm::has_opt_gemm<int8_t, int8_t, int32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, - {})), + !(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output"); } else { ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(arm_gemm::has_opt_gemm<int8_t, int8_t, int8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, - {})), + !(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for S8 input and S8 output"); } break; @@ -951,15 +952,13 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected if (d->data_type() == DataType::BFLOAT16) { ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(arm_gemm::has_opt_gemm<bfloat16, bfloat16, bfloat16, arm_gemm::Nothing>(arm_gemm_expected_wf, - args, {})), + !(arm_gemm::has_opt_gemm<bfloat16, bfloat16, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for BFLOAT16 input and BFLOAT16 output"); } else { ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(arm_gemm::has_opt_gemm<bfloat16, bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, - {})), + !(arm_gemm::has_opt_gemm<bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for BFLOAT16 input and F32 output"); } break; @@ -969,8 +968,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected #if defined(ENABLE_FP16_KERNELS) case DataType::F16: ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(arm_gemm::has_opt_gemm<float16_t, float16_t, float16_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, - {})), + !(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for F16 input and F16 output"); break; #endif /* ENABLE_FP16_KERNELS */ @@ -1011,7 +1009,7 @@ Status CpuGemmAssemblyDispatch::validate( ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(b, DataType::BFLOAT16); } - else if (!(a->data_type() == DataType::QASYMM8 && b->data_type() == DataType::QASYMM8_SIGNED)) + else { ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b); } @@ -1026,13 +1024,12 @@ Status CpuGemmAssemblyDispatch::validate( "Only U32 output supported for U8 input"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG( - a->data_type() == DataType::QASYMM8 && - (d->data_type() != DataType::QASYMM8 && d->data_type() != DataType::S32 && d->data_type() != DataType::F32), - "Only QASYMM8/S32/F32 output supported for QASYMM8 input"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && + (d->data_type() != DataType::QASYMM8 && d->data_type() != DataType::S32), + "Only QASYMM8/S32 output supported for QASYMM8 input"); arm_compute::WeightFormat expected_weight_format = arm_compute::WeightFormat::UNSPECIFIED; const Status ret = CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, info); - if (bool(ret) && expected_weight_format != arm_compute::WeightFormat::ANY) + if ((bool)ret && expected_weight_format != arm_compute::WeightFormat::ANY) { // Correctness check: if the format expected by the kernel is // not "any", make sure that the one found matches the format @@ -1065,44 +1062,33 @@ void CpuGemmAssemblyDispatch::configure( switch (a->data_type()) { case DataType::F32: - create_arm_gemm<float, float, float>(_arm_gemm, a, b, c, d, act, info); + create_arm_gemm<float, float>(_arm_gemm, a, b, c, d, act, info); break; #ifdef __aarch64__ case DataType::U8: case DataType::QASYMM8: - if (b->data_type() == DataType::S8 || b->data_type() == DataType::QASYMM8_SIGNED) - { - if (d->data_type() == DataType::F32) - { - create_arm_gemm_dequant<uint8_t, int8_t, float>(_arm_gemm, a, b, c, d, act, info); - } - else - { - create_arm_gemm_quant<uint8_t, int8_t, uint8_t>(_arm_gemm, a, b, c, d, act, info); - } - } - else if (d->data_type() == DataType::S32) + if (d->data_type() == DataType::S32) { - create_arm_gemm<uint8_t, uint8_t, uint32_t>(_arm_gemm, a, b, c, d, act, info); + create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, a, b, c, d, act, info); } else { - create_arm_gemm_quant<uint8_t, uint8_t, uint8_t>(_arm_gemm, a, b, c, d, act, info); + create_arm_gemm_quant<uint8_t, uint8_t>(_arm_gemm, a, b, c, d, act, info); } break; case DataType::S8: case DataType::QASYMM8_SIGNED: if (d->data_type() == DataType::S32) { - create_arm_gemm<int8_t, int8_t, int32_t>(_arm_gemm, a, b, c, d, act, info); + create_arm_gemm<int8_t, int32_t>(_arm_gemm, a, b, c, d, act, info); } else if (d->data_type() == DataType::F32) { - create_arm_gemm_dequant<int8_t, int8_t, float>(_arm_gemm, a, b, c, d, act, info); + create_arm_gemm_dequant<int8_t, float>(_arm_gemm, a, b, c, d, act, info); } else { - create_arm_gemm_quant<int8_t, int8_t, int8_t>(_arm_gemm, a, b, c, d, act, info); + create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, a, b, c, d, act, info); } break; #endif /* __aarch64__ */ @@ -1110,17 +1096,17 @@ void CpuGemmAssemblyDispatch::configure( case DataType::BFLOAT16: if (d->data_type() == DataType::BFLOAT16) { - create_arm_gemm<bfloat16, bfloat16, bfloat16>(_arm_gemm, a, b, c, d, act, info); + create_arm_gemm<bfloat16, bfloat16>(_arm_gemm, a, b, c, d, act, info); } else { - create_arm_gemm<bfloat16, bfloat16, float>(_arm_gemm, a, b, c, d, act, info); + create_arm_gemm<bfloat16, float>(_arm_gemm, a, b, c, d, act, info); } break; #endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ #ifdef ENABLE_FP16_KERNELS case DataType::F16: - create_arm_gemm<float16_t, float16_t, float16_t>(_arm_gemm, a, b, c, d, act, info); + create_arm_gemm<float16_t, float16_t>(_arm_gemm, a, b, c, d, act, info); break; #endif /* ENABLE_FP16_KERNELS */ default: |