diff options
Diffstat (limited to 'src/runtime')
-rw-r--r-- | src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp | 24 |
1 files changed, 21 insertions, 3 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp index 0423777217..6e03ffa1bc 100644 --- a/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp +++ b/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp @@ -45,6 +45,7 @@ namespace arm_compute #include "arm_compute/core/NEON/kernels/assembly/kernels/a64_gemm_s16_12x8.hpp" #include "arm_compute/core/NEON/kernels/assembly/kernels/a64_gemm_s8_12x8.hpp" #include "arm_compute/core/NEON/kernels/assembly/kernels/a64_gemm_s8_4x4.hpp" +#include "arm_compute/core/NEON/kernels/assembly/kernels/a64_gemm_u16_12x8.hpp" #include "arm_compute/core/NEON/kernels/assembly/kernels/a64_gemm_u8_4x4.hpp" } // namespace arm_compute @@ -94,9 +95,26 @@ void NEGEMMLowpAssemblyMatrixMultiplyCore::configure(const ITensor *a, const ITe #elif defined(ARM_COMPUTE_AARCH64_V8A) if(ci.CPU == CPUTarget::A53) { - // Configure matrix multiply kernel - GemmInterleaved<gemm_s16_12x8, int8_t, int32_t> gemm(&ci, M, N, K, false, false); - _workspace.allocator()->init(TensorInfo(TensorShape{ (gemm.get_working_size() + workspace_alignment - 1) * NEScheduler::get().num_threads() }, 1, DataType::U8)); + switch(a->info()->data_type()) + { + case DataType::S8: + { + // Configure matrix multiply kernel + GemmInterleaved<gemm_s16_12x8, int8_t, int32_t> gemm(&ci, M, N, K, false, false); + _workspace.allocator()->init(TensorInfo(TensorShape{ (gemm.get_working_size() + workspace_alignment - 1) * NEScheduler::get().num_threads() }, 1, DataType::U8)); + } + break; + case DataType::U8: + { + // Configure matrix multiply kernel + GemmInterleaved<gemm_u16_12x8, uint8_t, uint32_t> gemm(&ci, M, N, K, false, false); + _workspace.allocator()->init(TensorInfo(TensorShape{ (gemm.get_working_size() + workspace_alignment - 1) * NEScheduler::get().num_threads() }, 1, DataType::U8)); + } + break; + default: + ARM_COMPUTE_ERROR("Datatype not supported"); + } + _memory_group.manage(&_workspace); // Configure matrix multiplication kernel auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpAArch64A53Kernel>(); |