aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2024-04-30 13:56:20 +0100
committerJeremy Johnson <jeremy.johnson@arm.com>2024-05-01 11:08:27 +0100
commit29b02017bf9e0ac381da068e1819632a56cf9966 (patch)
tree3aff90b874a84bf89819078d28b746c16607a349
parent8ded90d7894528f4bb0418213db08e947d3a6def (diff)
downloadreference_model-29b02017bf9e0ac381da068e1819632a56cf9966.tar.gz
Fix MAXIMUM/MINIMUM handling of NaNs and zeroes
Change FP_SPECIAL testing to be used for DOT_PRODUCT cases only. Use default EXACT matching - where zeroes of different signs will be ignored when testing for equality Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I0461c42258611cae597f693507075b3ef15fbe19
-rw-r--r--reference_model/src/ops/ewise_binary.cc31
-rw-r--r--verif/generator/tosa_test_gen.py29
2 files changed, 49 insertions, 11 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index 8cc1319..d4a9f2f 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -411,6 +411,22 @@ int OpMaximum<Rank, Dtype>::register_fcn()
case TOSA_REF_TYPE_BF16:
case TOSA_REF_TYPE_FP32:
case TOSA_REF_TYPE_FP64:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType {
+ if (isnan(a))
+ {
+ return a;
+ }
+ else if (isnan(b))
+ {
+ return b;
+ }
+ else
+ {
+ return a > b ? a : b;
+ }
+ };
+ break;
+
case TOSA_REF_TYPE_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; };
break;
@@ -430,6 +446,21 @@ int OpMinimum<Rank, Dtype>::register_fcn()
case TOSA_REF_TYPE_BF16:
case TOSA_REF_TYPE_FP32:
case TOSA_REF_TYPE_FP64:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType {
+ if (isnan(a))
+ {
+ return a;
+ }
+ else if (isnan(b))
+ {
+ return b;
+ }
+ else
+ {
+ return a < b ? a : b;
+ }
+ };
+ break;
case TOSA_REF_TYPE_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; };
break;
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 88dd17a..cbac081 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -260,6 +260,11 @@ class TosaTestGen:
# Data type is needed for all FP runs, as refmodel precise mode produces FP64
"data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
}
+
+ op_compliance = op.get("compliance", {})
+ mode = None
+
+ # Check what data generation we have done
if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
mode = gtu.ComplianceMode.DOT_PRODUCT
compliance_tens["dot_product_info"] = {
@@ -268,12 +273,10 @@ class TosaTestGen:
int(argsDict["ksb"]) if "ksb" in argsDict else int(argsDict["ks"])
),
}
- elif argsDict["dg_type"] == gtu.DataGenType.FP_SPECIAL:
- mode = gtu.ComplianceMode.FP_SPECIAL
- elif "compliance" in op and "ulp" in op["compliance"]:
+ elif "ulp" in op_compliance:
mode = gtu.ComplianceMode.ULP
compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
- elif "compliance" in op and "relative" in op["compliance"]:
+ elif "relative" in op_compliance:
mode = gtu.ComplianceMode.RELATIVE
compliance_tens["relative_info"] = {
"max": argsDict["max_abs_value"],
@@ -284,26 +287,30 @@ class TosaTestGen:
compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
mode = gtu.ComplianceMode.ABS_ERROR
- if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
+ if "abs_error_lower_bound" in op_compliance:
compliance_tens["abs_error_info"] = {
"lower_bound": op["compliance"]["abs_error_lower_bound"]
}
elif op["op"] in (Op.SIN, Op.COS):
mode = gtu.ComplianceMode.ABS_ERROR
- if "compliance" in op:
- normal_divisor = op["compliance"].get("abs_error_normal_divisor", 1)
- bound_addition = op["compliance"].get("abs_error_bound_addition", 0)
- else:
- normal_divisor = 1
- bound_addition = 0
+ normal_divisor = op_compliance.get("abs_error_normal_divisor", 1)
+ bound_addition = op_compliance.get("abs_error_bound_addition", 0)
compliance_tens["abs_error_info"] = {
"normal_divisor": normal_divisor,
"bound_as_magnitude": True,
"bound_addition": bound_addition,
}
+ elif argsDict["dg_type"] == gtu.DataGenType.FP_SPECIAL:
+ if gtu.ComplianceMode.DOT_PRODUCT in op["data_gen"][inputType]:
+ # Use special mode that only checks for matching inf/nan/zeroes
+ # as normal values need statistical analysis
+ mode = gtu.ComplianceMode.FP_SPECIAL
+ else:
+ mode = gtu.ComplianceMode.EXACT
else:
mode = gtu.ComplianceMode.EXACT
+
compliance_tens["mode"] = gtu.ComplianceMode(mode).name
return compliance_tens