From 7a56925ad139d66356cf5b1e348c4ec318ad41b5 Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Thu, 1 Mar 2018 12:03:14 +0000 Subject: COMPMID-959 Add data layout check in Validate Change-Id: I89474e7d65e67739eae8ec0b9968a1ee83e9dba6 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/122802 Tested-by: Jenkins Reviewed-by: Pablo Tello --- arm_compute/core/Validate.h | 51 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) (limited to 'arm_compute/core/Validate.h') diff --git a/arm_compute/core/Validate.h b/arm_compute/core/Validate.h index b04d293c4c..08e872fd90 100644 --- a/arm_compute/core/Validate.h +++ b/arm_compute/core/Validate.h @@ -393,6 +393,57 @@ inline arm_compute::Status error_on_mismatching_shapes(const char *function, con #define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(...) \ ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_mismatching_shapes(__func__, __FILE__, __LINE__, __VA_ARGS__)) +/** Return an error if the passed tensor infos have different data layouts + * + * @param[in] function Function in which the error occurred. + * @param[in] file Name of the file where the error occurred. + * @param[in] line Line on which the error occurred. + * @param[in] tensor_info The first tensor info to be compared. + * @param[in] tensor_infos (Optional) Further allowed tensor infos. + * + * @return Status + */ +template +inline arm_compute::Status error_on_mismatching_data_layouts(const char *function, const char *file, const int line, + const ITensorInfo *tensor_info, Ts... tensor_infos) +{ + ARM_COMPUTE_RETURN_ERROR_ON_LOC(tensor_info == nullptr, function, file, line); + ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_nullptr(function, file, line, std::forward(tensor_infos)...)); + + DataLayout &&tensor_data_layout = tensor_info->data_layout(); + const std::array tensors_infos_array{ { std::forward(tensor_infos)... } }; + ARM_COMPUTE_RETURN_ERROR_ON_LOC_MSG(std::any_of(tensors_infos_array.begin(), tensors_infos_array.end(), [&](const ITensorInfo * tensor_info_obj) + { + return tensor_info_obj->data_layout() != tensor_data_layout; + }), + function, file, line, "Tensors have different data layouts"); + return arm_compute::Status{}; +} +/** Return an error if the passed tensors have different data layouts + * + * @param[in] function Function in which the error occurred. + * @param[in] file Name of the file where the error occurred. + * @param[in] line Line on which the error occurred. + * @param[in] tensor The first tensor to be compared. + * @param[in] tensors (Optional) Further allowed tensors. + * + * @return Status + */ +template +inline arm_compute::Status error_on_mismatching_data_layouts(const char *function, const char *file, const int line, + const ITensor *tensor, Ts... tensors) +{ + ARM_COMPUTE_RETURN_ERROR_ON_LOC(tensor == nullptr, function, file, line); + ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_nullptr(function, file, line, std::forward(tensors)...)); + ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_mismatching_data_layouts(function, file, line, tensor->info(), + detail::get_tensor_info_t()(tensors)...)); + return arm_compute::Status{}; +} +#define ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_LAYOUT(...) \ + ARM_COMPUTE_ERROR_THROW_ON(::arm_compute::error_on_mismatching_data_layouts(__func__, __FILE__, __LINE__, __VA_ARGS__)) +#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(...) \ + ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_mismatching_data_layouts(__func__, __FILE__, __LINE__, __VA_ARGS__)) + /** Return an error if the passed two tensor infos have different data types * * @param[in] function Function in which the error occurred. -- cgit v1.2.1