aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm64/NEGEMMLowpAArch64Kernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm64/NEGEMMLowpAArch64Kernel.cpp')
-rw-r--r--src/core/NEON/kernels/arm64/NEGEMMLowpAArch64Kernel.cpp29
1 files changed, 15 insertions, 14 deletions
diff --git a/src/core/NEON/kernels/arm64/NEGEMMLowpAArch64Kernel.cpp b/src/core/NEON/kernels/arm64/NEGEMMLowpAArch64Kernel.cpp
index db37201687..38f82f0407 100644
--- a/src/core/NEON/kernels/arm64/NEGEMMLowpAArch64Kernel.cpp
+++ b/src/core/NEON/kernels/arm64/NEGEMMLowpAArch64Kernel.cpp
@@ -56,7 +56,7 @@ NEGEMMLowpAArch64Kernel::NEGEMMLowpAArch64Kernel()
{
}
-void gemm_interleaved_s8(const ITensor *input0, const ITensor *input1, ITensor *output, ITensor *workspace, float alpha, float beta, bool transform_0, bool transform_1, const Window &window,
+void gemm_interleaved_s8(const ITensor *input0, const ITensor *input1, ITensor *output, ITensor *workspace, float alpha, float beta, bool is_transposed_0, bool is_transposed_1, const Window &window,
const ThreadInfo &info)
{
const int lda = input0->info()->strides_in_bytes().y();
@@ -77,7 +77,7 @@ void gemm_interleaved_s8(const ITensor *input0, const ITensor *input1, ITensor *
Iterator in0(input0, window);
Iterator out(output, window);
- GemmInterleaved<gemm_s8_4x4, int8_t, int32_t> gemm(&info.cpu_info, M, N, K, !transform_1, !transform_1);
+ GemmInterleaved<gemm_s8_4x4, int8_t, int32_t> gemm(&info.cpu_info, M, N, K, is_transposed_0, is_transposed_1);
constexpr size_t alignment = 4096;
const size_t offset = (gemm.get_working_size() + alignment - 1) * info.thread_id;
@@ -99,7 +99,7 @@ void gemm_interleaved_s8(const ITensor *input0, const ITensor *input1, ITensor *
in0, out);
}
-void gemm_interleaved_u8(const ITensor *input0, const ITensor *input1, ITensor *output, ITensor *workspace, float alpha, float beta, bool transform_0, bool transform_1, const Window &window,
+void gemm_interleaved_u8(const ITensor *input0, const ITensor *input1, ITensor *output, ITensor *workspace, float alpha, float beta, bool is_transposed_0, bool is_transposed_1, const Window &window,
const ThreadInfo &info)
{
const int lda = input0->info()->strides_in_bytes().y();
@@ -120,7 +120,7 @@ void gemm_interleaved_u8(const ITensor *input0, const ITensor *input1, ITensor *
Iterator in0(input0, window);
Iterator out(output, window);
- GemmInterleaved<gemm_u8_4x4, uint8_t, uint32_t> gemm(&info.cpu_info, M, N, K, !transform_1, !transform_1);
+ GemmInterleaved<gemm_u8_4x4, uint8_t, uint32_t> gemm(&info.cpu_info, M, N, K, is_transposed_0, is_transposed_1);
constexpr size_t alignment = 4096;
const size_t offset = (gemm.get_working_size() + alignment - 1) * info.thread_id;
@@ -142,20 +142,21 @@ void gemm_interleaved_u8(const ITensor *input0, const ITensor *input1, ITensor *
in0, out);
}
-void NEGEMMLowpAArch64Kernel::internal_configure(const ITensor *input0, const ITensor *input1, ITensor *output, ITensor *workspace, float alpha, float beta, bool transform_0, bool transform_1)
+void NEGEMMLowpAArch64Kernel::internal_configure(const ITensor *input0, const ITensor *input1, ITensor *output, ITensor *workspace, float alpha, float beta, bool is_transposed_0,
+ bool is_transposed_1)
{
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::S8, DataType::U8);
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32, DataType::U32);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1);
- _input0 = input0;
- _input1 = input1;
- _output = output;
- _workspace = workspace;
- _alpha = alpha;
- _beta = beta;
- _transform_0 = transform_0;
- _transform_1 = transform_1;
+ _input0 = input0;
+ _input1 = input1;
+ _output = output;
+ _workspace = workspace;
+ _alpha = alpha;
+ _beta = beta;
+ _is_transposed_0 = is_transposed_0;
+ _is_transposed_1 = is_transposed_1;
switch(input0->info()->data_type())
{
@@ -192,7 +193,7 @@ void NEGEMMLowpAArch64Kernel::run(const Window &window, const ThreadInfo &info)
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
ARM_COMPUTE_ERROR_ON(_func == nullptr);
- (*_func)(_input0, _input1, _output, _workspace, _alpha, _beta, _transform_0, _transform_1, window, info);
+ (*_func)(_input0, _input1, _output, _workspace, _alpha, _beta, _is_transposed_0, _is_transposed_1, window, info);
}
} // namespace arm_compute
#endif /* ARM_COMPUTE_AARCH64_V8A */