aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py29
1 files changed, 18 insertions, 11 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 88dd17a..cbac081 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -260,6 +260,11 @@ class TosaTestGen:
# Data type is needed for all FP runs, as refmodel precise mode produces FP64
"data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
}
+
+ op_compliance = op.get("compliance", {})
+ mode = None
+
+ # Check what data generation we have done
if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
mode = gtu.ComplianceMode.DOT_PRODUCT
compliance_tens["dot_product_info"] = {
@@ -268,12 +273,10 @@ class TosaTestGen:
int(argsDict["ksb"]) if "ksb" in argsDict else int(argsDict["ks"])
),
}
- elif argsDict["dg_type"] == gtu.DataGenType.FP_SPECIAL:
- mode = gtu.ComplianceMode.FP_SPECIAL
- elif "compliance" in op and "ulp" in op["compliance"]:
+ elif "ulp" in op_compliance:
mode = gtu.ComplianceMode.ULP
compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
- elif "compliance" in op and "relative" in op["compliance"]:
+ elif "relative" in op_compliance:
mode = gtu.ComplianceMode.RELATIVE
compliance_tens["relative_info"] = {
"max": argsDict["max_abs_value"],
@@ -284,26 +287,30 @@ class TosaTestGen:
compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
mode = gtu.ComplianceMode.ABS_ERROR
- if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
+ if "abs_error_lower_bound" in op_compliance:
compliance_tens["abs_error_info"] = {
"lower_bound": op["compliance"]["abs_error_lower_bound"]
}
elif op["op"] in (Op.SIN, Op.COS):
mode = gtu.ComplianceMode.ABS_ERROR
- if "compliance" in op:
- normal_divisor = op["compliance"].get("abs_error_normal_divisor", 1)
- bound_addition = op["compliance"].get("abs_error_bound_addition", 0)
- else:
- normal_divisor = 1
- bound_addition = 0
+ normal_divisor = op_compliance.get("abs_error_normal_divisor", 1)
+ bound_addition = op_compliance.get("abs_error_bound_addition", 0)
compliance_tens["abs_error_info"] = {
"normal_divisor": normal_divisor,
"bound_as_magnitude": True,
"bound_addition": bound_addition,
}
+ elif argsDict["dg_type"] == gtu.DataGenType.FP_SPECIAL:
+ if gtu.ComplianceMode.DOT_PRODUCT in op["data_gen"][inputType]:
+ # Use special mode that only checks for matching inf/nan/zeroes
+ # as normal values need statistical analysis
+ mode = gtu.ComplianceMode.FP_SPECIAL
+ else:
+ mode = gtu.ComplianceMode.EXACT
else:
mode = gtu.ComplianceMode.EXACT
+
compliance_tens["mode"] = gtu.ComplianceMode(mode).name
return compliance_tens