aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm64/NEHGEMMAArch64FP16Kernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm64/NEHGEMMAArch64FP16Kernel.cpp')
-rw-r--r--src/core/NEON/kernels/arm64/NEHGEMMAArch64FP16Kernel.cpp21
1 files changed, 11 insertions, 10 deletions
diff --git a/src/core/NEON/kernels/arm64/NEHGEMMAArch64FP16Kernel.cpp b/src/core/NEON/kernels/arm64/NEHGEMMAArch64FP16Kernel.cpp
index 225630434b..38b9102c20 100644
--- a/src/core/NEON/kernels/arm64/NEHGEMMAArch64FP16Kernel.cpp
+++ b/src/core/NEON/kernels/arm64/NEHGEMMAArch64FP16Kernel.cpp
@@ -50,20 +50,21 @@ namespace arm_compute
namespace arm_compute
{
-void NEHGEMMAArch64FP16Kernel::internal_configure(const ITensor *input0, const ITensor *input1, ITensor *output, ITensor *workspace, float alpha, float beta, bool transform_0, bool transform_1)
+void NEHGEMMAArch64FP16Kernel::internal_configure(const ITensor *input0, const ITensor *input1, ITensor *output, ITensor *workspace, float alpha, float beta, bool is_transposed_0,
+ bool is_transposed_1)
{
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1, output);
ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input0, input1, output);
- _input0 = input0;
- _input1 = input1;
- _output = output;
- _workspace = workspace;
- _alpha = alpha;
- _beta = beta;
- _transform_0 = transform_0;
- _transform_1 = transform_1;
+ _input0 = input0;
+ _input1 = input1;
+ _output = output;
+ _workspace = workspace;
+ _alpha = alpha;
+ _beta = beta;
+ _is_transposed_0 = is_transposed_0;
+ _is_transposed_1 = is_transposed_1;
// Configure kernel window
Window win = calculate_max_window(*output->info());
@@ -105,7 +106,7 @@ void NEHGEMMAArch64FP16Kernel::run(const Window &window, const ThreadInfo &info)
Iterator in0(_input0, window);
Iterator out(_output, window);
- GemmInterleaved<hgemm_24x8, hgemm_24x8::operand_type, hgemm_24x8::result_type> gemm(&info.cpu_info, M, N, K, !_transform_0, !_transform_1);
+ GemmInterleaved<hgemm_24x8, hgemm_24x8::operand_type, hgemm_24x8::result_type> gemm(&info.cpu_info, M, N, K, _is_transposed_0, _is_transposed_1);
constexpr size_t alignment = 4096;
const size_t offset = (gemm.get_working_size() + alignment - 1) * info.thread_id;
void *workspace = _workspace->buffer() + offset;