From ea9e0dc18c408fecb6dc482b774bd900dd321610 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Tue, 28 Aug 2018 16:24:56 +0100 Subject: COMPMID-1469: Add validate in NEGEMMMatrixAdditionKernel Change-Id: I228e2503eb40c12869fbd7e834ac1309aa613480 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/145878 Reviewed-by: Giorgio Arena Tested-by: Jenkins --- .../NEON/kernels/NEGEMMMatrixAdditionKernel.cpp | 43 ++++++++++++++++------ 1 file changed, 31 insertions(+), 12 deletions(-) (limited to 'src/core/NEON/kernels') diff --git a/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp b/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp index cd6aa553db..757dbbc399 100644 --- a/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp +++ b/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp @@ -32,15 +32,27 @@ #include -using namespace arm_compute; - namespace arm_compute { -class Coordinates; -} // namespace arm_compute - namespace { +Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, float beta) +{ + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output); + ARM_COMPUTE_UNUSED(beta); + + ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32); + + if(output->total_size() > 0) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output); + } + + return Status{}; +} + void matrix_addition_f32(const ITensor *input, ITensor *output, const Window &window, float beta) { const float32x4_t beta_f32 = vdupq_n_f32(beta); @@ -101,12 +113,10 @@ NEGEMMMatrixAdditionKernel::NEGEMMMatrixAdditionKernel() void NEGEMMMatrixAdditionKernel::configure(const ITensor *input, ITensor *output, float beta) { - ARM_COMPUTE_ERROR_ON_CPU_F16_UNSUPPORTED(input); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F16, DataType::F32); - ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); - ARM_COMPUTE_ERROR_ON(input->info()->dimension(0) != output->info()->dimension(0)); - ARM_COMPUTE_ERROR_ON(input->info()->dimension(1) != output->info()->dimension(1)); + ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); + + // Perform validation step + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), beta)); switch(input->info()->data_type()) { @@ -123,13 +133,21 @@ void NEGEMMMatrixAdditionKernel::configure(const ITensor *input, ITensor *output break; } + // Configure kernel window constexpr unsigned int num_elems_processed_per_iteration = 16; - INESimpleKernel::configure(input, output, num_elems_processed_per_iteration); _beta = beta; } +Status NEGEMMMatrixAdditionKernel::validate(const ITensorInfo *input, const ITensorInfo *output, float beta) +{ + constexpr unsigned int num_elems_processed_per_iteration = 16; + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, beta)); + ARM_COMPUTE_RETURN_ON_ERROR(INESimpleKernel::validate(input->clone().get(), output->clone().get(), num_elems_processed_per_iteration)); + return Status{}; +} + void NEGEMMMatrixAdditionKernel::run(const Window &window, const ThreadInfo &info) { ARM_COMPUTE_UNUSED(info); @@ -141,3 +159,4 @@ void NEGEMMMatrixAdditionKernel::run(const Window &window, const ThreadInfo &inf (*_func)(_input, _output, window, _beta); } } +} // namespace arm_compute -- cgit v1.2.1