diff options
Diffstat (limited to 'arm_compute/core')
-rw-r--r-- | arm_compute/core/CL/kernels/CLReshapeLayerKernel.h | 2 | ||||
-rw-r--r-- | arm_compute/core/Types.h | 10 | ||||
-rw-r--r-- | arm_compute/core/Validate.h | 65 |
3 files changed, 76 insertions, 1 deletions
diff --git a/arm_compute/core/CL/kernels/CLReshapeLayerKernel.h b/arm_compute/core/CL/kernels/CLReshapeLayerKernel.h index d8ccfa88cb..044b5e7006 100644 --- a/arm_compute/core/CL/kernels/CLReshapeLayerKernel.h +++ b/arm_compute/core/CL/kernels/CLReshapeLayerKernel.h @@ -49,7 +49,7 @@ public: ~CLReshapeLayerKernel() = default; /** Set the input and output of the kernel * - * @param[in] input Source tensor. Data type supported: U8/S8/QS8/U16/S16/QS16/U32/S32/F16/F32 + * @param[in] input Source tensor. Data type supported: U8/S8/QS8/QASYMM8/U16/S16/QS16/U32/S32/F16/F32 * @param[out] output Destination tensor. Data type supported: Same as @p input */ void configure(const ICLTensor *input, ICLTensor *output); diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index eaff8fb709..6e7eb3c829 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -104,6 +104,16 @@ struct QuantizationInfo { } + bool operator==(const QuantizationInfo &other) + { + return scale == other.scale && offset == other.offset; + } + + bool operator!=(const QuantizationInfo &other) + { + return !(*this == other); + } + float scale; /**< scale */ int offset; /**< offset */ diff --git a/arm_compute/core/Validate.h b/arm_compute/core/Validate.h index 2ca9f6b64e..4f3b6102f5 100644 --- a/arm_compute/core/Validate.h +++ b/arm_compute/core/Validate.h @@ -477,6 +477,71 @@ inline arm_compute::Error error_on_mismatching_fixed_point(const char *function, #define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(...) \ ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_mismatching_fixed_point(__func__, __FILE__, __LINE__, __VA_ARGS__)) +/** Return an error if the passed tensor infos have different asymmetric quantized data types or different quantization info + * + * @note: If the first tensor info doesn't have asymmetric quantized data type, the function returns without throwing an error + * + * @param[in] function Function in which the error occurred. + * @param[in] file Name of the file where the error occurred. + * @param[in] line Line on which the error occurred. + * @param[in] tensor_info_1 The first tensor info to be compared. + * @param[in] tensor_info_2 The second tensor info to be compared. + * @param[in] tensor_infos (Optional) Further allowed tensor infos. + * + * @return Error + */ +template <typename... Ts> +inline arm_compute::Error error_on_mismatching_quantization_info(const char *function, const char *file, const int line, + const ITensorInfo *tensor_info_1, const ITensorInfo *tensor_info_2, Ts... tensor_infos) +{ + DataType &&first_data_type = tensor_info_1->data_type(); + const QuantizationInfo first_quantization_info = tensor_info_1->quantization_info(); + + if(!is_data_type_quantized_asymmetric(first_data_type)) + { + return arm_compute::Error{}; + } + + const std::array < const ITensorInfo *, 1 + sizeof...(Ts) > tensor_infos_array{ { tensor_info_2, std::forward<Ts>(tensor_infos)... } }; + ARM_COMPUTE_RETURN_ERROR_ON_LOC_MSG(std::any_of(tensor_infos_array.begin(), tensor_infos_array.end(), [&](const ITensorInfo * tensor_info) + { + return tensor_info->data_type() != first_data_type; + }), + function, file, line, "Tensors have different asymmetric quantized data types"); + ARM_COMPUTE_RETURN_ERROR_ON_LOC_MSG(std::any_of(tensor_infos_array.begin(), tensor_infos_array.end(), [&](const ITensorInfo * tensor_info) + { + return tensor_info->quantization_info() != first_quantization_info; + }), + function, file, line, "Tensors have different quantization information"); + + return arm_compute::Error{}; +} +/** Return an error if the passed tensor have different asymmetric quantized data types or different quantization info + * + * @note: If the first tensor doesn't have asymmetric quantized data type, the function returns without throwing an error + * + * @param[in] function Function in which the error occurred. + * @param[in] file Name of the file where the error occurred. + * @param[in] line Line on which the error occurred. + * @param[in] tensor_1 The first tensor to be compared. + * @param[in] tensor_2 The second tensor to be compared. + * @param[in] tensors (Optional) Further allowed tensors. + * + * @return Error + */ +template <typename... Ts> +inline arm_compute::Error error_on_mismatching_quantization_info(const char *function, const char *file, const int line, + const ITensor *tensor_1, const ITensor *tensor_2, Ts... tensors) +{ + ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_mismatching_quantization_info(function, file, line, tensor_1->info(), tensor_2->info(), + detail::get_tensor_info_t<ITensorInfo *>()(tensors)...)); + return arm_compute::Error{}; +} +#define ARM_COMPUTE_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(...) \ + ARM_COMPUTE_ERROR_THROW_ON(::arm_compute::error_on_mismatching_quantization_info(__func__, __FILE__, __LINE__, __VA_ARGS__)) +#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(...) \ + ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_mismatching_quantization_info(__func__, __FILE__, __LINE__, __VA_ARGS__)) + /** Throw an error if the format of the passed tensor/multi-image does not match any of the formats provided. * * @param[in] function Function in which the error occurred. |