diff options
Diffstat (limited to 'src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp | 43 |
1 files changed, 31 insertions, 12 deletions
diff --git a/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp b/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp index cd6aa553db..757dbbc399 100644 --- a/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp +++ b/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp @@ -32,15 +32,27 @@ #include <arm_neon.h> -using namespace arm_compute; - namespace arm_compute { -class Coordinates; -} // namespace arm_compute - namespace { +Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, float beta) +{ + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output); + ARM_COMPUTE_UNUSED(beta); + + ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32); + + if(output->total_size() > 0) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output); + } + + return Status{}; +} + void matrix_addition_f32(const ITensor *input, ITensor *output, const Window &window, float beta) { const float32x4_t beta_f32 = vdupq_n_f32(beta); @@ -101,12 +113,10 @@ NEGEMMMatrixAdditionKernel::NEGEMMMatrixAdditionKernel() void NEGEMMMatrixAdditionKernel::configure(const ITensor *input, ITensor *output, float beta) { - ARM_COMPUTE_ERROR_ON_CPU_F16_UNSUPPORTED(input); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F16, DataType::F32); - ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); - ARM_COMPUTE_ERROR_ON(input->info()->dimension(0) != output->info()->dimension(0)); - ARM_COMPUTE_ERROR_ON(input->info()->dimension(1) != output->info()->dimension(1)); + ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); + + // Perform validation step + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), beta)); switch(input->info()->data_type()) { @@ -123,13 +133,21 @@ void NEGEMMMatrixAdditionKernel::configure(const ITensor *input, ITensor *output break; } + // Configure kernel window constexpr unsigned int num_elems_processed_per_iteration = 16; - INESimpleKernel::configure(input, output, num_elems_processed_per_iteration); _beta = beta; } +Status NEGEMMMatrixAdditionKernel::validate(const ITensorInfo *input, const ITensorInfo *output, float beta) +{ + constexpr unsigned int num_elems_processed_per_iteration = 16; + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, beta)); + ARM_COMPUTE_RETURN_ON_ERROR(INESimpleKernel::validate(input->clone().get(), output->clone().get(), num_elems_processed_per_iteration)); + return Status{}; +} + void NEGEMMMatrixAdditionKernel::run(const Window &window, const ThreadInfo &info) { ARM_COMPUTE_UNUSED(info); @@ -141,3 +159,4 @@ void NEGEMMMatrixAdditionKernel::run(const Window &window, const ThreadInfo &inf (*_func)(_input, _output, window, _beta); } } +} // namespace arm_compute |