aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py56
1 files changed, 49 insertions, 7 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index cbfffae..c596645 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -828,6 +828,12 @@ class TosaTensorValuesGen:
if "axis" in argsDict:
info["axis"] = int(argsDict["axis"])
tens_meta["dot_product_info"] = info
+ elif dg_type == gtu.DataGenType.FULL_RANGE:
+ info = {}
+ info["start_val"] = int(
+ testGen.randInt(0, gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["fullset"])
+ )
+ tens_meta["full_range_info"] = info
else:
# TODO - other data gen type
assert False, "TODO: support other data gen types"
@@ -1795,7 +1801,7 @@ class TosaArgGen:
pass
@staticmethod
- def _add_data_generators(testGen, opName, dtype, arg_list, error_name):
+ def _add_data_generators(testGen, opName, shapeList, dtype, arg_list, error_name):
"""Add extra tests for each type of data generator for this op."""
if (
error_name is None
@@ -1820,7 +1826,16 @@ class TosaArgGen:
new_arg_list = []
for dg_type in dataGenTypesList:
for arg_str, args_dict in arg_list:
- args_dict["dg_type"] = dg_type
+
+ if dg_type == gtu.DataGenType.FULL_RANGE:
+ tensor_size = gtu.product(shapeList[0])
+ if tensor_size >= gtu.DTYPE_ATTRIBUTES[dtype]["fullset"]:
+ # Large enough tensor data size for full range, add a single test
+ num_test_sets = 0
+ else:
+ # Not enough data size for full range of values, revert to random numbers
+ dg_type = gtu.DataGenType.PSEUDO_RANDOM
+
if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
if error_name is None:
num_test_sets = (
@@ -1829,6 +1844,7 @@ class TosaArgGen:
else 0
)
else:
+ # Add single test for pseudo random
num_test_sets = 0
elif dg_type == gtu.DataGenType.DOT_PRODUCT:
@@ -1852,13 +1868,16 @@ class TosaArgGen:
if num_test_sets > 0:
for s in range(0, num_test_sets):
- new_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}"
- new_args_dict = args_dict.copy()
- new_args_dict["s"] = s
- new_arg_list.append((new_arg_str, new_args_dict))
+ set_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}"
+ set_args_dict = args_dict.copy()
+ set_args_dict["s"] = s
+ set_args_dict["dg_type"] = dg_type
+ new_arg_list.append((set_arg_str, set_args_dict))
else:
# Default is a single test
- new_arg_list.append((arg_str, args_dict))
+ new_args_dict = args_dict.copy()
+ new_args_dict["dg_type"] = dg_type
+ new_arg_list.append((arg_str, new_args_dict))
return new_arg_list
@@ -1869,6 +1888,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
[("", {})],
error_name,
@@ -1883,6 +1903,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
[("", {"num_test_sets": 3})],
error_name,
@@ -1921,6 +1942,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -2160,6 +2182,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtypes[0],
arg_list,
error_name,
@@ -2194,6 +2217,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
input_dtype,
arg_list,
error_name,
@@ -2246,6 +2270,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -2402,6 +2427,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtypes[0],
arg_list,
error_name,
@@ -2482,6 +2508,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -2685,6 +2712,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -2774,6 +2802,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -2925,6 +2954,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
inDtype,
arg_list,
error_name,
@@ -2947,6 +2977,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -2967,6 +2998,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -2994,6 +3026,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -3019,6 +3052,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -3091,6 +3125,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -3137,6 +3172,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -3179,6 +3215,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -3214,6 +3251,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -3547,6 +3585,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -3586,6 +3625,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -3606,6 +3646,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -3624,6 +3665,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,