aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py162
1 files changed, 98 insertions, 64 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 32f4341..94b7172 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -628,6 +628,13 @@ class TosaTensorValuesGen:
return tens
+ # Default high value for random numbers
+ TVG_FLOAT_HIGH_VALUE = {
+ DType.FP32: (1 << 128) - (1 << (127 - 23)),
+ DType.FP16: (1 << 16) - (1 << (15 - 10)),
+ DType.BF16: (1 << 128) - (1 << (127 - 7)),
+ }
+
@staticmethod
def tvgLazyGenDefault(
testGen, opName, dtypeList, shapeList, argsDict, error_name=None
@@ -684,10 +691,13 @@ class TosaTensorValuesGen:
info = {}
# TODO - generate seed for this generator based on test
info["rng_seed"] = 42
- info["range"] = [
- str(v)
- for v in testGen.getDTypeRange(dtypeList[idx], high_inclusive=True)
- ]
+ if "data_range" in argsDict:
+ data_range = argsDict["data_range"]
+ else:
+ data_range = testGen.getDTypeRange(
+ dtypeList[idx], high_inclusive=True
+ )
+ info["range"] = [str(v) for v in data_range]
tens_meta["pseudo_random_info"] = info
elif dg_type == gtu.DataGenType.DOT_PRODUCT:
info = {}
@@ -950,80 +960,97 @@ class TosaTensorValuesGen:
testGen, op, dtypeList, shapeList, testArgs, error_name
)
+ # Set the data range to the square root of the largest value
+ TVG_FLOAT_HIGH_VALUE_MUL = {
+ DType.FP32: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
+ DType.FP16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
+ DType.BF16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
+ }
+
@staticmethod
- def tvgMul(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
- if error_name is None:
+ def tvgMul(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
+ if error_name is not None or dtypeList[0] in (
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
+ ):
+ # ERROR_IF or floating point test
+ if dtypeList[0] in TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL:
+ data_range = testGen.getDTypeRange(dtypeList[0], high_inclusive=True)
+ high_val = TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL[dtypeList[0]]
+ # Set the values to something that won't produce infinity whilst
+ # respecting the default ranges if less than the high value
+ argsDict["data_range"] = [
+ max(-high_val, data_range[0]),
+ min(high_val, data_range[1]),
+ ]
+ return TosaTensorValuesGen.tvgLazyGenDefault(
+ testGen, opName, dtypeList, shapeList, argsDict, error_name
+ )
+ else:
+ # Integer test
+ op = testGen.TOSA_OP_LIST[opName]
pCount, cCount = op["operands"]
assert (
pCount == 2 and cCount == 0
), "Op.MUL must have 2 placeholders, 0 consts"
- tens = []
- if dtypeList[0] in (DType.FP16, DType.BF16, DType.FP32):
- tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
- else:
- placeholders = []
-
- # Make sure multiply result in int32 range
- shift = testArgs[0]
- if dtypeList[0] == DType.INT8:
- num_bits = 8
- elif dtypeList[0] == DType.INT16:
- num_bits = 16
- elif dtypeList[0] == DType.INT32:
- num_bits = 32
- elif error_name == ErrorIf.WrongInputType:
- num_bits = 8
- else:
- raise Exception("OpMul: invalid input dtype")
+ tens_ser_list = []
- for idx, shape in enumerate(shapeList[:]):
- low = -(2 ** (num_bits - 1))
- high = (2 ** (num_bits - 1)) - 1
+ # Make sure multiply result in int32 range
+ shift = argsDict["shift"]
+ if dtypeList[0] == DType.INT8:
+ num_bits = 8
+ elif dtypeList[0] == DType.INT16:
+ num_bits = 16
+ elif dtypeList[0] == DType.INT32:
+ num_bits = 32
+ elif error_name == ErrorIf.WrongInputType:
+ num_bits = 8
+ else:
+ raise Exception("OpMul: invalid input dtype")
- a_arr = np.int32(
- testGen.rng.integers(low=low, high=high, size=shapeList[0])
- )
- b_arr = np.int32(
- testGen.rng.integers(low=low, high=high, size=shapeList[1])
- )
+ for idx, shape in enumerate(shapeList[:]):
+ low = -(2 ** (num_bits - 1))
+ high = (2 ** (num_bits - 1)) - 1
- i = 0
- while True:
+ a_arr = np.int32(
+ testGen.rng.integers(low=low, high=high, size=shapeList[0])
+ )
+ b_arr = np.int32(
+ testGen.rng.integers(low=low, high=high, size=shapeList[1])
+ )
- a_arr_64 = a_arr.astype(np.int64)
- b_arr_64 = b_arr.astype(np.int64)
+ i = 0
+ while True:
- if shift > 0:
- rounding = 1 << (shift - 1)
- result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
- else:
- result_arr = a_arr_64 * b_arr_64
+ a_arr_64 = a_arr.astype(np.int64)
+ b_arr_64 = b_arr.astype(np.int64)
- if (result_arr > -(2**31)).all() and (
- result_arr <= ((2**31) - 1)
- ).all():
- break
-
- i = i + 1
- a_arr = a_arr // 2
- b_arr = b_arr // 2
+ if shift > 0:
+ rounding = 1 << (shift - 1)
+ result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
+ else:
+ result_arr = a_arr_64 * b_arr_64
- placeholders.append(
- testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
- )
- placeholders.append(
- testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
- )
+ if (result_arr > -(2**31)).all() and (
+ result_arr <= ((2**31) - 1)
+ ).all():
+ break
- tens.extend(placeholders)
+ i = i + 1
+ a_arr = a_arr // 2
+ b_arr = b_arr // 2
- return tens
- else:
- return TosaTensorValuesGen.tvgDefault(
- testGen, op, dtypeList, shapeList, testArgs, error_name
+ tens_ser_list.append(
+ testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
+ )
+ tens_ser_list.append(
+ testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
)
+ return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
+
@staticmethod
def tvgConcat(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
count = len(shapeList) - testGen.args.num_const_inputs_concat
@@ -2076,11 +2103,18 @@ class TosaArgGen:
for p in range(testGen.args.num_rand_permutations):
shift = testGen.randInt(0, 32)
-
- arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
+ arg_list.append(("perm{}_shift{}".format(p, shift), {"shift": shift}))
else:
- arg_list.append(("perm0_shift0", [0]))
+ arg_list.append(("perm0_shift0", {"shift": 0}))
+ arg_list = TosaArgGen._add_data_generators(
+ testGen,
+ opName,
+ dtype,
+ arg_list,
+ error_name,
+ )
+ # Return list of tuples: (arg_str, args_dict)
return arg_list
@staticmethod