diff options
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 29 |
1 files changed, 18 insertions, 11 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 88dd17a..cbac081 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -260,6 +260,11 @@ class TosaTestGen: # Data type is needed for all FP runs, as refmodel precise mode produces FP64 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"], } + + op_compliance = op.get("compliance", {}) + mode = None + + # Check what data generation we have done if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT: mode = gtu.ComplianceMode.DOT_PRODUCT compliance_tens["dot_product_info"] = { @@ -268,12 +273,10 @@ class TosaTestGen: int(argsDict["ksb"]) if "ksb" in argsDict else int(argsDict["ks"]) ), } - elif argsDict["dg_type"] == gtu.DataGenType.FP_SPECIAL: - mode = gtu.ComplianceMode.FP_SPECIAL - elif "compliance" in op and "ulp" in op["compliance"]: + elif "ulp" in op_compliance: mode = gtu.ComplianceMode.ULP compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]} - elif "compliance" in op and "relative" in op["compliance"]: + elif "relative" in op_compliance: mode = gtu.ComplianceMode.RELATIVE compliance_tens["relative_info"] = { "max": argsDict["max_abs_value"], @@ -284,26 +287,30 @@ class TosaTestGen: compliance_tens["reduce_product_info"] = {"n": argsDict["n"]} elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID): mode = gtu.ComplianceMode.ABS_ERROR - if "compliance" in op and "abs_error_lower_bound" in op["compliance"]: + if "abs_error_lower_bound" in op_compliance: compliance_tens["abs_error_info"] = { "lower_bound": op["compliance"]["abs_error_lower_bound"] } elif op["op"] in (Op.SIN, Op.COS): mode = gtu.ComplianceMode.ABS_ERROR - if "compliance" in op: - normal_divisor = op["compliance"].get("abs_error_normal_divisor", 1) - bound_addition = op["compliance"].get("abs_error_bound_addition", 0) - else: - normal_divisor = 1 - bound_addition = 0 + normal_divisor = op_compliance.get("abs_error_normal_divisor", 1) + bound_addition = op_compliance.get("abs_error_bound_addition", 0) compliance_tens["abs_error_info"] = { "normal_divisor": normal_divisor, "bound_as_magnitude": True, "bound_addition": bound_addition, } + elif argsDict["dg_type"] == gtu.DataGenType.FP_SPECIAL: + if gtu.ComplianceMode.DOT_PRODUCT in op["data_gen"][inputType]: + # Use special mode that only checks for matching inf/nan/zeroes + # as normal values need statistical analysis + mode = gtu.ComplianceMode.FP_SPECIAL + else: + mode = gtu.ComplianceMode.EXACT else: mode = gtu.ComplianceMode.EXACT + compliance_tens["mode"] = gtu.ComplianceMode(mode).name return compliance_tens |