diff options
Diffstat (limited to 'src/core/NEON/kernels/assembly/Helpers.cpp')
-rw-r--r-- | src/core/NEON/kernels/assembly/Helpers.cpp | 100 |
1 files changed, 28 insertions, 72 deletions
diff --git a/src/core/NEON/kernels/assembly/Helpers.cpp b/src/core/NEON/kernels/assembly/Helpers.cpp index 09ac08c0a4..3d8d66d7fc 100644 --- a/src/core/NEON/kernels/assembly/Helpers.cpp +++ b/src/core/NEON/kernels/assembly/Helpers.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -24,91 +24,47 @@ #include "arm_compute/core/NEON/kernels/assembly/Helpers.h" -#include "NEGEMMInterleavedStrategies.h" +#include "arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp" namespace arm_compute { -namespace -{ -template <typename InputType, bool use_dot = false> -BlockSizes calculate_block_sizes_template(const CPUInfo &ci, unsigned int M, unsigned int N, unsigned int K) -{ - using strategy = typename Kernel<InputType, use_dot>::strategy; - return calculate_block_sizes<strategy>(ci, M, N, K); -} -} // namespace - -const char *get_strategy_name(DataType input_type, bool use_dot) +arm_gemm::KernelDescription get_gemm_info(DataType input_type, + const CPUInfo &ci, + const unsigned int num_threads, + const INEGEMMWrapperKernel::Params &p, + float alpha, + float beta, + bool pretranspose_hint) { switch(input_type) { - case DataType::F32: - return Kernel<float>::name; #ifdef __aarch64__ - case DataType::U8: case DataType::QASYMM8: - if(use_dot) - { - return Kernel<uint8_t, true>::name; - } - else - { - return Kernel<uint8_t, false>::name; - } - case DataType::S8: - if(use_dot) - { - return Kernel<int8_t, true>::name; - } - else - { - return Kernel<int8_t, false>::name; - } -#endif /* __aarch64__ */ -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - case DataType::F16: - return Kernel<__fp16>::name; -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - default: - ARM_COMPUTE_ERROR("DataType not supported"); - break; - } -} - -BlockSizes calculate_block_sizes_from_data_type(const CPUInfo &ci, unsigned int M, unsigned int N, unsigned int K, DataType input_type, bool use_dot) -{ - switch(input_type) - { - case DataType::F32: - return calculate_block_sizes_template<float>(ci, M, N, K); -#ifdef __aarch64__ case DataType::U8: - case DataType::QASYMM8: - if(use_dot) - { - return calculate_block_sizes_template<uint8_t, true>(ci, M, N, K); - } - else - { - return calculate_block_sizes_template<uint8_t, false>(ci, M, N, K); - } + { + arm_gemm::GemmArgs<uint32_t> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint); + return arm_gemm::get_gemm_method<uint8_t, uint32_t>(args); + } case DataType::S8: - if(use_dot) - { - return calculate_block_sizes_template<int8_t, true>(ci, M, N, K); - } - else - { - return calculate_block_sizes_template<int8_t, false>(ci, M, N, K); - } -#endif /* __aarch64__ */ + { + arm_gemm::GemmArgs<int32_t> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint); + return arm_gemm::get_gemm_method<int8_t, int32_t>(args); + } +#endif // __aarch64__ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: - return calculate_block_sizes_template<__fp16>(ci, M, N, K); + { + arm_gemm::GemmArgs<__fp16> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint); + return arm_gemm::get_gemm_method<__fp16, __fp16>(args); + } #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ + case DataType::F32: + { + arm_gemm::GemmArgs<float> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint); + return arm_gemm::get_gemm_method<float, float>(args); + } default: - ARM_COMPUTE_ERROR("DataType not supported"); - break; + return arm_gemm::KernelDescription(); } } } // namespace arm_compute |