aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp33
1 files changed, 22 insertions, 11 deletions
diff --git a/src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp b/src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp
index c5183af7d7..47d77ad8a9 100644
--- a/src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp
+++ b/src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp
@@ -61,17 +61,7 @@ void CLArithmeticSubtractionKernel::configure(const ICLTensor *input1, const ICL
}
}
- ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_MSG(output->info()->data_type() == DataType::U8 && (input1->info()->data_type() != DataType::U8 || input2->info()->data_type() != DataType::U8),
- "Output can only be U8 if both inputs are U8");
- if(is_data_type_fixed_point(input1->info()->data_type()) || is_data_type_fixed_point(input2->info()->data_type()) || is_data_type_fixed_point(output->info()->data_type()))
- {
- // Check that all data types are the same and all fixed-point positions are the same
- ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2, output);
- }
+ ARM_COMPUTE_ERROR_THROW_ON(CLArithmeticSubtractionKernel::validate(input1->info(), input2->info(), output->info(), policy));
_input1 = input1;
_input2 = input2;
@@ -111,6 +101,27 @@ void CLArithmeticSubtractionKernel::configure(const ICLTensor *input1, const ICL
ICLKernel::configure(win);
}
+Error CLArithmeticSubtractionKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy)
+{
+ ARM_COMPUTE_UNUSED(policy);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2);
+
+ // Validate in case of configured output
+ if((output != nullptr) && (output->total_size() != 0))
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8),
+ "Output can only be U8 if both inputs are U8");
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, output);
+ }
+
+ return Error{};
+}
+
void CLArithmeticSubtractionKernel::run(const Window &window, cl::CommandQueue &queue)
{
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);