From f7f78ae236e623a57919f9450e8b2043e681ddb3 Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Wed, 25 May 2022 15:26:38 +0100 Subject: Add support for uint16_t to RESCALE Update ref-model RESCALE op to support UINT16 conversions Add testing for RESCALE UINT16 and ERROR_IFs Signed-off-by: Jeremy Johnson Change-Id: Ic6e6e53de1f0b054bedb9e6ba3856e7475498aba --- reference_model/src/ops/op_factory.cc | 4 + reference_model/src/ops/template_types.h | 22 ++++- reference_model/src/ops/type_conversion.cc | 24 ++++- reference_model/src/quant_util.h | 2 +- reference_model/src/tensor.cc | 3 + reference_model/src/tensor.h | 1 + thirdparty/serialization_lib | 2 +- verif/generator/tosa_arg_gen.py | 49 +++++++--- verif/generator/tosa_error_if.py | 138 +++++++++++++++++++---------- verif/generator/tosa_test_gen.py | 49 +++++++--- verif/generator/tosa_utils.py | 10 ++- 11 files changed, 226 insertions(+), 78 deletions(-) diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index 6edd63f..f7ded9a 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -396,7 +396,11 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT16, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, UINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT16); break; // custom diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h index 0fe9a41..2bc7e04 100644 --- a/reference_model/src/ops/template_types.h +++ b/reference_model/src/ops/template_types.h @@ -23,7 +23,7 @@ using namespace tosa; namespace TosaReference { -// Shorter aliase templates for common Eigen::Tensor types +// Shorter alias templates for common Eigen::Tensor types template using ETensor0 = Eigen::Tensor; template @@ -89,6 +89,11 @@ struct GetEigenType using type = int32_t; }; template <> +struct GetEigenType +{ + using type = int32_t; +}; +template <> struct GetEigenType { using type = int32_t; @@ -121,6 +126,11 @@ struct GetNumBits static constexpr int32_t value = 8; }; template <> +struct GetNumBits +{ + static constexpr int32_t value = 16; +}; +template <> struct GetNumBits { static constexpr int32_t value = 4; @@ -158,6 +168,11 @@ struct GetQMin static constexpr int64_t value = 0L; }; template <> +struct GetQMin +{ + static constexpr int64_t value = 0L; +}; +template <> struct GetQMin { static constexpr int64_t value = -8L; @@ -194,6 +209,11 @@ struct GetQMax static constexpr int64_t value = 255L; }; template <> +struct GetQMax +{ + static constexpr int64_t value = 65535L; +}; +template <> struct GetQMax { static constexpr int64_t value = 7L; diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index e46ab38..7ee9692 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -64,15 +64,27 @@ int OpRescale::checkTensorAttributes() ASSERT_MEM(in && out); - if ((InDtype != DType_INT8) && (InDtype != DType_UINT8) && (attribute->input_zp() != 0)) + if ((InDtype != DType_INT8) && (InDtype != DType_UINT8) && (InDtype != DType_UINT16) && (attribute->input_zp() != 0)) { - printNodeValidationError("OpRescale: Input DType not INT8/UINT8 and zero point not 0"); + printNodeValidationError("OpRescale: Input DType not INT8/UINT8/UINT16 and zero point not 0"); return 1; } - if ((OutDtype != DType_INT8) && (OutDtype != DType_UINT8) && (attribute->output_zp() != 0)) + if ((OutDtype != DType_INT8) && (OutDtype != DType_UINT8) && (OutDtype != DType_UINT16) && (attribute->output_zp() != 0)) { - printNodeValidationError("OpRescale: Output DType not INT8/UINT8 and zero point not 0"); + printNodeValidationError("OpRescale: Output DType not INT8/UINT8/UINT16 and zero point not 0"); + return 1; + } + + if ((InDtype == DType_UINT16) && ((attribute->input_zp() != 0) && (attribute->input_zp() != 32768))) + { + printNodeValidationError("OpRescale: Input DType UINT16 and zero point not 0 or 32768"); + return 1; + } + + if ((OutDtype == DType_UINT16) && ((attribute->output_zp() != 0) && (attribute->output_zp() != 32768))) + { + printNodeValidationError("OpRescale: Output DType UINT16 and zero point not 0 or 32768"); return 1; } @@ -329,4 +341,8 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT16, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, UINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT16); diff --git a/reference_model/src/quant_util.h b/reference_model/src/quant_util.h index 8c1b391..3b7674d 100644 --- a/reference_model/src/quant_util.h +++ b/reference_model/src/quant_util.h @@ -114,7 +114,7 @@ public: static bool is_integer(DType dtype) { if (dtype == DType_INT4 || dtype == DType_INT8 || dtype == DType_UINT8 || dtype == DType_INT16 || - dtype == DType_INT32 || dtype == DType_INT48) + dtype == DType_UINT16 || dtype == DType_INT32 || dtype == DType_INT48) { return true; } diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index f2a3a98..36ace48 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -102,6 +102,7 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) case DType_INT4: case DType_INT8: case DType_INT16: + case DType_UINT16: i32databuf = (int32_t*)calloc(sizeof(int32_t), elements); ASSERT_MEM(i32databuf); @@ -157,6 +158,7 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) case DType_INT4: case DType_INT8: case DType_INT16: + case DType_UINT16: if (setTensorValueInt32(elements, i32databuf)) { free(i32databuf); @@ -225,6 +227,7 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const case DType_INT4: case DType_INT8: case DType_INT16: + case DType_UINT16: i32databuf = (int32_t*)calloc(sizeof(int32_t), elements); ASSERT_MEM(i32databuf); diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h index d857dc8..ede42a9 100644 --- a/reference_model/src/tensor.h +++ b/reference_model/src/tensor.h @@ -656,6 +656,7 @@ public: case DType_INT4: case DType_INT8: case DType_INT16: + case DType_UINT16: switch (rank) { case 0: diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib index 9b22517..4102773 160000 --- a/thirdparty/serialization_lib +++ b/thirdparty/serialization_lib @@ -1 +1 @@ -Subproject commit 9b22517ba0cd6f767123583ce56e864f50e9d758 +Subproject commit 4102773d83e236448130b43b1747621ace00160f diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index b1f8942..a741efb 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -1349,29 +1349,58 @@ class TosaArgGen: arg_list = [] # Enumerate the output types here - for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]: + for outDtype in [ + DType.UINT8, + DType.INT8, + DType.INT16, + DType.INT32, + DType.UINT16, + ]: if ( - dtype in [DType.UINT8, DType.INT8] + outDtype in [DType.UINT8, DType.INT8, DType.UINT16] and error_name == ErrorIf.OutputZeroPointNotZero ): continue + if ( + outDtype != DType.UINT16 + and error_name == ErrorIf.U16OutputZeroPointNotValid + ) or ( + inDtype != DType.UINT16 + and error_name == ErrorIf.U16InputZeroPointNotValid + ): + # ErrorIfs only valid with UINT16 + continue if ( inDtype == DType.UINT8 - and dtype != DType.INT8 + and outDtype not in [DType.INT8, DType.INT16] + and error_name != ErrorIf.WrongOutputType + ): + # The only output dtypes for UINT8 are INT8/INT16, skip all others + continue + if ( + inDtype not in [DType.INT8, DType.INT16] + and outDtype == DType.UINT8 + and error_name != ErrorIf.WrongOutputType + ): + # The only input dtypes for UINT8 are INT8/INT16, skip all others + continue + if ( + inDtype == DType.UINT16 + and outDtype != DType.INT16 and error_name != ErrorIf.WrongOutputType ): - # The only output dtype for UINT8 is INT8, skip all other combinations + # The only output dtype for UINT16 is INT16, skip all others continue if ( - inDtype != DType.INT8 - and dtype == DType.UINT8 + inDtype != DType.INT16 + and outDtype == DType.UINT16 and error_name != ErrorIf.WrongOutputType ): - # The only input dtype for UINT8 is INT8, skip all other combinations + # The only input dtype for UINT16 is INT16, skip all others continue if ( error_name == ErrorIf.WrongOutputType - and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, dtype) + and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype) ): continue @@ -1403,12 +1432,12 @@ class TosaArgGen: arg_list.append( ( "out{}_sc{}_dr{}_pc{}".format( - DTypeNames[dtype], + DTypeNames[outDtype], int(scale32), int(double_round), int(per_channel), ), - [dtype, scale32, double_round, per_channel], + [outDtype, scale32, double_round, per_channel], ) ) diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index e7e758f..1900d8a 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -68,6 +68,8 @@ class ErrorIf(object): InputListBodyGraphInputMismatch = "InputListBodyGraphInputMismatch" InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch" CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool" + U16InputZeroPointNotValid = "U16InputZeroPointNotValid" + U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid" class TosaErrorIfArgGen: @@ -227,14 +229,26 @@ class TosaErrorIfArgGen: if input_dtype == DType.INT8: if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]: return True - if input_dtype in [DType.INT16, DType.INT32]: + elif input_dtype == DType.INT16: + if output_dtype not in [ + DType.UINT8, + DType.INT8, + DType.UINT16, + DType.INT16, + DType.INT32, + ]: + return True + elif input_dtype == DType.INT32: if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]: return True elif input_dtype == DType.INT48: if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]: return True elif input_dtype == DType.UINT8: - if output_dtype != DType.INT8: + if output_dtype not in [DType.INT8, DType.INT16]: + return True + elif input_dtype == DType.UINT16: + if output_dtype != DType.INT16: return True return False @@ -418,23 +432,9 @@ class TosaErrorValidator: error_result = True elif op["op"] == Op.RESCALE: - if input_dtype == DType.INT8: - if output_dtype not in [ - DType.UINT8, - DType.INT8, - DType.INT16, - DType.INT32, - ]: - error_result = True - if input_dtype in [DType.INT16, DType.INT32]: - if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]: - error_result = True - elif input_dtype == DType.INT48: - if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]: - error_result = True - elif input_dtype == DType.UINT8: - if output_dtype != DType.INT8: - error_result = True + error_result = TosaErrorIfArgGen.eiRescaleWrongOutputType( + input_dtype, output_dtype + ) elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]: if ( @@ -997,13 +997,26 @@ class TosaErrorValidator: } return info_dict + @staticmethod + def _getZeroPoint(qinfo, index): + """Return zero point value from quantization info. + + Generally input_zp is index 0, output_zp is index 1 + """ + if isinstance(qinfo, tuple): + zero_point = qinfo[index] + else: + # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp + zero_point = qinfo.ints[index][1] + return zero_point + @staticmethod def evInputZeroPointNotZero(check=False, **kwargs): op = kwargs["op"] error_result = False # Quantizable types - qTypes = (DType.INT8, DType.UINT8) + qTypes = (DType.INT8, DType.UINT8, DType.UINT16) # This does not apply to quantizable types inputDtypes = [ @@ -1015,19 +1028,12 @@ class TosaErrorValidator: if check: input_dtype = kwargs["input_dtype"] - if isinstance(kwargs["qinfo"], tuple): - qinfo = kwargs["qinfo"] - input_zero_point = qinfo[0] - else: - # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp - qinfo = kwargs["qinfo"].ints - input_zero_point = qinfo[0][1] - + input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0) if op["op"] == Op.MATMUL: - qinfo = kwargs["qinfo"].ints + input2_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1) for dtype, zp in ( - (kwargs["input_dtype"], qinfo[0][1]), - (kwargs["input2_dtype"], qinfo[1][1]), + (kwargs["input_dtype"], input_zero_point), + (kwargs["input2_dtype"], input2_zero_point), ): if dtype not in qTypes and zp != 0: error_result = True @@ -1059,9 +1065,7 @@ class TosaErrorValidator: if check: weight_dtype = kwargs["weight_dtype"] - # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = weight_zp - qinfo = kwargs["qinfo"].ints - weight_zero_point = qinfo[1][1] + weight_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1) if weight_dtype != DType.INT8 and weight_zero_point != 0: error_result = True @@ -1076,11 +1080,9 @@ class TosaErrorValidator: @staticmethod def evOutputZeroPointNotZero(check=False, **kwargs): op = kwargs["op"] - inputDtypes = op["types"].copy() - if DType.INT8 in inputDtypes: - inputDtypes.remove(DType.INT8) - if DType.UINT8 in inputDtypes: - inputDtypes.remove(DType.UINT8) + inputDtypes = [ + t for t in op["types"] if t not in [DType.INT8, DType.UINT8, DType.UINT16] + ] error_name = ErrorIf.OutputZeroPointNotZero param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None} @@ -1090,18 +1092,13 @@ class TosaErrorValidator: if check: input_dtype = kwargs["input_dtype"] output_dtype = kwargs["output_dtype"] - if isinstance(kwargs["qinfo"], tuple): - qinfo = kwargs["qinfo"] - output_zero_point = qinfo[1] - else: - # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp - qinfo = kwargs["qinfo"].ints - output_zero_point = qinfo[1][1] + output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1) if op["op"] == Op.AVG_POOL2D: if input_dtype != DType.INT8 and output_zero_point != 0: error_result = True elif ( - output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0 + output_dtype not in [DType.INT8, DType.UINT8, DType.UINT16] + and output_zero_point != 0 ): error_result = True @@ -1113,6 +1110,53 @@ class TosaErrorValidator: } return info_dict + @staticmethod + def evU16InputZeroPointNotValid(check=False, **kwargs): + error_name = ErrorIf.U16InputZeroPointNotValid + param_reqs = {"rank": None, "dtype": [DType.UINT16], "shape": None} + error_result = False + error_reason = "Input DType is UINT16 and zero point not 0 or 32678" + + if check: + input_dtype = kwargs["input_dtype"] + input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0) + error_result = input_dtype == DType.UINT16 and input_zero_point not in [ + 0, + 32768, + ] + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs, + } + return info_dict + + @staticmethod + def evU16OutputZeroPointNotValid(check=False, **kwargs): + error_name = ErrorIf.U16OutputZeroPointNotValid + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "Output DType is UINT16 and zero point not 0 or 32678" + + if check: + output_dtype = kwargs["output_dtype"] + output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1) + + error_result = output_dtype == DType.UINT16 and output_zero_point not in [ + 0, + 32768, + ] + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs, + } + return info_dict + @staticmethod def evAxisSmallerZero(check=False, **kwargs): error_name = ErrorIf.AxisSmallerZero diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 7c2b9de..c9c6d7e 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -70,6 +70,8 @@ class TosaTestGen: return np.int32(self.rng.integers(low=0, high=256, size=shape)) elif dtype == DType.INT16: return np.int32(self.rng.integers(low=-32768, high=32768, size=shape)) + elif dtype == DType.UINT16: + return np.int32(self.rng.integers(low=0, high=65536, size=shape)) elif dtype == DType.INT32: return np.int32( self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape) @@ -169,6 +171,8 @@ class TosaTestGen: return "u8" elif t == DType.INT16: return "i16" + elif t == DType.UINT16: + return "u16" elif t == DType.INT32: return "i32" elif t == DType.INT48: @@ -188,6 +192,8 @@ class TosaTestGen: return 8 elif t == DType.INT16: return 16 + elif t == DType.UINT16: + return 16 elif t == DType.INT32: return 32 elif t == DType.INT48: @@ -1575,29 +1581,43 @@ class TosaTestGen: if val.dtype == DType.INT8: input_zp = self.randInt(-128, 128) - in_type_width = in_type_width + 1 + in_type_width += 1 elif val.dtype == DType.UINT8: input_zp = self.randInt(0, 256) - in_type_width = in_type_width + 1 - elif error_name == ErrorIf.InputZeroPointNotZero: + in_type_width += 1 + elif error_name in [ + ErrorIf.InputZeroPointNotZero, + ErrorIf.U16InputZeroPointNotValid, + ]: input_zp = self.randInt(-128, 128) if input_zp == 0: input_zp = input_zp + self.rng.integers(1, 10) - in_type_width = in_type_width + 1 + in_type_width += 1 + elif val.dtype == DType.UINT16: + # Must come after ErrorIf.U16InputZeroPointNotValid check + input_zp = self.rng.choice([0, 32768]) + in_type_width += 1 else: input_zp = 0 if out_dtype == DType.INT8: output_zp = self.randInt(-128, 128) - out_type_width = out_type_width + 1 + out_type_width += 1 elif out_dtype == DType.UINT8: output_zp = self.randInt(0, 256) - out_type_width = out_type_width + 1 - elif error_name == ErrorIf.OutputZeroPointNotZero: + out_type_width += 1 + elif error_name in [ + ErrorIf.OutputZeroPointNotZero, + ErrorIf.U16OutputZeroPointNotValid, + ]: output_zp = self.randInt(-128, 128) if output_zp == 0: output_zp = output_zp + self.rng.integers(1, 10) - out_type_width = out_type_width + 1 + out_type_width += 1 + elif out_dtype == DType.UINT16: + # Must come after ErrorIf.U16OutputZeroPointNotValid check + output_zp = self.rng.choice([0, 32768]) + out_type_width += 1 else: output_zp = 0 @@ -1631,7 +1651,7 @@ class TosaTestGen: # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp)) if scale32 and error_name is None: - # Make sure random values are within apply_scale_32 speicification + # Make sure random values are within apply_scale_32 specification # REQUIRES(value >= (-1<<(shift-2)) && value < (1<<(shift-2)) assert val.placeholderFilename values = np.load( @@ -3642,10 +3662,19 @@ class TosaTestGen: TosaTensorValuesGen.tvgDefault, TosaArgGen.agRescale, ), - "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48], + "types": [ + DType.UINT8, + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.UINT16, + ], "error_if_validators": ( TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero, + TosaErrorValidator.evU16InputZeroPointNotValid, + TosaErrorValidator.evU16OutputZeroPointNotValid, TosaErrorValidator.evScaleTrue, TosaErrorValidator.evScaleNotTrue, TosaErrorValidator.evWrongInputType, diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py index ca115a2..a4ef31a 100644 --- a/verif/generator/tosa_utils.py +++ b/verif/generator/tosa_utils.py @@ -59,9 +59,11 @@ def allDTypes(*, excludes=None): def usableDTypes(*, excludes=None): """Get a set of usable DType values, optionally excluding some values. - Excludes (DType.UNKNOWN, DType.UINT8) in addition to the excludes - specified by the caller, as the serializer lib does not support them. - If you wish to include 'UNKNOWN' or 'UINT8' use allDTypes instead. + Excludes uncommon types (DType.UNKNOWN, DType.UINT16, DType.UINT8) in + addition to the excludes specified by the caller, as the serializer lib + does not support them. + If you wish to include 'UNKNOWN', 'UINT8' or 'UINT16' use allDTypes + instead. Args: excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL]) @@ -69,7 +71,7 @@ def usableDTypes(*, excludes=None): Returns: A set of DType values """ - omit = {DType.UNKNOWN, DType.UINT8} + omit = {DType.UNKNOWN, DType.UINT8, DType.UINT16} omit.update(excludes if excludes else ()) return allDTypes(excludes=omit) -- cgit v1.2.1