From 08c5a06e2b49df0d7912deedd6d26d2c603cfe58 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Thu, 14 Dec 2017 17:53:39 +0000 Subject: COMPMID-750: Fix assembly kernel interfaces Assembly kernel interfaces were wrongly translating the layout of the input matrices. Boolean flags transform0 and transform1 do not match the actual interface of the gemm assembly code which expects transpose0 and transposed1. Change-Id: Ia4df65a533834647fa63e78e8c897924793949df Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/113410 Tested-by: BSG Visual Compute Jenkins server to access repositories on http://mpd-gerrit.cambridge.arm.com Reviewed-by: Pablo Tello --- .../kernels/arm64/NEGEMMLowpAArch64A53Kernel.cpp | 31 ++++++++++++---------- 1 file changed, 17 insertions(+), 14 deletions(-) (limited to 'src/core/NEON/kernels/arm64/NEGEMMLowpAArch64A53Kernel.cpp') diff --git a/src/core/NEON/kernels/arm64/NEGEMMLowpAArch64A53Kernel.cpp b/src/core/NEON/kernels/arm64/NEGEMMLowpAArch64A53Kernel.cpp index e020cd9118..80606dcc07 100644 --- a/src/core/NEON/kernels/arm64/NEGEMMLowpAArch64A53Kernel.cpp +++ b/src/core/NEON/kernels/arm64/NEGEMMLowpAArch64A53Kernel.cpp @@ -56,7 +56,8 @@ NEGEMMLowpAArch64A53Kernel::NEGEMMLowpAArch64A53Kernel() { } -void gemm_interleaved_s16_12x8(const ITensor *input0, const ITensor *input1, ITensor *output, ITensor *workspace, float alpha, float beta, bool transform_0, bool transform_1, const Window &window, +void gemm_interleaved_s16_12x8(const ITensor *input0, const ITensor *input1, ITensor *output, ITensor *workspace, float alpha, float beta, bool is_transposed_0, bool is_transposed_1, + const Window &window, const ThreadInfo &info) { const int lda = input0->info()->strides_in_bytes().y(); @@ -77,7 +78,7 @@ void gemm_interleaved_s16_12x8(const ITensor *input0, const ITensor *input1, ITe Iterator in0(input0, window); Iterator out(output, window); - GemmInterleaved gemm(&info.cpu_info, M, N, K, !transform_1, !transform_1); + GemmInterleaved gemm(&info.cpu_info, M, N, K, is_transposed_0, is_transposed_1); constexpr size_t alignment = 4096; const size_t offset = (gemm.get_working_size() + alignment - 1) * info.thread_id; @@ -99,7 +100,8 @@ void gemm_interleaved_s16_12x8(const ITensor *input0, const ITensor *input1, ITe in0, out); } -void gemm_interleaved_u16_12x8(const ITensor *input0, const ITensor *input1, ITensor *output, ITensor *workspace, float alpha, float beta, bool transform_0, bool transform_1, const Window &window, +void gemm_interleaved_u16_12x8(const ITensor *input0, const ITensor *input1, ITensor *output, ITensor *workspace, float alpha, float beta, bool is_transposed_0, bool is_transposed_1, + const Window &window, const ThreadInfo &info) { const int lda = input0->info()->strides_in_bytes().y(); @@ -120,7 +122,7 @@ void gemm_interleaved_u16_12x8(const ITensor *input0, const ITensor *input1, ITe Iterator in0(input0, window); Iterator out(output, window); - GemmInterleaved gemm(&info.cpu_info, M, N, K, !transform_1, !transform_1); + GemmInterleaved gemm(&info.cpu_info, M, N, K, is_transposed_0, is_transposed_1); constexpr size_t alignment = 4096; const size_t offset = (gemm.get_working_size() + alignment - 1) * info.thread_id; @@ -142,20 +144,21 @@ void gemm_interleaved_u16_12x8(const ITensor *input0, const ITensor *input1, ITe in0, out); } -void NEGEMMLowpAArch64A53Kernel::internal_configure(const ITensor *input0, const ITensor *input1, ITensor *output, ITensor *workspace, float alpha, float beta, bool transform_0, bool transform_1) +void NEGEMMLowpAArch64A53Kernel::internal_configure(const ITensor *input0, const ITensor *input1, ITensor *output, ITensor *workspace, float alpha, float beta, bool is_transposed_0, + bool is_transposed_1) { ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::S8, DataType::U8); ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32); ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1); - _input0 = input0; - _input1 = input1; - _output = output; - _workspace = workspace; - _alpha = alpha; - _beta = beta; - _transform_0 = transform_0; - _transform_1 = transform_1; + _input0 = input0; + _input1 = input1; + _output = output; + _workspace = workspace; + _alpha = alpha; + _beta = beta; + _is_transposed_0 = is_transposed_0; + _is_transposed_1 = is_transposed_1; switch(input0->info()->data_type()) { @@ -192,7 +195,7 @@ void NEGEMMLowpAArch64A53Kernel::run(const Window &window, const ThreadInfo &inf ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window); ARM_COMPUTE_ERROR_ON(_func == nullptr); - (*_func)(_input0, _input1, _output, _workspace, _alpha, _beta, _transform_0, _transform_1, window, info); + (*_func)(_input0, _input1, _output, _workspace, _alpha, _beta, _is_transposed_0, _is_transposed_1, window, info); } } // namespace arm_compute #endif /* ARM_COMPUTE_AARCH64_V8A */ -- cgit v1.2.1