aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-08-28 16:24:56 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commitea9e0dc18c408fecb6dc482b774bd900dd321610 (patch)
treeb6e67a6559b53b5d4d97f77251d83ac73a6e55a5 /src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp
parent84797636b0ad44c16838df4177cf5a05aa929781 (diff)
downloadComputeLibrary-ea9e0dc18c408fecb6dc482b774bd900dd321610.tar.gz
COMPMID-1469: Add validate in NEGEMMMatrixAdditionKernel
Change-Id: I228e2503eb40c12869fbd7e834ac1309aa613480 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/145878 Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp43
1 files changed, 31 insertions, 12 deletions
diff --git a/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp b/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp
index cd6aa553db..757dbbc399 100644
--- a/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp
+++ b/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp
@@ -32,15 +32,27 @@
#include <arm_neon.h>
-using namespace arm_compute;
-
namespace arm_compute
{
-class Coordinates;
-} // namespace arm_compute
-
namespace
{
+Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, float beta)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
+ ARM_COMPUTE_UNUSED(beta);
+
+ ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
+
+ if(output->total_size() > 0)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
+ }
+
+ return Status{};
+}
+
void matrix_addition_f32(const ITensor *input, ITensor *output, const Window &window, float beta)
{
const float32x4_t beta_f32 = vdupq_n_f32(beta);
@@ -101,12 +113,10 @@ NEGEMMMatrixAdditionKernel::NEGEMMMatrixAdditionKernel()
void NEGEMMMatrixAdditionKernel::configure(const ITensor *input, ITensor *output, float beta)
{
- ARM_COMPUTE_ERROR_ON_CPU_F16_UNSUPPORTED(input);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_ERROR_ON(input->info()->dimension(0) != output->info()->dimension(0));
- ARM_COMPUTE_ERROR_ON(input->info()->dimension(1) != output->info()->dimension(1));
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
+
+ // Perform validation step
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), beta));
switch(input->info()->data_type())
{
@@ -123,13 +133,21 @@ void NEGEMMMatrixAdditionKernel::configure(const ITensor *input, ITensor *output
break;
}
+ // Configure kernel window
constexpr unsigned int num_elems_processed_per_iteration = 16;
-
INESimpleKernel::configure(input, output, num_elems_processed_per_iteration);
_beta = beta;
}
+Status NEGEMMMatrixAdditionKernel::validate(const ITensorInfo *input, const ITensorInfo *output, float beta)
+{
+ constexpr unsigned int num_elems_processed_per_iteration = 16;
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, beta));
+ ARM_COMPUTE_RETURN_ON_ERROR(INESimpleKernel::validate(input->clone().get(), output->clone().get(), num_elems_processed_per_iteration));
+ return Status{};
+}
+
void NEGEMMMatrixAdditionKernel::run(const Window &window, const ThreadInfo &info)
{
ARM_COMPUTE_UNUSED(info);
@@ -141,3 +159,4 @@ void NEGEMMMatrixAdditionKernel::run(const Window &window, const ThreadInfo &inf
(*_func)(_input, _output, window, _beta);
}
}
+} // namespace arm_compute