diff options
Diffstat (limited to 'arm_compute/core/Validate.h')
-rw-r--r-- | arm_compute/core/Validate.h | 51 |
1 files changed, 51 insertions, 0 deletions
diff --git a/arm_compute/core/Validate.h b/arm_compute/core/Validate.h index b04d293c4c..08e872fd90 100644 --- a/arm_compute/core/Validate.h +++ b/arm_compute/core/Validate.h @@ -393,6 +393,57 @@ inline arm_compute::Status error_on_mismatching_shapes(const char *function, con #define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(...) \ ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_mismatching_shapes(__func__, __FILE__, __LINE__, __VA_ARGS__)) +/** Return an error if the passed tensor infos have different data layouts + * + * @param[in] function Function in which the error occurred. + * @param[in] file Name of the file where the error occurred. + * @param[in] line Line on which the error occurred. + * @param[in] tensor_info The first tensor info to be compared. + * @param[in] tensor_infos (Optional) Further allowed tensor infos. + * + * @return Status + */ +template <typename... Ts> +inline arm_compute::Status error_on_mismatching_data_layouts(const char *function, const char *file, const int line, + const ITensorInfo *tensor_info, Ts... tensor_infos) +{ + ARM_COMPUTE_RETURN_ERROR_ON_LOC(tensor_info == nullptr, function, file, line); + ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_nullptr(function, file, line, std::forward<Ts>(tensor_infos)...)); + + DataLayout &&tensor_data_layout = tensor_info->data_layout(); + const std::array<const ITensorInfo *, sizeof...(Ts)> tensors_infos_array{ { std::forward<Ts>(tensor_infos)... } }; + ARM_COMPUTE_RETURN_ERROR_ON_LOC_MSG(std::any_of(tensors_infos_array.begin(), tensors_infos_array.end(), [&](const ITensorInfo * tensor_info_obj) + { + return tensor_info_obj->data_layout() != tensor_data_layout; + }), + function, file, line, "Tensors have different data layouts"); + return arm_compute::Status{}; +} +/** Return an error if the passed tensors have different data layouts + * + * @param[in] function Function in which the error occurred. + * @param[in] file Name of the file where the error occurred. + * @param[in] line Line on which the error occurred. + * @param[in] tensor The first tensor to be compared. + * @param[in] tensors (Optional) Further allowed tensors. + * + * @return Status + */ +template <typename... Ts> +inline arm_compute::Status error_on_mismatching_data_layouts(const char *function, const char *file, const int line, + const ITensor *tensor, Ts... tensors) +{ + ARM_COMPUTE_RETURN_ERROR_ON_LOC(tensor == nullptr, function, file, line); + ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_nullptr(function, file, line, std::forward<Ts>(tensors)...)); + ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_mismatching_data_layouts(function, file, line, tensor->info(), + detail::get_tensor_info_t<ITensorInfo *>()(tensors)...)); + return arm_compute::Status{}; +} +#define ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_LAYOUT(...) \ + ARM_COMPUTE_ERROR_THROW_ON(::arm_compute::error_on_mismatching_data_layouts(__func__, __FILE__, __LINE__, __VA_ARGS__)) +#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(...) \ + ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_mismatching_data_layouts(__func__, __FILE__, __LINE__, __VA_ARGS__)) + /** Return an error if the passed two tensor infos have different data types * * @param[in] function Function in which the error occurred. |