aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/cpu/kernels/CpuScaleKernel.cpp4
-rw-r--r--src/gpu/cl/kernels/ClScaleKernel.cpp1
2 files changed, 3 insertions, 2 deletions
diff --git a/src/cpu/kernels/CpuScaleKernel.cpp b/src/cpu/kernels/CpuScaleKernel.cpp
index b8bb5ad18a..4f01c794cf 100644
--- a/src/cpu/kernels/CpuScaleKernel.cpp
+++ b/src/cpu/kernels/CpuScaleKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2022 Arm Limited.
+ * Copyright (c) 2016-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -136,10 +136,10 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dx, const I
const auto *uk = CpuScaleKernel::get_implementation(ScaleKernelDataTypeISASelectorData{ src->data_type(), CPUInfo::get().get_isa(), info.interpolation_policy });
ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
-
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(dst);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, dst);
ARM_COMPUTE_RETURN_ERROR_ON(dst == src);
+ ARM_COMPUTE_RETURN_ERROR_ON(src->num_channels()!=1);
ARM_COMPUTE_RETURN_ERROR_ON(info.sampling_policy != SamplingPolicy::CENTER && info.sampling_policy != SamplingPolicy::TOP_LEFT);
ARM_COMPUTE_UNUSED(info.constant_border_value);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(info.use_padding, "Padding is not supported");
diff --git a/src/gpu/cl/kernels/ClScaleKernel.cpp b/src/gpu/cl/kernels/ClScaleKernel.cpp
index 910287194e..d31c387ee5 100644
--- a/src/gpu/cl/kernels/ClScaleKernel.cpp
+++ b/src/gpu/cl/kernels/ClScaleKernel.cpp
@@ -64,6 +64,7 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst, const
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, dst);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(src, dst);
ARM_COMPUTE_RETURN_ERROR_ON(dst == src);
+ ARM_COMPUTE_RETURN_ERROR_ON(src->num_channels()!=1);
ARM_COMPUTE_RETURN_ERROR_ON(info.align_corners && !arm_compute::scale_utils::is_align_corners_allowed_sampling_policy(info.sampling_policy));
ARM_COMPUTE_RETURN_ERROR_ON(is_data_type_quantized(src->data_type()) && !is_data_type_quantized_asymmetric(src->data_type()));