aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/Validate.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/Validate.h')
-rw-r--r--arm_compute/core/Validate.h51
1 files changed, 51 insertions, 0 deletions
diff --git a/arm_compute/core/Validate.h b/arm_compute/core/Validate.h
index b04d293c4c..08e872fd90 100644
--- a/arm_compute/core/Validate.h
+++ b/arm_compute/core/Validate.h
@@ -393,6 +393,57 @@ inline arm_compute::Status error_on_mismatching_shapes(const char *function, con
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(...) \
ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_mismatching_shapes(__func__, __FILE__, __LINE__, __VA_ARGS__))
+/** Return an error if the passed tensor infos have different data layouts
+ *
+ * @param[in] function Function in which the error occurred.
+ * @param[in] file Name of the file where the error occurred.
+ * @param[in] line Line on which the error occurred.
+ * @param[in] tensor_info The first tensor info to be compared.
+ * @param[in] tensor_infos (Optional) Further allowed tensor infos.
+ *
+ * @return Status
+ */
+template <typename... Ts>
+inline arm_compute::Status error_on_mismatching_data_layouts(const char *function, const char *file, const int line,
+ const ITensorInfo *tensor_info, Ts... tensor_infos)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_LOC(tensor_info == nullptr, function, file, line);
+ ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_nullptr(function, file, line, std::forward<Ts>(tensor_infos)...));
+
+ DataLayout &&tensor_data_layout = tensor_info->data_layout();
+ const std::array<const ITensorInfo *, sizeof...(Ts)> tensors_infos_array{ { std::forward<Ts>(tensor_infos)... } };
+ ARM_COMPUTE_RETURN_ERROR_ON_LOC_MSG(std::any_of(tensors_infos_array.begin(), tensors_infos_array.end(), [&](const ITensorInfo * tensor_info_obj)
+ {
+ return tensor_info_obj->data_layout() != tensor_data_layout;
+ }),
+ function, file, line, "Tensors have different data layouts");
+ return arm_compute::Status{};
+}
+/** Return an error if the passed tensors have different data layouts
+ *
+ * @param[in] function Function in which the error occurred.
+ * @param[in] file Name of the file where the error occurred.
+ * @param[in] line Line on which the error occurred.
+ * @param[in] tensor The first tensor to be compared.
+ * @param[in] tensors (Optional) Further allowed tensors.
+ *
+ * @return Status
+ */
+template <typename... Ts>
+inline arm_compute::Status error_on_mismatching_data_layouts(const char *function, const char *file, const int line,
+ const ITensor *tensor, Ts... tensors)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_LOC(tensor == nullptr, function, file, line);
+ ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_nullptr(function, file, line, std::forward<Ts>(tensors)...));
+ ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_mismatching_data_layouts(function, file, line, tensor->info(),
+ detail::get_tensor_info_t<ITensorInfo *>()(tensors)...));
+ return arm_compute::Status{};
+}
+#define ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_LAYOUT(...) \
+ ARM_COMPUTE_ERROR_THROW_ON(::arm_compute::error_on_mismatching_data_layouts(__func__, __FILE__, __LINE__, __VA_ARGS__))
+#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(...) \
+ ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_mismatching_data_layouts(__func__, __FILE__, __LINE__, __VA_ARGS__))
+
/** Return an error if the passed two tensor infos have different data types
*
* @param[in] function Function in which the error occurred.