diff options
Diffstat (limited to 'verif/checker/tosa_result_checker.py')
-rw-r--r-- | verif/checker/tosa_result_checker.py | 22 |
1 files changed, 21 insertions, 1 deletions
diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py index 8ae3218..b7a76b6 100644 --- a/verif/checker/tosa_result_checker.py +++ b/verif/checker/tosa_result_checker.py @@ -9,6 +9,7 @@ from enum import unique from pathlib import Path import numpy as np +from generator.tosa_utils import float32_is_valid_bfloat16 ################################## color_printing = True @@ -63,7 +64,12 @@ TestResultErrorStr = [ def test_check( - reference, result, test_name="test", quantize_tolerance=0, float_tolerance=1e-3 + reference, + result, + test_name="test", + quantize_tolerance=0, + float_tolerance=1e-3, + misc_checks=[], ): """Check if the result is the same as the expected reference.""" if not os.path.isfile(reference): @@ -111,6 +117,20 @@ def test_check( ) return (TestResult.MISMATCH, 0.0, msg) + # Perform miscellaneous checks + if "bf16" in misc_checks: + # Ensure floats are valid bfloat16 values + test_res_is_bf16 = all([float32_is_valid_bfloat16(f) for f in test_result.flat]) + ref_res_is_bf16 = all( + [float32_is_valid_bfloat16(f) for f in reference_result.flat] + ) + if not (test_res_is_bf16 and ref_res_is_bf16): + msg = ( + "All output values must be valid bfloat16. " + "reference_result: {ref_res_is_bf16}; test_result: {test_res_is_bf16}" + ) + return (TestResult.INCORRECT_FORMAT, 0.0, msg) + # for quantized test, allow +-(quantize_tolerance) error if reference_result.dtype == np.int32 or reference_result.dtype == np.int64: |