diff options
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 162 |
1 files changed, 98 insertions, 64 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 32f4341..94b7172 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -628,6 +628,13 @@ class TosaTensorValuesGen: return tens + # Default high value for random numbers + TVG_FLOAT_HIGH_VALUE = { + DType.FP32: (1 << 128) - (1 << (127 - 23)), + DType.FP16: (1 << 16) - (1 << (15 - 10)), + DType.BF16: (1 << 128) - (1 << (127 - 7)), + } + @staticmethod def tvgLazyGenDefault( testGen, opName, dtypeList, shapeList, argsDict, error_name=None @@ -684,10 +691,13 @@ class TosaTensorValuesGen: info = {} # TODO - generate seed for this generator based on test info["rng_seed"] = 42 - info["range"] = [ - str(v) - for v in testGen.getDTypeRange(dtypeList[idx], high_inclusive=True) - ] + if "data_range" in argsDict: + data_range = argsDict["data_range"] + else: + data_range = testGen.getDTypeRange( + dtypeList[idx], high_inclusive=True + ) + info["range"] = [str(v) for v in data_range] tens_meta["pseudo_random_info"] = info elif dg_type == gtu.DataGenType.DOT_PRODUCT: info = {} @@ -950,80 +960,97 @@ class TosaTensorValuesGen: testGen, op, dtypeList, shapeList, testArgs, error_name ) + # Set the data range to the square root of the largest value + TVG_FLOAT_HIGH_VALUE_MUL = { + DType.FP32: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP32]), + DType.FP16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP16]), + DType.BF16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.BF16]), + } + @staticmethod - def tvgMul(testGen, op, dtypeList, shapeList, testArgs, error_name=None): - if error_name is None: + def tvgMul(testGen, opName, dtypeList, shapeList, argsDict, error_name=None): + if error_name is not None or dtypeList[0] in ( + DType.FP16, + DType.BF16, + DType.FP32, + ): + # ERROR_IF or floating point test + if dtypeList[0] in TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL: + data_range = testGen.getDTypeRange(dtypeList[0], high_inclusive=True) + high_val = TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL[dtypeList[0]] + # Set the values to something that won't produce infinity whilst + # respecting the default ranges if less than the high value + argsDict["data_range"] = [ + max(-high_val, data_range[0]), + min(high_val, data_range[1]), + ] + return TosaTensorValuesGen.tvgLazyGenDefault( + testGen, opName, dtypeList, shapeList, argsDict, error_name + ) + else: + # Integer test + op = testGen.TOSA_OP_LIST[opName] pCount, cCount = op["operands"] assert ( pCount == 2 and cCount == 0 ), "Op.MUL must have 2 placeholders, 0 consts" - tens = [] - if dtypeList[0] in (DType.FP16, DType.BF16, DType.FP32): - tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:])) - else: - placeholders = [] - - # Make sure multiply result in int32 range - shift = testArgs[0] - if dtypeList[0] == DType.INT8: - num_bits = 8 - elif dtypeList[0] == DType.INT16: - num_bits = 16 - elif dtypeList[0] == DType.INT32: - num_bits = 32 - elif error_name == ErrorIf.WrongInputType: - num_bits = 8 - else: - raise Exception("OpMul: invalid input dtype") + tens_ser_list = [] - for idx, shape in enumerate(shapeList[:]): - low = -(2 ** (num_bits - 1)) - high = (2 ** (num_bits - 1)) - 1 + # Make sure multiply result in int32 range + shift = argsDict["shift"] + if dtypeList[0] == DType.INT8: + num_bits = 8 + elif dtypeList[0] == DType.INT16: + num_bits = 16 + elif dtypeList[0] == DType.INT32: + num_bits = 32 + elif error_name == ErrorIf.WrongInputType: + num_bits = 8 + else: + raise Exception("OpMul: invalid input dtype") - a_arr = np.int32( - testGen.rng.integers(low=low, high=high, size=shapeList[0]) - ) - b_arr = np.int32( - testGen.rng.integers(low=low, high=high, size=shapeList[1]) - ) + for idx, shape in enumerate(shapeList[:]): + low = -(2 ** (num_bits - 1)) + high = (2 ** (num_bits - 1)) - 1 - i = 0 - while True: + a_arr = np.int32( + testGen.rng.integers(low=low, high=high, size=shapeList[0]) + ) + b_arr = np.int32( + testGen.rng.integers(low=low, high=high, size=shapeList[1]) + ) - a_arr_64 = a_arr.astype(np.int64) - b_arr_64 = b_arr.astype(np.int64) + i = 0 + while True: - if shift > 0: - rounding = 1 << (shift - 1) - result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift - else: - result_arr = a_arr_64 * b_arr_64 + a_arr_64 = a_arr.astype(np.int64) + b_arr_64 = b_arr.astype(np.int64) - if (result_arr > -(2**31)).all() and ( - result_arr <= ((2**31) - 1) - ).all(): - break - - i = i + 1 - a_arr = a_arr // 2 - b_arr = b_arr // 2 + if shift > 0: + rounding = 1 << (shift - 1) + result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift + else: + result_arr = a_arr_64 * b_arr_64 - placeholders.append( - testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr) - ) - placeholders.append( - testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr) - ) + if (result_arr > -(2**31)).all() and ( + result_arr <= ((2**31) - 1) + ).all(): + break - tens.extend(placeholders) + i = i + 1 + a_arr = a_arr // 2 + b_arr = b_arr // 2 - return tens - else: - return TosaTensorValuesGen.tvgDefault( - testGen, op, dtypeList, shapeList, testArgs, error_name + tens_ser_list.append( + testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr) + ) + tens_ser_list.append( + testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr) ) + return TosaTensorValuesGen.TVGInfo(tens_ser_list, None) + @staticmethod def tvgConcat(testGen, op, dtypeList, shapeList, testArgs, error_name=None): count = len(shapeList) - testGen.args.num_const_inputs_concat @@ -2076,11 +2103,18 @@ class TosaArgGen: for p in range(testGen.args.num_rand_permutations): shift = testGen.randInt(0, 32) - - arg_list.append(("perm{}_shift{}".format(p, shift), [shift])) + arg_list.append(("perm{}_shift{}".format(p, shift), {"shift": shift})) else: - arg_list.append(("perm0_shift0", [0])) + arg_list.append(("perm0_shift0", {"shift": 0})) + arg_list = TosaArgGen._add_data_generators( + testGen, + opName, + dtype, + arg_list, + error_name, + ) + # Return list of tuples: (arg_str, args_dict) return arg_list @staticmethod |