aboutsummaryrefslogtreecommitdiff
path: root/verif/checker/tosa_result_checker.py
diff options
context:
space:
mode:
authorJames Ward <james.ward@arm.com>2022-10-19 12:20:31 +0100
committerJeremy Johnson <jeremy.johnson@arm.com>2022-11-09 12:19:51 +0000
commit24dbc420aae556649f50e645bd94489dab2cc75a (patch)
tree490345da43e9c5bae0f450ba05ffe85874077e0a /verif/checker/tosa_result_checker.py
parent3b0544c1e7463295c49a48a162ebb9a546326829 (diff)
downloadreference_model-24dbc420aae556649f50e645bd94489dab2cc75a.tar.gz
Add BF16 support to reference model
* Upgrade Eigen to 3.4.0 (for bfloat16 support) and add work- arounds for reduce.any() and reduce.all() bugs (introduced between 3.3.7 and 3.4.0) * Truncation to bfloat16 now performed in eval() methods Signed-off-by: James Ward <james.ward@arm.com> Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: If5f5c988d76d3d30790acf3b97081726b89205fe
Diffstat (limited to 'verif/checker/tosa_result_checker.py')
-rw-r--r--verif/checker/tosa_result_checker.py22
1 files changed, 21 insertions, 1 deletions
diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py
index 8ae3218..b7a76b6 100644
--- a/verif/checker/tosa_result_checker.py
+++ b/verif/checker/tosa_result_checker.py
@@ -9,6 +9,7 @@ from enum import unique
from pathlib import Path
import numpy as np
+from generator.tosa_utils import float32_is_valid_bfloat16
##################################
color_printing = True
@@ -63,7 +64,12 @@ TestResultErrorStr = [
def test_check(
- reference, result, test_name="test", quantize_tolerance=0, float_tolerance=1e-3
+ reference,
+ result,
+ test_name="test",
+ quantize_tolerance=0,
+ float_tolerance=1e-3,
+ misc_checks=[],
):
"""Check if the result is the same as the expected reference."""
if not os.path.isfile(reference):
@@ -111,6 +117,20 @@ def test_check(
)
return (TestResult.MISMATCH, 0.0, msg)
+ # Perform miscellaneous checks
+ if "bf16" in misc_checks:
+ # Ensure floats are valid bfloat16 values
+ test_res_is_bf16 = all([float32_is_valid_bfloat16(f) for f in test_result.flat])
+ ref_res_is_bf16 = all(
+ [float32_is_valid_bfloat16(f) for f in reference_result.flat]
+ )
+ if not (test_res_is_bf16 and ref_res_is_bf16):
+ msg = (
+ "All output values must be valid bfloat16. "
+ "reference_result: {ref_res_is_bf16}; test_result: {test_res_is_bf16}"
+ )
+ return (TestResult.INCORRECT_FORMAT, 0.0, msg)
+
# for quantized test, allow +-(quantize_tolerance) error
if reference_result.dtype == np.int32 or reference_result.dtype == np.int64: