aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/Validate.h
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2018-03-01 12:03:14 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:48:33 +0000
commit7a56925ad139d66356cf5b1e348c4ec318ad41b5 (patch)
tree793a2572d3e04bcdb602f67b9760dd889e9b6c22 /arm_compute/core/Validate.h
parent30a6342241e4a0d70c3b05f8149ccc40014a8d1d (diff)
downloadComputeLibrary-7a56925ad139d66356cf5b1e348c4ec318ad41b5.tar.gz
COMPMID-959 Add data layout check in Validate
Change-Id: I89474e7d65e67739eae8ec0b9968a1ee83e9dba6 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/122802 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Pablo Tello <pablo.tello@arm.com>
Diffstat (limited to 'arm_compute/core/Validate.h')
-rw-r--r--arm_compute/core/Validate.h51
1 files changed, 51 insertions, 0 deletions
diff --git a/arm_compute/core/Validate.h b/arm_compute/core/Validate.h
index b04d293c4c..08e872fd90 100644
--- a/arm_compute/core/Validate.h
+++ b/arm_compute/core/Validate.h
@@ -393,6 +393,57 @@ inline arm_compute::Status error_on_mismatching_shapes(const char *function, con
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(...) \
ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_mismatching_shapes(__func__, __FILE__, __LINE__, __VA_ARGS__))
+/** Return an error if the passed tensor infos have different data layouts
+ *
+ * @param[in] function Function in which the error occurred.
+ * @param[in] file Name of the file where the error occurred.
+ * @param[in] line Line on which the error occurred.
+ * @param[in] tensor_info The first tensor info to be compared.
+ * @param[in] tensor_infos (Optional) Further allowed tensor infos.
+ *
+ * @return Status
+ */
+template <typename... Ts>
+inline arm_compute::Status error_on_mismatching_data_layouts(const char *function, const char *file, const int line,
+ const ITensorInfo *tensor_info, Ts... tensor_infos)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_LOC(tensor_info == nullptr, function, file, line);
+ ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_nullptr(function, file, line, std::forward<Ts>(tensor_infos)...));
+
+ DataLayout &&tensor_data_layout = tensor_info->data_layout();
+ const std::array<const ITensorInfo *, sizeof...(Ts)> tensors_infos_array{ { std::forward<Ts>(tensor_infos)... } };
+ ARM_COMPUTE_RETURN_ERROR_ON_LOC_MSG(std::any_of(tensors_infos_array.begin(), tensors_infos_array.end(), [&](const ITensorInfo * tensor_info_obj)
+ {
+ return tensor_info_obj->data_layout() != tensor_data_layout;
+ }),
+ function, file, line, "Tensors have different data layouts");
+ return arm_compute::Status{};
+}
+/** Return an error if the passed tensors have different data layouts
+ *
+ * @param[in] function Function in which the error occurred.
+ * @param[in] file Name of the file where the error occurred.
+ * @param[in] line Line on which the error occurred.
+ * @param[in] tensor The first tensor to be compared.
+ * @param[in] tensors (Optional) Further allowed tensors.
+ *
+ * @return Status
+ */
+template <typename... Ts>
+inline arm_compute::Status error_on_mismatching_data_layouts(const char *function, const char *file, const int line,
+ const ITensor *tensor, Ts... tensors)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_LOC(tensor == nullptr, function, file, line);
+ ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_nullptr(function, file, line, std::forward<Ts>(tensors)...));
+ ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_mismatching_data_layouts(function, file, line, tensor->info(),
+ detail::get_tensor_info_t<ITensorInfo *>()(tensors)...));
+ return arm_compute::Status{};
+}
+#define ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_LAYOUT(...) \
+ ARM_COMPUTE_ERROR_THROW_ON(::arm_compute::error_on_mismatching_data_layouts(__func__, __FILE__, __LINE__, __VA_ARGS__))
+#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(...) \
+ ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_mismatching_data_layouts(__func__, __FILE__, __LINE__, __VA_ARGS__))
+
/** Return an error if the passed two tensor infos have different data types
*
* @param[in] function Function in which the error occurred.