aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2017-11-03 19:01:44 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:35:24 +0000
commitf9d3a0a12ede4db89348fd924274c9acc6809bb2 (patch)
tree73080e4a230c4db248d6f9e16181eab09c8168dc /src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp
parentd6afedc775220f17317f1835a4d18b72a54525de (diff)
downloadComputeLibrary-f9d3a0a12ede4db89348fd924274c9acc6809bb2.tar.gz
COMPMID-617: Add validation functions.
Added validation routines to the following kernels. -CLActivationLayer -CLBatchNormalizationLayer -CLArithmeticAddition -CLArithmeticSubtraction -CLPixelwiseMultiplication Change-Id: I0f3a03154f9e392279f715af656683cd0ad4cef5 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/94595 Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp52
1 files changed, 38 insertions, 14 deletions
diff --git a/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp b/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp
index 33c8b81c1d..5e35c8c1ff 100644
--- a/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp
+++ b/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp
@@ -64,20 +64,8 @@ void CLPixelWiseMultiplicationKernel::configure(const ICLTensor *input1, const I
}
}
- ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_MSG(output->info()->data_type() == DataType::U8 && (input1->info()->data_type() != DataType::U8 || input2->info()->data_type() != DataType::U8),
- "Output can only be U8 if both inputs are U8");
- ARM_COMPUTE_ERROR_ON_MSG(scale < 0, "Scale cannot be negative. ");
- if(is_data_type_fixed_point(input1->info()->data_type()))
- {
- // All data types must be all QS8 or all QS16
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2, output);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input1, input2, output);
- ARM_COMPUTE_ERROR_ON_MSG(scale != 1, "Unsupported scaling factor for QS8/QS16. Scale must be 1.");
- }
+ ARM_COMPUTE_ERROR_THROW_ON(CLPixelWiseMultiplicationKernel::validate(input1->info(), input2->info(), output->info(),
+ scale, overflow_policy, rounding_policy));
_input1 = input1;
_input2 = input2;
@@ -178,6 +166,42 @@ void CLPixelWiseMultiplicationKernel::configure(const ICLTensor *input1, const I
ICLKernel::configure(win);
}
+Error CLPixelWiseMultiplicationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale,
+ ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
+{
+ ARM_COMPUTE_UNUSED(overflow_policy);
+ ARM_COMPUTE_UNUSED(rounding_policy);
+
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(scale < 0, "Scale cannot be negative.");
+
+ if(is_data_type_fixed_point(input1->data_type()))
+ {
+ // All data types must be all QS8 or all QS16
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(scale != 1, "Unsupported scaling factor for QS8/QS16. Scale must be 1.");
+ }
+
+ // Validate in case of configured output
+ if((output != nullptr) && (output->total_size() != 0))
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8),
+ "Output can only be U8 if both inputs are U8");
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, output);
+ if(is_data_type_fixed_point(input1->data_type()))
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, output);
+ }
+ }
+
+ return Error{};
+}
+
void CLPixelWiseMultiplicationKernel::run(const Window &window, cl::CommandQueue &queue)
{
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);