aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/assembly/Helpers.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/assembly/Helpers.cpp')
-rw-r--r--src/core/NEON/kernels/assembly/Helpers.cpp13
1 files changed, 5 insertions, 8 deletions
diff --git a/src/core/NEON/kernels/assembly/Helpers.cpp b/src/core/NEON/kernels/assembly/Helpers.cpp
index 3d8d66d7fc..93ea6c8d5e 100644
--- a/src/core/NEON/kernels/assembly/Helpers.cpp
+++ b/src/core/NEON/kernels/assembly/Helpers.cpp
@@ -24,16 +24,13 @@
#include "arm_compute/core/NEON/kernels/assembly/Helpers.h"
-#include "arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp"
-
namespace arm_compute
{
arm_gemm::KernelDescription get_gemm_info(DataType input_type,
const CPUInfo &ci,
const unsigned int num_threads,
const INEGEMMWrapperKernel::Params &p,
- float alpha,
- float beta,
+ arm_gemm::Activation activation,
bool pretranspose_hint)
{
switch(input_type)
@@ -42,25 +39,25 @@ arm_gemm::KernelDescription get_gemm_info(DataType in
case DataType::QASYMM8:
case DataType::U8:
{
- arm_gemm::GemmArgs<uint32_t> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint);
+ arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, activation, num_threads, pretranspose_hint);
return arm_gemm::get_gemm_method<uint8_t, uint32_t>(args);
}
case DataType::S8:
{
- arm_gemm::GemmArgs<int32_t> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint);
+ arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, activation, num_threads, pretranspose_hint);
return arm_gemm::get_gemm_method<int8_t, int32_t>(args);
}
#endif // __aarch64__
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
{
- arm_gemm::GemmArgs<__fp16> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint);
+ arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, activation, num_threads, pretranspose_hint);
return arm_gemm::get_gemm_method<__fp16, __fp16>(args);
}
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
case DataType::F32:
{
- arm_gemm::GemmArgs<float> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint);
+ arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, activation, num_threads, pretranspose_hint);
return arm_gemm::get_gemm_method<float, float>(args);
}
default: