diff options
author | Jeremy Johnson <jeremy.johnson@arm.com> | 2024-04-30 13:56:20 +0100 |
---|---|---|
committer | Jeremy Johnson <jeremy.johnson@arm.com> | 2024-05-01 11:08:27 +0100 |
commit | 29b02017bf9e0ac381da068e1819632a56cf9966 (patch) | |
tree | 3aff90b874a84bf89819078d28b746c16607a349 /verif | |
parent | 8ded90d7894528f4bb0418213db08e947d3a6def (diff) | |
download | reference_model-29b02017bf9e0ac381da068e1819632a56cf9966.tar.gz |
Fix MAXIMUM/MINIMUM handling of NaNs and zeroes
Change FP_SPECIAL testing to be used for DOT_PRODUCT cases only.
Use default EXACT matching - where zeroes of different signs will
be ignored when testing for equality
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: I0461c42258611cae597f693507075b3ef15fbe19
Diffstat (limited to 'verif')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 29 |
1 files changed, 18 insertions, 11 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 88dd17a..cbac081 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -260,6 +260,11 @@ class TosaTestGen: # Data type is needed for all FP runs, as refmodel precise mode produces FP64 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"], } + + op_compliance = op.get("compliance", {}) + mode = None + + # Check what data generation we have done if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT: mode = gtu.ComplianceMode.DOT_PRODUCT compliance_tens["dot_product_info"] = { @@ -268,12 +273,10 @@ class TosaTestGen: int(argsDict["ksb"]) if "ksb" in argsDict else int(argsDict["ks"]) ), } - elif argsDict["dg_type"] == gtu.DataGenType.FP_SPECIAL: - mode = gtu.ComplianceMode.FP_SPECIAL - elif "compliance" in op and "ulp" in op["compliance"]: + elif "ulp" in op_compliance: mode = gtu.ComplianceMode.ULP compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]} - elif "compliance" in op and "relative" in op["compliance"]: + elif "relative" in op_compliance: mode = gtu.ComplianceMode.RELATIVE compliance_tens["relative_info"] = { "max": argsDict["max_abs_value"], @@ -284,26 +287,30 @@ class TosaTestGen: compliance_tens["reduce_product_info"] = {"n": argsDict["n"]} elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID): mode = gtu.ComplianceMode.ABS_ERROR - if "compliance" in op and "abs_error_lower_bound" in op["compliance"]: + if "abs_error_lower_bound" in op_compliance: compliance_tens["abs_error_info"] = { "lower_bound": op["compliance"]["abs_error_lower_bound"] } elif op["op"] in (Op.SIN, Op.COS): mode = gtu.ComplianceMode.ABS_ERROR - if "compliance" in op: - normal_divisor = op["compliance"].get("abs_error_normal_divisor", 1) - bound_addition = op["compliance"].get("abs_error_bound_addition", 0) - else: - normal_divisor = 1 - bound_addition = 0 + normal_divisor = op_compliance.get("abs_error_normal_divisor", 1) + bound_addition = op_compliance.get("abs_error_bound_addition", 0) compliance_tens["abs_error_info"] = { "normal_divisor": normal_divisor, "bound_as_magnitude": True, "bound_addition": bound_addition, } + elif argsDict["dg_type"] == gtu.DataGenType.FP_SPECIAL: + if gtu.ComplianceMode.DOT_PRODUCT in op["data_gen"][inputType]: + # Use special mode that only checks for matching inf/nan/zeroes + # as normal values need statistical analysis + mode = gtu.ComplianceMode.FP_SPECIAL + else: + mode = gtu.ComplianceMode.EXACT else: mode = gtu.ComplianceMode.EXACT + compliance_tens["mode"] = gtu.ComplianceMode(mode).name return compliance_tens |