From 4a2051146f498cb9ec35d7213720540c5c3e81e2 Mon Sep 17 00:00:00 2001 From: evacha01 Date: Fri, 8 Mar 2024 16:39:24 +0000 Subject: SPECIAL data gen mode for FP16 and FP32 Signed-off-by: evacha01 Change-Id: I5a9a1c63345bd83ca04bc6c2a99b0ef3612971ee --- verif/generator/tosa_arg_gen.py | 35 +++++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 10 deletions(-) (limited to 'verif/generator/tosa_arg_gen.py') diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 8d6c8d7..5957a33 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -264,6 +264,9 @@ class TosaTensorGen: return [[]] * num_shapes shape = testGen.makeShape(rng, rank) + # Do not broadcast for some tests + if error_name is None and rng.randInt(high=100) < 10: + return [shape] * num_shapes shape_list = [] # Choose any one of the inputs to broadcast @@ -785,6 +788,10 @@ class TosaTensorValuesGen: "tensors": {}, } dg_tens_meta = tens_data["tensors"] + + fp_special_info = {} + fp_special_info["start_idx"] = int(rng.randInt()) + for idx, shape in enumerate(shapeList): tens_meta = {} @@ -858,6 +865,8 @@ class TosaTensorValuesGen: rng.randInt(0, gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["fullset"]) ) tens_meta["full_range_info"] = info + elif dg_type == gtu.DataGenType.FP_SPECIAL: + tens_meta["fp_special_info"] = fp_special_info else: # TODO - other data gen type assert False, "TODO: support other data gen types" @@ -1862,16 +1871,12 @@ class TosaArgGen: for dg_type in dataGenTypesList: for arg_str, args_dict in arg_list: gen_args_dict = args_dict.copy() + # Only create one test by default - no sets of tests + num_test_sets = 0 + if dg_type == gtu.DataGenType.PSEUDO_RANDOM: if error_name is None: - num_test_sets = ( - args_dict["num_test_sets"] - if "num_test_sets" in args_dict - else 0 - ) - else: - # Add single test for pseudo random - num_test_sets = 0 + num_test_sets = args_dict.get("num_test_sets", 0) elif dg_type == gtu.DataGenType.DOT_PRODUCT: # Extra tests for each dot product test set @@ -1900,13 +1905,23 @@ class TosaArgGen: f"Skipping {opName}{shape_info} as tensor data size too small for full range of values {tensor_size} < {gtu.DTYPE_ATTRIBUTES[dtype]['fullset']}" ) continue - # Large enough tensor data size for full range, add a single test - num_test_sets = 0 + # Large enough tensor data size for full range, add full test arg_str = f"{arg_str}_full" if arg_str else "full" gen_args_dict["tags"] = args_dict.get("tags", []) + [ "non_finite_fp_data" ] + elif dg_type == gtu.DataGenType.FP_SPECIAL: + shapes_set = {tuple(x) for x in shapeList} + if len(shapes_set) != 1: + logger.info( + f"Changing {opName} input shapes {shapes_set} - broadcasting incompatable with special test" + ) + shapeList = [np.int32(np.broadcast_shapes(*shapeList))] * len( + shapeList + ) + arg_str = f"{arg_str}_fs" if arg_str else "fs" + gen_args_dict["dg_type"] = dg_type if num_test_sets > 0: for s in range(0, num_test_sets): -- cgit v1.2.1