aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/assembly/Helpers.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/assembly/Helpers.cpp')
-rw-r--r--src/core/NEON/kernels/assembly/Helpers.cpp100
1 files changed, 28 insertions, 72 deletions
diff --git a/src/core/NEON/kernels/assembly/Helpers.cpp b/src/core/NEON/kernels/assembly/Helpers.cpp
index 09ac08c0a4..3d8d66d7fc 100644
--- a/src/core/NEON/kernels/assembly/Helpers.cpp
+++ b/src/core/NEON/kernels/assembly/Helpers.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,91 +24,47 @@
#include "arm_compute/core/NEON/kernels/assembly/Helpers.h"
-#include "NEGEMMInterleavedStrategies.h"
+#include "arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp"
namespace arm_compute
{
-namespace
-{
-template <typename InputType, bool use_dot = false>
-BlockSizes calculate_block_sizes_template(const CPUInfo &ci, unsigned int M, unsigned int N, unsigned int K)
-{
- using strategy = typename Kernel<InputType, use_dot>::strategy;
- return calculate_block_sizes<strategy>(ci, M, N, K);
-}
-} // namespace
-
-const char *get_strategy_name(DataType input_type, bool use_dot)
+arm_gemm::KernelDescription get_gemm_info(DataType input_type,
+ const CPUInfo &ci,
+ const unsigned int num_threads,
+ const INEGEMMWrapperKernel::Params &p,
+ float alpha,
+ float beta,
+ bool pretranspose_hint)
{
switch(input_type)
{
- case DataType::F32:
- return Kernel<float>::name;
#ifdef __aarch64__
- case DataType::U8:
case DataType::QASYMM8:
- if(use_dot)
- {
- return Kernel<uint8_t, true>::name;
- }
- else
- {
- return Kernel<uint8_t, false>::name;
- }
- case DataType::S8:
- if(use_dot)
- {
- return Kernel<int8_t, true>::name;
- }
- else
- {
- return Kernel<int8_t, false>::name;
- }
-#endif /* __aarch64__ */
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- case DataType::F16:
- return Kernel<__fp16>::name;
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
- default:
- ARM_COMPUTE_ERROR("DataType not supported");
- break;
- }
-}
-
-BlockSizes calculate_block_sizes_from_data_type(const CPUInfo &ci, unsigned int M, unsigned int N, unsigned int K, DataType input_type, bool use_dot)
-{
- switch(input_type)
- {
- case DataType::F32:
- return calculate_block_sizes_template<float>(ci, M, N, K);
-#ifdef __aarch64__
case DataType::U8:
- case DataType::QASYMM8:
- if(use_dot)
- {
- return calculate_block_sizes_template<uint8_t, true>(ci, M, N, K);
- }
- else
- {
- return calculate_block_sizes_template<uint8_t, false>(ci, M, N, K);
- }
+ {
+ arm_gemm::GemmArgs<uint32_t> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint);
+ return arm_gemm::get_gemm_method<uint8_t, uint32_t>(args);
+ }
case DataType::S8:
- if(use_dot)
- {
- return calculate_block_sizes_template<int8_t, true>(ci, M, N, K);
- }
- else
- {
- return calculate_block_sizes_template<int8_t, false>(ci, M, N, K);
- }
-#endif /* __aarch64__ */
+ {
+ arm_gemm::GemmArgs<int32_t> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint);
+ return arm_gemm::get_gemm_method<int8_t, int32_t>(args);
+ }
+#endif // __aarch64__
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
- return calculate_block_sizes_template<__fp16>(ci, M, N, K);
+ {
+ arm_gemm::GemmArgs<__fp16> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint);
+ return arm_gemm::get_gemm_method<__fp16, __fp16>(args);
+ }
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+ case DataType::F32:
+ {
+ arm_gemm::GemmArgs<float> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint);
+ return arm_gemm::get_gemm_method<float, float>(args);
+ }
default:
- ARM_COMPUTE_ERROR("DataType not supported");
- break;
+ return arm_gemm::KernelDescription();
}
}
} // namespace arm_compute