diff options
Diffstat (limited to 'arm_compute/core/Validate.h')
-rw-r--r-- | arm_compute/core/Validate.h | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/arm_compute/core/Validate.h b/arm_compute/core/Validate.h index 918c8e5fc3..dab4221a3b 100644 --- a/arm_compute/core/Validate.h +++ b/arm_compute/core/Validate.h @@ -693,6 +693,58 @@ inline arm_compute::Status error_on_data_type_not_in(const char *function, const #define ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(t, ...) \ ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_data_type_not_in(__func__, __FILE__, __LINE__, t, __VA_ARGS__)) +/** Return an error if the data layout of the passed tensor info does not match any of the data layouts provided. + * + * @param[in] function Function in which the error occurred. + * @param[in] file Name of the file where the error occurred. + * @param[in] line Line on which the error occurred. + * @param[in] tensor_info Tensor info to validate. + * @param[in] dl First data layout allowed. + * @param[in] dls (Optional) Further allowed data layouts. + * + * @return Status + */ +template <typename T, typename... Ts> +inline arm_compute::Status error_on_data_layout_not_in(const char *function, const char *file, const int line, + const ITensorInfo *tensor_info, T &&dl, Ts &&... dls) +{ + ARM_COMPUTE_RETURN_ERROR_ON_LOC(tensor_info == nullptr, function, file, line); + + const DataLayout &tensor_dl = tensor_info->data_layout(); //NOLINT + ARM_COMPUTE_RETURN_ERROR_ON_LOC(tensor_dl == DataLayout::UNKNOWN, function, file, line); + + const std::array<T, sizeof...(Ts)> dls_array{ { std::forward<Ts>(dls)... } }; + ARM_COMPUTE_RETURN_ERROR_ON_LOC_MSG(tensor_dl != dl && std::none_of(dls_array.begin(), dls_array.end(), [&](const T & l) + { + return l == tensor_dl; + }), + function, file, line, "ITensor data layout %s not supported by this kernel", string_from_data_layout(tensor_dl).c_str()); + return arm_compute::Status{}; +} +/** Return an error if the data layout of the passed tensor does not match any of the data layout provided. + * + * @param[in] function Function in which the error occurred. + * @param[in] file Name of the file where the error occurred. + * @param[in] line Line on which the error occurred. + * @param[in] tensor Tensor to validate. + * @param[in] dl First data layout allowed. + * @param[in] dls (Optional) Further allowed data layouts. + * + * @return Status + */ +template <typename T, typename... Ts> +inline arm_compute::Status error_on_data_layout_not_in(const char *function, const char *file, const int line, + const ITensor *tensor, T &&dl, Ts &&... dls) +{ + ARM_COMPUTE_RETURN_ERROR_ON_LOC(tensor == nullptr, function, file, line); + ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_data_layout_not_in(function, file, line, tensor->info(), std::forward<T>(dl), std::forward<Ts>(dls)...)); + return arm_compute::Status{}; +} +#define ARM_COMPUTE_ERROR_ON_DATA_LAYOUT_NOT_IN(t, ...) \ + ARM_COMPUTE_ERROR_THROW_ON(::arm_compute::error_on_data_layout_not_in(__func__, __FILE__, __LINE__, t, __VA_ARGS__)) +#define ARM_COMPUTE_RETURN_ERROR_ON_DATA_LAYOUT_NOT_IN(t, ...) \ + ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_data_layout_not_in(__func__, __FILE__, __LINE__, t, __VA_ARGS__)) + /** Return an error if the data type or the number of channels of the passed tensor info does not match any of the data types and number of channels provided. * * @param[in] function Function in which the error occurred. |