diff options
Diffstat (limited to 'src/core/NEON/kernels/arm64/NEGEMMLowpAArch64V8P4Kernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/arm64/NEGEMMLowpAArch64V8P4Kernel.cpp | 21 |
1 files changed, 11 insertions, 10 deletions
diff --git a/src/core/NEON/kernels/arm64/NEGEMMLowpAArch64V8P4Kernel.cpp b/src/core/NEON/kernels/arm64/NEGEMMLowpAArch64V8P4Kernel.cpp index e996e571ab..099934a49d 100644 --- a/src/core/NEON/kernels/arm64/NEGEMMLowpAArch64V8P4Kernel.cpp +++ b/src/core/NEON/kernels/arm64/NEGEMMLowpAArch64V8P4Kernel.cpp @@ -84,20 +84,21 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITe namespace arm_compute { -void NEGEMMLowpAArch64V8P4Kernel::internal_configure(const ITensor *input0, const ITensor *input1, ITensor *output, ITensor *workspace, float alpha, float beta, bool transform_0, bool transform_1) +void NEGEMMLowpAArch64V8P4Kernel::internal_configure(const ITensor *input0, const ITensor *input1, ITensor *output, ITensor *workspace, float alpha, float beta, bool is_transposed_0, + bool is_transposed_1) { // Perform validate step ARM_COMPUTE_ERROR_ON_NULLPTR(input0, input1, output); ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info())); - _input0 = input0; - _input1 = input1; - _output = output; - _workspace = workspace; - _alpha = alpha; - _beta = beta; - _transform_0 = transform_0; - _transform_1 = transform_1; + _input0 = input0; + _input1 = input1; + _output = output; + _workspace = workspace; + _alpha = alpha; + _beta = beta; + _is_transposed_0 = is_transposed_0; + _is_transposed_1 = is_transposed_1; // Configure kernel window auto win_config = validate_and_configure_window(input0->info(), input1->info(), output->info()); @@ -136,7 +137,7 @@ void NEGEMMLowpAArch64V8P4Kernel::run(const Window &window, const ThreadInfo &in Iterator in0(_input0, window); Iterator out(_output, window); - GemmInterleaved<gemm_u8_12x8, gemm_u8_12x8::operand_type, gemm_u8_12x8::result_type> gemm(&info.cpu_info, M, N, K, !_transform_1, !_transform_1); + GemmInterleaved<gemm_u8_12x8, gemm_u8_12x8::operand_type, gemm_u8_12x8::result_type> gemm(&info.cpu_info, M, N, K, _is_transposed_0, _is_transposed_1); constexpr size_t alignment = 4096; const size_t offset = (gemm.get_working_size() + alignment - 1) * info.thread_id; |