diff options
Diffstat (limited to 'arm_compute/core/Validate.h')
-rw-r--r-- | arm_compute/core/Validate.h | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/arm_compute/core/Validate.h b/arm_compute/core/Validate.h index 2ca9f6b64e..4f3b6102f5 100644 --- a/arm_compute/core/Validate.h +++ b/arm_compute/core/Validate.h @@ -477,6 +477,71 @@ inline arm_compute::Error error_on_mismatching_fixed_point(const char *function, #define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(...) \ ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_mismatching_fixed_point(__func__, __FILE__, __LINE__, __VA_ARGS__)) +/** Return an error if the passed tensor infos have different asymmetric quantized data types or different quantization info + * + * @note: If the first tensor info doesn't have asymmetric quantized data type, the function returns without throwing an error + * + * @param[in] function Function in which the error occurred. + * @param[in] file Name of the file where the error occurred. + * @param[in] line Line on which the error occurred. + * @param[in] tensor_info_1 The first tensor info to be compared. + * @param[in] tensor_info_2 The second tensor info to be compared. + * @param[in] tensor_infos (Optional) Further allowed tensor infos. + * + * @return Error + */ +template <typename... Ts> +inline arm_compute::Error error_on_mismatching_quantization_info(const char *function, const char *file, const int line, + const ITensorInfo *tensor_info_1, const ITensorInfo *tensor_info_2, Ts... tensor_infos) +{ + DataType &&first_data_type = tensor_info_1->data_type(); + const QuantizationInfo first_quantization_info = tensor_info_1->quantization_info(); + + if(!is_data_type_quantized_asymmetric(first_data_type)) + { + return arm_compute::Error{}; + } + + const std::array < const ITensorInfo *, 1 + sizeof...(Ts) > tensor_infos_array{ { tensor_info_2, std::forward<Ts>(tensor_infos)... } }; + ARM_COMPUTE_RETURN_ERROR_ON_LOC_MSG(std::any_of(tensor_infos_array.begin(), tensor_infos_array.end(), [&](const ITensorInfo * tensor_info) + { + return tensor_info->data_type() != first_data_type; + }), + function, file, line, "Tensors have different asymmetric quantized data types"); + ARM_COMPUTE_RETURN_ERROR_ON_LOC_MSG(std::any_of(tensor_infos_array.begin(), tensor_infos_array.end(), [&](const ITensorInfo * tensor_info) + { + return tensor_info->quantization_info() != first_quantization_info; + }), + function, file, line, "Tensors have different quantization information"); + + return arm_compute::Error{}; +} +/** Return an error if the passed tensor have different asymmetric quantized data types or different quantization info + * + * @note: If the first tensor doesn't have asymmetric quantized data type, the function returns without throwing an error + * + * @param[in] function Function in which the error occurred. + * @param[in] file Name of the file where the error occurred. + * @param[in] line Line on which the error occurred. + * @param[in] tensor_1 The first tensor to be compared. + * @param[in] tensor_2 The second tensor to be compared. + * @param[in] tensors (Optional) Further allowed tensors. + * + * @return Error + */ +template <typename... Ts> +inline arm_compute::Error error_on_mismatching_quantization_info(const char *function, const char *file, const int line, + const ITensor *tensor_1, const ITensor *tensor_2, Ts... tensors) +{ + ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_mismatching_quantization_info(function, file, line, tensor_1->info(), tensor_2->info(), + detail::get_tensor_info_t<ITensorInfo *>()(tensors)...)); + return arm_compute::Error{}; +} +#define ARM_COMPUTE_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(...) \ + ARM_COMPUTE_ERROR_THROW_ON(::arm_compute::error_on_mismatching_quantization_info(__func__, __FILE__, __LINE__, __VA_ARGS__)) +#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(...) \ + ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_mismatching_quantization_info(__func__, __FILE__, __LINE__, __VA_ARGS__)) + /** Throw an error if the format of the passed tensor/multi-image does not match any of the formats provided. * * @param[in] function Function in which the error occurred. |