aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py35
1 files changed, 25 insertions, 10 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 8d6c8d7..5957a33 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -264,6 +264,9 @@ class TosaTensorGen:
return [[]] * num_shapes
shape = testGen.makeShape(rng, rank)
+ # Do not broadcast for some tests
+ if error_name is None and rng.randInt(high=100) < 10:
+ return [shape] * num_shapes
shape_list = []
# Choose any one of the inputs to broadcast
@@ -785,6 +788,10 @@ class TosaTensorValuesGen:
"tensors": {},
}
dg_tens_meta = tens_data["tensors"]
+
+ fp_special_info = {}
+ fp_special_info["start_idx"] = int(rng.randInt())
+
for idx, shape in enumerate(shapeList):
tens_meta = {}
@@ -858,6 +865,8 @@ class TosaTensorValuesGen:
rng.randInt(0, gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["fullset"])
)
tens_meta["full_range_info"] = info
+ elif dg_type == gtu.DataGenType.FP_SPECIAL:
+ tens_meta["fp_special_info"] = fp_special_info
else:
# TODO - other data gen type
assert False, "TODO: support other data gen types"
@@ -1862,16 +1871,12 @@ class TosaArgGen:
for dg_type in dataGenTypesList:
for arg_str, args_dict in arg_list:
gen_args_dict = args_dict.copy()
+ # Only create one test by default - no sets of tests
+ num_test_sets = 0
+
if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
if error_name is None:
- num_test_sets = (
- args_dict["num_test_sets"]
- if "num_test_sets" in args_dict
- else 0
- )
- else:
- # Add single test for pseudo random
- num_test_sets = 0
+ num_test_sets = args_dict.get("num_test_sets", 0)
elif dg_type == gtu.DataGenType.DOT_PRODUCT:
# Extra tests for each dot product test set
@@ -1900,13 +1905,23 @@ class TosaArgGen:
f"Skipping {opName}{shape_info} as tensor data size too small for full range of values {tensor_size} < {gtu.DTYPE_ATTRIBUTES[dtype]['fullset']}"
)
continue
- # Large enough tensor data size for full range, add a single test
- num_test_sets = 0
+ # Large enough tensor data size for full range, add full test
arg_str = f"{arg_str}_full" if arg_str else "full"
gen_args_dict["tags"] = args_dict.get("tags", []) + [
"non_finite_fp_data"
]
+ elif dg_type == gtu.DataGenType.FP_SPECIAL:
+ shapes_set = {tuple(x) for x in shapeList}
+ if len(shapes_set) != 1:
+ logger.info(
+ f"Changing {opName} input shapes {shapes_set} - broadcasting incompatable with special test"
+ )
+ shapeList = [np.int32(np.broadcast_shapes(*shapeList))] * len(
+ shapeList
+ )
+ arg_str = f"{arg_str}_fs" if arg_str else "fs"
+
gen_args_dict["dg_type"] = dg_type
if num_test_sets > 0:
for s in range(0, num_test_sets):