diff options
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 42 |
1 files changed, 23 insertions, 19 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 79d4e78..f9499b5 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -1831,10 +1831,10 @@ class TosaArgGen: and "data_gen" in testGen.TOSA_OP_LIST[opName] and gtu.dtypeIsSupportedByCompliance(dtype) ): - if gtu.dtypeIsFloat(dtype): - dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"] - else: - dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"] + dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"].get( + dtype, (gtu.DataGenType.PSEUDO_RANDOM,) + ) + else: # Error test or No data generator types listed - assume random dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,) @@ -1843,16 +1843,7 @@ class TosaArgGen: new_arg_list = [] for dg_type in dataGenTypesList: for arg_str, args_dict in arg_list: - - if dg_type == gtu.DataGenType.FULL_RANGE: - tensor_size = gtu.product(shapeList[0]) - if tensor_size >= gtu.DTYPE_ATTRIBUTES[dtype]["fullset"]: - # Large enough tensor data size for full range, add a single test - num_test_sets = 0 - else: - # Not enough data size for full range of values, revert to random numbers - dg_type = gtu.DataGenType.PSEUDO_RANDOM - + gen_args_dict = args_dict.copy() if dg_type == gtu.DataGenType.PSEUDO_RANDOM: if error_name is None: num_test_sets = ( @@ -1883,18 +1874,31 @@ class TosaArgGen: num_test_sets = testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS + elif dg_type == gtu.DataGenType.FULL_RANGE: + tensor_size = gtu.product(shapeList[0]) + if tensor_size < gtu.DTYPE_ATTRIBUTES[dtype]["fullset"]: + shape_info = " ({})".format(shapeList[0]) + logger.info( + f"Skipping {opName}{shape_info} as tensor data size too small for full range of values {tensor_size} < {gtu.DTYPE_ATTRIBUTES[dtype]['fullset']}" + ) + continue + # Large enough tensor data size for full range, add a single test + num_test_sets = 0 + arg_str = f"{arg_str}_full" if arg_str else "full" + gen_args_dict["tags"] = args_dict.get("tags", []) + [ + "non_finite_fp_data" + ] + + gen_args_dict["dg_type"] = dg_type if num_test_sets > 0: for s in range(0, num_test_sets): set_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}" - set_args_dict = args_dict.copy() + set_args_dict = gen_args_dict.copy() set_args_dict["s"] = s - set_args_dict["dg_type"] = dg_type new_arg_list.append((set_arg_str, set_args_dict)) else: # Default is a single test - new_args_dict = args_dict.copy() - new_args_dict["dg_type"] = dg_type - new_arg_list.append((arg_str, new_args_dict)) + new_arg_list.append((arg_str, gen_args_dict)) return new_arg_list |