aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/Validate.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/Validate.h')
-rw-r--r--arm_compute/core/Validate.h19
1 files changed, 10 insertions, 9 deletions
diff --git a/arm_compute/core/Validate.h b/arm_compute/core/Validate.h
index 48eba70adf..dd71f2c714 100644
--- a/arm_compute/core/Validate.h
+++ b/arm_compute/core/Validate.h
@@ -249,29 +249,30 @@ void error_on_mismatching_shapes(const char *function, const char *file, const i
* @param[in] function Function in which the error occurred.
* @param[in] file Name of the file where the error occurred.
* @param[in] line Line on which the error occurred.
- * @param[in] tensor_1 The first tensor to be compared.
- * @param[in] tensor_2 The second tensor to be compared.
+ * @param[in] tensor The first tensor to be compared.
* @param[in] tensors (Optional) Further allowed tensors.
*/
template <typename... Ts>
void error_on_mismatching_data_types(const char *function, const char *file, const int line,
- const ITensor *tensor_1, const ITensor *tensor_2, Ts... tensors)
+ const ITensor *tensor, Ts... tensors)
{
ARM_COMPUTE_UNUSED(function);
ARM_COMPUTE_UNUSED(file);
ARM_COMPUTE_UNUSED(line);
- ARM_COMPUTE_UNUSED(tensor_1);
- ARM_COMPUTE_UNUSED(tensor_2);
+ ARM_COMPUTE_UNUSED(tensor);
- DataType &&first_data_type = tensor_1->info()->data_type();
- ARM_COMPUTE_UNUSED(first_data_type);
+ ARM_COMPUTE_ERROR_ON_LOC(tensor == nullptr, function, file, line);
+
+ DataType &&tensor_data_type = tensor->info()->data_type();
+ ARM_COMPUTE_UNUSED(tensor_data_type);
const std::array<const ITensor *, sizeof...(Ts)> tensors_array{ { std::forward<Ts>(tensors)... } };
ARM_COMPUTE_UNUSED(tensors_array);
- ARM_COMPUTE_ERROR_ON_LOC_MSG(tensor_2->info()->data_type() != first_data_type || std::any_of(tensors_array.begin(), tensors_array.end(), [&](const ITensor * tensor)
+ ARM_COMPUTE_ERROR_ON_LOC_MSG(std::any_of(tensors_array.begin(), tensors_array.end(), [&](const ITensor * tensor_obj)
{
- return tensor->info()->data_type() != first_data_type;
+ ARM_COMPUTE_ERROR_ON_LOC(tensor_obj == nullptr, function, file, line);
+ return tensor_obj->info()->data_type() != tensor_data_type;
}),
function, file, line, "Tensors have different data types");
}