diff options
Diffstat (limited to 'verif/checker')
-rw-r--r-- | verif/checker/tosa_result_checker.py | 13 |
1 files changed, 13 insertions, 0 deletions
diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py index 212c809..4d6d345 100644 --- a/verif/checker/tosa_result_checker.py +++ b/verif/checker/tosa_result_checker.py @@ -13,6 +13,7 @@ from checker.color_print import print_color from checker.verifier import VerifierError from checker.verifier import VerifierLibrary from generator.tosa_utils import float32_is_valid_bfloat16 +from generator.tosa_utils import float32_is_valid_float8 from schemavalidation.schemavalidation import TestDescSchemaValidator @@ -195,6 +196,18 @@ def test_check( "reference_result: {ref_res_is_bf16}; test_result: {test_res_is_bf16}" ) return (TestResult.INCORRECT_FORMAT, 0.0, msg) + if "fp8e4m3" in misc_checks or "fp8e5m2" in misc_checks: + # Ensure floats are valid float8 values + test_res_is_fp8 = all([float32_is_valid_float8(f) for f in test_result.flat]) + ref_res_is_fp8 = all( + [float32_is_valid_float8(f) for f in reference_result.flat] + ) + if not (test_res_is_fp8 and ref_res_is_fp8): + msg = ( + "All output values must be valid float8. " + "reference_result: {ref_res_is_float8}; test_result: {test_res_is_float8}" + ) + return (TestResult.INCORRECT_FLOAT, 0.0, msg) # for quantized test, allow +-(quantize_tolerance) error if reference_result.dtype in ( |