aboutsummaryrefslogtreecommitdiff
path: root/verif/generator
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator')
-rw-r--r--verif/generator/tosa_arg_gen.py64
-rw-r--r--verif/generator/tosa_test_gen.py60
-rw-r--r--verif/generator/tosa_verif_build_tests.py7
3 files changed, 98 insertions, 33 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 41b0936..6de771d 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1506,6 +1506,21 @@ class TosaTensorValuesGen:
)
@staticmethod
+ def tvgTable(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
+ # Use supported type for table data on ERROR_IF
+ dtypeList[1] = (
+ dtypeList[0] if error_name != ErrorIf.WrongInputType else DType.INT8
+ )
+
+ table_values = argsDict["table"]
+ shapeList[1] = [len(table_values)]
+ argsDict["fixed_data"] = [None, table_values]
+
+ return TosaTensorValuesGen.tvgLazyGenDefault(
+ testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
+ )
+
+ @staticmethod
def tvgResize(
testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
):
@@ -1875,11 +1890,31 @@ class TosaArgGen:
dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"].get(
dtype, (gtu.DataGenType.PSEUDO_RANDOM,)
)
-
else:
# Error test or No data generator types listed - assume random
dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,)
+ def check_min_size(opName, shape, min_size, reason):
+ # Check tensor size meets minimum requirements
+ tensor_size = gtu.product(shape)
+ if tensor_size < min_size:
+ shape_info = " ({})".format(shape)
+ logger.info(
+ f"Skipping {opName}{shape_info} as tensor data size too small for {reason} values {tensor_size} < {min_size}"
+ )
+ return False
+ return True
+
+ def update_data_gen(testGen, opName, dtype, dgt_remove):
+ # Remove special data generator to limit number of tests
+ assert "data_gen" in testGen.TOSA_OP_LIST[opName]
+ assert dtype in testGen.TOSA_OP_LIST[opName]["data_gen"]
+ data_gen = testGen.TOSA_OP_LIST[opName]["data_gen"].copy()
+ dgt_list = list(data_gen[dtype])
+ dgt_list.remove(dgt_remove)
+ data_gen[dtype] = tuple(dgt_list)
+ testGen.TOSA_OP_LIST[opName]["data_gen"] = data_gen
+
# Expand arg list with other data generator types
new_arg_list = []
for dg_type in dataGenTypesList:
@@ -1912,20 +1947,33 @@ class TosaArgGen:
num_test_sets = testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS
elif dg_type == gtu.DataGenType.FULL_RANGE:
- tensor_size = gtu.product(shapeList[0])
- if tensor_size < gtu.DTYPE_ATTRIBUTES[dtype]["fullset"]:
- shape_info = " ({})".format(shapeList[0])
- logger.info(
- f"Skipping {opName}{shape_info} as tensor data size too small for full range of values {tensor_size} < {gtu.DTYPE_ATTRIBUTES[dtype]['fullset']}"
- )
+ if testGen.args.no_special_tests:
+ continue
+ if not check_min_size(
+ opName,
+ shapeList[0],
+ gtu.DTYPE_ATTRIBUTES[dtype]["fullset"],
+ "full range of",
+ ):
continue
# Large enough tensor data size for full range, add full test
arg_str = f"{arg_str}_full" if arg_str else "full"
gen_args_dict["tags"] = args_dict.get("tags", []) + [
"non_finite_fp_data"
]
+ # Create one special test per data type
+ update_data_gen(testGen, opName, dtype, dg_type)
elif dg_type == gtu.DataGenType.FP_SPECIAL:
+ if testGen.args.no_special_tests:
+ continue
+ if not check_min_size(
+ opName,
+ shapeList[0],
+ testGen.TOSA_MI_FP_SPECIAL_MIN_SIZE,
+ "FP special",
+ ):
+ continue
shapes_set = {tuple(x) for x in shapeList}
if len(shapes_set) != 1:
logger.info(
@@ -1938,6 +1986,8 @@ class TosaArgGen:
gen_args_dict["tags"] = args_dict.get("tags", []) + [
"non_finite_fp_data"
]
+ # Create one special test per data type
+ update_data_gen(testGen, opName, dtype, dg_type)
gen_args_dict["dg_type"] = dg_type
if num_test_sets > 0:
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 88dd17a..7f76429 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -44,6 +44,9 @@ class TosaTestGen:
TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
TOSA_MI_DOT_PRODUCT_MIN = 1000
+ # Minimum tensor size for the FP special tests
+ TOSA_MI_FP_SPECIAL_MIN_SIZE = 20
+
def __init__(self, args):
self.args = args
self.basePath = args.output_dir
@@ -260,6 +263,11 @@ class TosaTestGen:
# Data type is needed for all FP runs, as refmodel precise mode produces FP64
"data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
}
+
+ op_compliance = op.get("compliance", {})
+ mode = None
+
+ # Check what data generation we have done
if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
mode = gtu.ComplianceMode.DOT_PRODUCT
compliance_tens["dot_product_info"] = {
@@ -268,12 +276,10 @@ class TosaTestGen:
int(argsDict["ksb"]) if "ksb" in argsDict else int(argsDict["ks"])
),
}
- elif argsDict["dg_type"] == gtu.DataGenType.FP_SPECIAL:
- mode = gtu.ComplianceMode.FP_SPECIAL
- elif "compliance" in op and "ulp" in op["compliance"]:
+ elif "ulp" in op_compliance:
mode = gtu.ComplianceMode.ULP
compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
- elif "compliance" in op and "relative" in op["compliance"]:
+ elif "relative" in op_compliance:
mode = gtu.ComplianceMode.RELATIVE
compliance_tens["relative_info"] = {
"max": argsDict["max_abs_value"],
@@ -284,26 +290,30 @@ class TosaTestGen:
compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
mode = gtu.ComplianceMode.ABS_ERROR
- if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
+ if "abs_error_lower_bound" in op_compliance:
compliance_tens["abs_error_info"] = {
"lower_bound": op["compliance"]["abs_error_lower_bound"]
}
elif op["op"] in (Op.SIN, Op.COS):
mode = gtu.ComplianceMode.ABS_ERROR
- if "compliance" in op:
- normal_divisor = op["compliance"].get("abs_error_normal_divisor", 1)
- bound_addition = op["compliance"].get("abs_error_bound_addition", 0)
- else:
- normal_divisor = 1
- bound_addition = 0
+ normal_divisor = op_compliance.get("abs_error_normal_divisor", 1)
+ bound_addition = op_compliance.get("abs_error_bound_addition", 0)
compliance_tens["abs_error_info"] = {
"normal_divisor": normal_divisor,
"bound_as_magnitude": True,
"bound_addition": bound_addition,
}
+ elif argsDict["dg_type"] == gtu.DataGenType.FP_SPECIAL:
+ if gtu.ComplianceMode.DOT_PRODUCT in op["data_gen"][inputType]:
+ # Use special mode that only checks for matching inf/nan/zeroes
+ # as normal values need statistical analysis
+ mode = gtu.ComplianceMode.FP_SPECIAL
+ else:
+ mode = gtu.ComplianceMode.EXACT
else:
mode = gtu.ComplianceMode.EXACT
+
compliance_tens["mode"] = gtu.ComplianceMode(mode).name
return compliance_tens
@@ -569,16 +579,13 @@ class TosaTestGen:
error_name=None,
qinfo=None,
):
- assert len(inputs) == 1
+ assert len(inputs) == 2
a = inputs[0]
- table = args_dict["table"]
+ table = inputs[1]
result_tensor = OutputShaper.tableOp(self.ser, rng, a, error_name)
- attr = ts.TosaSerializerAttribute()
- attr.TableAttribute(table)
-
# Invalidate Input/Output list for error if checks.
- input_list = [a.name]
+ input_list = [a.name, table.name]
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
@@ -601,7 +608,7 @@ class TosaTestGen:
):
return None
- self.ser.addOperator(op["op"], input_list, output_list, attr)
+ self.ser.addOperator(op["op"], input_list, output_list)
compliance = self.tensorComplianceMetaData(
op, a.dtype, args_dict, result_tensor, error_name
@@ -3360,6 +3367,7 @@ class TosaTestGen:
}
EW_UNARY_DATAGEN = {
DType.FP16: (gtu.DataGenType.PSEUDO_RANDOM, gtu.DataGenType.FULL_RANGE),
+ DType.FP32: (gtu.DataGenType.PSEUDO_RANDOM, gtu.DataGenType.FP_SPECIAL),
}
PR_FS_DATAGEN = {
DType.FP16: (gtu.DataGenType.PSEUDO_RANDOM, gtu.DataGenType.FP_SPECIAL),
@@ -3647,7 +3655,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
- "data_gen": PSEUDO_RANDOM_DATAGEN,
+ "data_gen": PR_FS_DATAGEN,
},
"sigmoid": {
"op": Op.SIGMOID,
@@ -3665,7 +3673,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
- "data_gen": PSEUDO_RANDOM_DATAGEN,
+ "data_gen": PR_FS_DATAGEN,
},
"tanh": {
"op": Op.TANH,
@@ -3683,7 +3691,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
- "data_gen": PSEUDO_RANDOM_DATAGEN,
+ "data_gen": PR_FS_DATAGEN,
"compliance": {
"abs_error_lower_bound": 0.5,
},
@@ -3704,7 +3712,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
- "data_gen": PSEUDO_RANDOM_DATAGEN,
+ "data_gen": PR_FS_DATAGEN,
"compliance": {"ulp": 5},
},
# Elementwise Binary Operators
@@ -4042,11 +4050,11 @@ class TosaTestGen:
# Use the automatic generation functions to create the input array
# but create the table tensor in the build function, as it may be
# a different type from the input
- "operands": (1, 0),
+ "operands": (2, 0),
"build_fcn": (
build_table,
TosaTensorGen.tgBasic,
- TosaTensorValuesGen.tvgLazyGenDefault,
+ TosaTensorValuesGen.tvgTable,
TosaArgGen.agTable,
),
"types": [DType.INT8, DType.INT16],
@@ -4145,7 +4153,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
- "data_gen": PSEUDO_RANDOM_DATAGEN,
+ "data_gen": EW_UNARY_DATAGEN,
"compliance": {
"abs_error_normal_divisor": 2,
"abs_error_bound_addition": 1,
@@ -4299,7 +4307,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
- "data_gen": PSEUDO_RANDOM_DATAGEN,
+ "data_gen": EW_UNARY_DATAGEN,
"compliance": {"abs_error_normal_divisor": 2},
},
# Elementwise Ternary operators
diff --git a/verif/generator/tosa_verif_build_tests.py b/verif/generator/tosa_verif_build_tests.py
index a46b061..dbd46b5 100644
--- a/verif/generator/tosa_verif_build_tests.py
+++ b/verif/generator/tosa_verif_build_tests.py
@@ -287,6 +287,13 @@ def parseArgs(argv):
help="enables test selection, this is the selection criteria to use from the selection config",
)
+ filter_group.add_argument(
+ "--no-special-tests",
+ dest="no_special_tests",
+ action="store_true",
+ help="Do not produce special 'full range' or 'FP special' tests",
+ )
+
parser.add_argument(
"--list-tests",
dest="list_tests",