diff options
Diffstat (limited to 'arm_compute/core/Helpers.h')
-rw-r--r-- | arm_compute/core/Helpers.h | 20 |
1 files changed, 19 insertions, 1 deletions
diff --git a/arm_compute/core/Helpers.h b/arm_compute/core/Helpers.h index 6e4d987180..edb05e99a1 100644 --- a/arm_compute/core/Helpers.h +++ b/arm_compute/core/Helpers.h @@ -466,10 +466,15 @@ inline Strides compute_strides(const ITensorInfo &info) * @param[in] num_channels New number of channels. * @param[in] data_type New data type * @param[in] fixed_point_position New fixed point position + * @param[in] quantization_info (Optional) New quantization info * * @return True if the tensor info has been initialized */ -bool auto_init_if_empty(ITensorInfo &info, const TensorShape &shape, int num_channels, DataType data_type, int fixed_point_position); +bool auto_init_if_empty(ITensorInfo &info, + const TensorShape &shape, + int num_channels, DataType data_type, + int fixed_point_position, + QuantizationInfo quantization_info = QuantizationInfo()); /* Set the shape to the specified value if the current assignment is empty. * @@ -509,6 +514,17 @@ bool set_data_type_if_unknown(ITensorInfo &info, DataType data_type); * @return True if the fixed point position has been changed. */ bool set_fixed_point_position_if_zero(ITensorInfo &info, int fixed_point_position); + +/* Set the quantization info to the specified value if + * the current quantization info is empty and the data type of asymmetric quantized type + * + * @param[in,out] info Tensor info used to check and assign. + * @param[in] quantization_info Quantization info + * + * @return True if the quantization info has been changed. + */ +bool set_quantization_info_if_empty(ITensorInfo &info, QuantizationInfo quantization_info); + /** Helper function to calculate the Valid Region for Scale. * * @param[in] src_info Input tensor info used to check. @@ -520,6 +536,7 @@ bool set_fixed_point_position_if_zero(ITensorInfo &info, int fixed_point_positio * @return The corrispondent valid region */ ValidRegion calculate_valid_region_scale(const ITensorInfo &src_info, const TensorShape &dst_shape, InterpolationPolicy policy, BorderSize border_size, bool border_undefined); + /** Convert a linear index into n-dimensional coordinates. * * @param[in] shape Shape of the n-dimensional tensor. @@ -528,6 +545,7 @@ ValidRegion calculate_valid_region_scale(const ITensorInfo &src_info, const Tens * @return n-dimensional coordinates. */ inline Coordinates index2coords(const TensorShape &shape, int index); + /** Convert n-dimensional coordinates into a linear index. * * @param[in] shape Shape of the n-dimensional tensor. |