diff options
author | James Ward <james.ward@arm.com> | 2022-10-19 12:20:31 +0100 |
---|---|---|
committer | Jeremy Johnson <jeremy.johnson@arm.com> | 2022-11-09 12:19:51 +0000 |
commit | 24dbc420aae556649f50e645bd94489dab2cc75a (patch) | |
tree | 490345da43e9c5bae0f450ba05ffe85874077e0a /verif/checker/tosa_result_checker.py | |
parent | 3b0544c1e7463295c49a48a162ebb9a546326829 (diff) | |
download | reference_model-24dbc420aae556649f50e645bd94489dab2cc75a.tar.gz |
Add BF16 support to reference model
* Upgrade Eigen to 3.4.0 (for bfloat16 support) and add work-
arounds for reduce.any() and reduce.all() bugs (introduced
between 3.3.7 and 3.4.0)
* Truncation to bfloat16 now performed in eval() methods
Signed-off-by: James Ward <james.ward@arm.com>
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: If5f5c988d76d3d30790acf3b97081726b89205fe
Diffstat (limited to 'verif/checker/tosa_result_checker.py')
-rw-r--r-- | verif/checker/tosa_result_checker.py | 22 |
1 files changed, 21 insertions, 1 deletions
diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py index 8ae3218..b7a76b6 100644 --- a/verif/checker/tosa_result_checker.py +++ b/verif/checker/tosa_result_checker.py @@ -9,6 +9,7 @@ from enum import unique from pathlib import Path import numpy as np +from generator.tosa_utils import float32_is_valid_bfloat16 ################################## color_printing = True @@ -63,7 +64,12 @@ TestResultErrorStr = [ def test_check( - reference, result, test_name="test", quantize_tolerance=0, float_tolerance=1e-3 + reference, + result, + test_name="test", + quantize_tolerance=0, + float_tolerance=1e-3, + misc_checks=[], ): """Check if the result is the same as the expected reference.""" if not os.path.isfile(reference): @@ -111,6 +117,20 @@ def test_check( ) return (TestResult.MISMATCH, 0.0, msg) + # Perform miscellaneous checks + if "bf16" in misc_checks: + # Ensure floats are valid bfloat16 values + test_res_is_bf16 = all([float32_is_valid_bfloat16(f) for f in test_result.flat]) + ref_res_is_bf16 = all( + [float32_is_valid_bfloat16(f) for f in reference_result.flat] + ) + if not (test_res_is_bf16 and ref_res_is_bf16): + msg = ( + "All output values must be valid bfloat16. " + "reference_result: {ref_res_is_bf16}; test_result: {test_res_is_bf16}" + ) + return (TestResult.INCORRECT_FORMAT, 0.0, msg) + # for quantized test, allow +-(quantize_tolerance) error if reference_result.dtype == np.int32 or reference_result.dtype == np.int64: |