diff options
author | Michalis Spyrou <michalis.spyrou@arm.com> | 2018-03-01 12:03:14 +0000 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:48:33 +0000 |
commit | 7a56925ad139d66356cf5b1e348c4ec318ad41b5 (patch) | |
tree | 793a2572d3e04bcdb602f67b9760dd889e9b6c22 | |
parent | 30a6342241e4a0d70c3b05f8149ccc40014a8d1d (diff) | |
download | ComputeLibrary-7a56925ad139d66356cf5b1e348c4ec318ad41b5.tar.gz |
COMPMID-959 Add data layout check in Validate
Change-Id: I89474e7d65e67739eae8ec0b9968a1ee83e9dba6
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/122802
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Pablo Tello <pablo.tello@arm.com>
-rw-r--r-- | arm_compute/core/Validate.h | 51 |
1 files changed, 51 insertions, 0 deletions
diff --git a/arm_compute/core/Validate.h b/arm_compute/core/Validate.h index b04d293c4c..08e872fd90 100644 --- a/arm_compute/core/Validate.h +++ b/arm_compute/core/Validate.h @@ -393,6 +393,57 @@ inline arm_compute::Status error_on_mismatching_shapes(const char *function, con #define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(...) \ ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_mismatching_shapes(__func__, __FILE__, __LINE__, __VA_ARGS__)) +/** Return an error if the passed tensor infos have different data layouts + * + * @param[in] function Function in which the error occurred. + * @param[in] file Name of the file where the error occurred. + * @param[in] line Line on which the error occurred. + * @param[in] tensor_info The first tensor info to be compared. + * @param[in] tensor_infos (Optional) Further allowed tensor infos. + * + * @return Status + */ +template <typename... Ts> +inline arm_compute::Status error_on_mismatching_data_layouts(const char *function, const char *file, const int line, + const ITensorInfo *tensor_info, Ts... tensor_infos) +{ + ARM_COMPUTE_RETURN_ERROR_ON_LOC(tensor_info == nullptr, function, file, line); + ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_nullptr(function, file, line, std::forward<Ts>(tensor_infos)...)); + + DataLayout &&tensor_data_layout = tensor_info->data_layout(); + const std::array<const ITensorInfo *, sizeof...(Ts)> tensors_infos_array{ { std::forward<Ts>(tensor_infos)... } }; + ARM_COMPUTE_RETURN_ERROR_ON_LOC_MSG(std::any_of(tensors_infos_array.begin(), tensors_infos_array.end(), [&](const ITensorInfo * tensor_info_obj) + { + return tensor_info_obj->data_layout() != tensor_data_layout; + }), + function, file, line, "Tensors have different data layouts"); + return arm_compute::Status{}; +} +/** Return an error if the passed tensors have different data layouts + * + * @param[in] function Function in which the error occurred. + * @param[in] file Name of the file where the error occurred. + * @param[in] line Line on which the error occurred. + * @param[in] tensor The first tensor to be compared. + * @param[in] tensors (Optional) Further allowed tensors. + * + * @return Status + */ +template <typename... Ts> +inline arm_compute::Status error_on_mismatching_data_layouts(const char *function, const char *file, const int line, + const ITensor *tensor, Ts... tensors) +{ + ARM_COMPUTE_RETURN_ERROR_ON_LOC(tensor == nullptr, function, file, line); + ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_nullptr(function, file, line, std::forward<Ts>(tensors)...)); + ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_mismatching_data_layouts(function, file, line, tensor->info(), + detail::get_tensor_info_t<ITensorInfo *>()(tensors)...)); + return arm_compute::Status{}; +} +#define ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_LAYOUT(...) \ + ARM_COMPUTE_ERROR_THROW_ON(::arm_compute::error_on_mismatching_data_layouts(__func__, __FILE__, __LINE__, __VA_ARGS__)) +#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(...) \ + ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_mismatching_data_layouts(__func__, __FILE__, __LINE__, __VA_ARGS__)) + /** Return an error if the passed two tensor infos have different data types * * @param[in] function Function in which the error occurred. |