aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/CpuGemmMatrixAdditionKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/CpuGemmMatrixAdditionKernel.cpp')
-rw-r--r--src/cpu/kernels/CpuGemmMatrixAdditionKernel.cpp39
1 files changed, 15 insertions, 24 deletions
diff --git a/src/cpu/kernels/CpuGemmMatrixAdditionKernel.cpp b/src/cpu/kernels/CpuGemmMatrixAdditionKernel.cpp
index 6399ebbef4..fb1b70b91f 100644
--- a/src/cpu/kernels/CpuGemmMatrixAdditionKernel.cpp
+++ b/src/cpu/kernels/CpuGemmMatrixAdditionKernel.cpp
@@ -26,11 +26,12 @@
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Validate.h"
-#include "src/core/CPP/Validate.h"
-#include "src/core/NEON/NEFixedPoint.h"
+
#include "src/core/common/Registrars.h"
+#include "src/core/CPP/Validate.h"
#include "src/core/helpers/AutoConfiguration.h"
#include "src/core/helpers/WindowHelpers.h"
+#include "src/core/NEON/NEFixedPoint.h"
#include "src/cpu/kernels/gemm_matrix_add/list.h"
namespace arm_compute
{
@@ -40,24 +41,12 @@ namespace kernels
{
namespace
{
-static const std::vector<CpuGemmMatrixAdditionKernel::GemmMatrixAddKernel> available_kernels =
-{
- {
- "neon_fp32_gemm_matrix_add",
- [](const DataTypeISASelectorData & data)
- {
- return (data.dt == DataType::F32);
- },
- REGISTER_FP32_NEON(neon_fp32_gemm_matrix_add)
- },
- {
- "neon_fp16_gemm_matrix_add",
- [](const DataTypeISASelectorData & data)
- {
- return (data.dt == DataType::F16) && data.isa.fp16;
- },
- REGISTER_FP16_NEON(neon_fp16_gemm_matrix_add)
- },
+static const std::vector<CpuGemmMatrixAdditionKernel::GemmMatrixAddKernel> available_kernels = {
+ {"neon_fp32_gemm_matrix_add", [](const DataTypeISASelectorData &data) { return (data.dt == DataType::F32); },
+ REGISTER_FP32_NEON(neon_fp32_gemm_matrix_add)},
+ {"neon_fp16_gemm_matrix_add",
+ [](const DataTypeISASelectorData &data) { return (data.dt == DataType::F16) && data.isa.fp16; },
+ REGISTER_FP16_NEON(neon_fp16_gemm_matrix_add)},
};
} // namespace
@@ -71,7 +60,8 @@ void CpuGemmMatrixAdditionKernel::configure(const ITensorInfo *src, ITensorInfo
ARM_COMPUTE_ERROR_THROW_ON(CpuGemmMatrixAdditionKernel::validate(src, dst, beta));
_beta = beta;
- const auto uk = CpuGemmMatrixAdditionKernel::get_implementation(DataTypeISASelectorData{ src->data_type(), CPUInfo::get().get_isa() });
+ const auto uk = CpuGemmMatrixAdditionKernel::get_implementation(
+ DataTypeISASelectorData{src->data_type(), CPUInfo::get().get_isa()});
ARM_COMPUTE_ERROR_ON_NULLPTR(uk);
_func = uk->ukernel;
// Configure kernel window
@@ -87,7 +77,7 @@ Status CpuGemmMatrixAdditionKernel::validate(const ITensorInfo *src, const ITens
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F16, DataType::F32);
- if(dst->total_size() > 0)
+ if (dst->total_size() > 0)
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, dst);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(src, dst);
@@ -105,7 +95,7 @@ void CpuGemmMatrixAdditionKernel::run_op(ITensorPack &tensors, const Window &win
const ITensor *src = tensors.get_const_tensor(TensorType::ACL_SRC);
ITensor *dst = tensors.get_tensor(TensorType::ACL_DST);
- if(_beta != 0.0f)
+ if (_beta != 0.0f)
{
(*_func)(src, dst, window, _beta);
}
@@ -116,7 +106,8 @@ const char *CpuGemmMatrixAdditionKernel::name() const
return "CpuGemmMatrixAdditionKernel";
}
-const std::vector<CpuGemmMatrixAdditionKernel::GemmMatrixAddKernel> &CpuGemmMatrixAdditionKernel::get_available_kernels()
+const std::vector<CpuGemmMatrixAdditionKernel::GemmMatrixAddKernel> &
+CpuGemmMatrixAdditionKernel::get_available_kernels()
{
return available_kernels;
}