From 29b02017bf9e0ac381da068e1819632a56cf9966 Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Tue, 30 Apr 2024 13:56:20 +0100 Subject: Fix MAXIMUM/MINIMUM handling of NaNs and zeroes Change FP_SPECIAL testing to be used for DOT_PRODUCT cases only. Use default EXACT matching - where zeroes of different signs will be ignored when testing for equality Signed-off-by: Jeremy Johnson Change-Id: I0461c42258611cae597f693507075b3ef15fbe19 --- verif/generator/tosa_test_gen.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) (limited to 'verif') diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 88dd17a..cbac081 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -260,6 +260,11 @@ class TosaTestGen: # Data type is needed for all FP runs, as refmodel precise mode produces FP64 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"], } + + op_compliance = op.get("compliance", {}) + mode = None + + # Check what data generation we have done if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT: mode = gtu.ComplianceMode.DOT_PRODUCT compliance_tens["dot_product_info"] = { @@ -268,12 +273,10 @@ class TosaTestGen: int(argsDict["ksb"]) if "ksb" in argsDict else int(argsDict["ks"]) ), } - elif argsDict["dg_type"] == gtu.DataGenType.FP_SPECIAL: - mode = gtu.ComplianceMode.FP_SPECIAL - elif "compliance" in op and "ulp" in op["compliance"]: + elif "ulp" in op_compliance: mode = gtu.ComplianceMode.ULP compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]} - elif "compliance" in op and "relative" in op["compliance"]: + elif "relative" in op_compliance: mode = gtu.ComplianceMode.RELATIVE compliance_tens["relative_info"] = { "max": argsDict["max_abs_value"], @@ -284,26 +287,30 @@ class TosaTestGen: compliance_tens["reduce_product_info"] = {"n": argsDict["n"]} elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID): mode = gtu.ComplianceMode.ABS_ERROR - if "compliance" in op and "abs_error_lower_bound" in op["compliance"]: + if "abs_error_lower_bound" in op_compliance: compliance_tens["abs_error_info"] = { "lower_bound": op["compliance"]["abs_error_lower_bound"] } elif op["op"] in (Op.SIN, Op.COS): mode = gtu.ComplianceMode.ABS_ERROR - if "compliance" in op: - normal_divisor = op["compliance"].get("abs_error_normal_divisor", 1) - bound_addition = op["compliance"].get("abs_error_bound_addition", 0) - else: - normal_divisor = 1 - bound_addition = 0 + normal_divisor = op_compliance.get("abs_error_normal_divisor", 1) + bound_addition = op_compliance.get("abs_error_bound_addition", 0) compliance_tens["abs_error_info"] = { "normal_divisor": normal_divisor, "bound_as_magnitude": True, "bound_addition": bound_addition, } + elif argsDict["dg_type"] == gtu.DataGenType.FP_SPECIAL: + if gtu.ComplianceMode.DOT_PRODUCT in op["data_gen"][inputType]: + # Use special mode that only checks for matching inf/nan/zeroes + # as normal values need statistical analysis + mode = gtu.ComplianceMode.FP_SPECIAL + else: + mode = gtu.ComplianceMode.EXACT else: mode = gtu.ComplianceMode.EXACT + compliance_tens["mode"] = gtu.ComplianceMode(mode).name return compliance_tens -- cgit v1.2.1