diff options
Diffstat (limited to 'src/core/NEON/kernels/arm64/NEHGEMMAArch64FP16Kernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/arm64/NEHGEMMAArch64FP16Kernel.cpp | 21 |
1 files changed, 11 insertions, 10 deletions
diff --git a/src/core/NEON/kernels/arm64/NEHGEMMAArch64FP16Kernel.cpp b/src/core/NEON/kernels/arm64/NEHGEMMAArch64FP16Kernel.cpp index 225630434b..38b9102c20 100644 --- a/src/core/NEON/kernels/arm64/NEHGEMMAArch64FP16Kernel.cpp +++ b/src/core/NEON/kernels/arm64/NEHGEMMAArch64FP16Kernel.cpp @@ -50,20 +50,21 @@ namespace arm_compute namespace arm_compute { -void NEHGEMMAArch64FP16Kernel::internal_configure(const ITensor *input0, const ITensor *input1, ITensor *output, ITensor *workspace, float alpha, float beta, bool transform_0, bool transform_1) +void NEHGEMMAArch64FP16Kernel::internal_configure(const ITensor *input0, const ITensor *input1, ITensor *output, ITensor *workspace, float alpha, float beta, bool is_transposed_0, + bool is_transposed_1) { ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16); ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1, output); ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input0, input1, output); - _input0 = input0; - _input1 = input1; - _output = output; - _workspace = workspace; - _alpha = alpha; - _beta = beta; - _transform_0 = transform_0; - _transform_1 = transform_1; + _input0 = input0; + _input1 = input1; + _output = output; + _workspace = workspace; + _alpha = alpha; + _beta = beta; + _is_transposed_0 = is_transposed_0; + _is_transposed_1 = is_transposed_1; // Configure kernel window Window win = calculate_max_window(*output->info()); @@ -105,7 +106,7 @@ void NEHGEMMAArch64FP16Kernel::run(const Window &window, const ThreadInfo &info) Iterator in0(_input0, window); Iterator out(_output, window); - GemmInterleaved<hgemm_24x8, hgemm_24x8::operand_type, hgemm_24x8::result_type> gemm(&info.cpu_info, M, N, K, !_transform_0, !_transform_1); + GemmInterleaved<hgemm_24x8, hgemm_24x8::operand_type, hgemm_24x8::result_type> gemm(&info.cpu_info, M, N, K, _is_transposed_0, _is_transposed_1); constexpr size_t alignment = 4096; const size_t offset = (gemm.get_working_size() + alignment - 1) * info.thread_id; void *workspace = _workspace->buffer() + offset; |