aboutsummaryrefslogtreecommitdiff
path: root/verif/checker/tosa_result_checker.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/checker/tosa_result_checker.py')
-rw-r--r--verif/checker/tosa_result_checker.py13
1 files changed, 13 insertions, 0 deletions
diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py
index 212c809..4d6d345 100644
--- a/verif/checker/tosa_result_checker.py
+++ b/verif/checker/tosa_result_checker.py
@@ -13,6 +13,7 @@ from checker.color_print import print_color
from checker.verifier import VerifierError
from checker.verifier import VerifierLibrary
from generator.tosa_utils import float32_is_valid_bfloat16
+from generator.tosa_utils import float32_is_valid_float8
from schemavalidation.schemavalidation import TestDescSchemaValidator
@@ -195,6 +196,18 @@ def test_check(
"reference_result: {ref_res_is_bf16}; test_result: {test_res_is_bf16}"
)
return (TestResult.INCORRECT_FORMAT, 0.0, msg)
+ if "fp8e4m3" in misc_checks or "fp8e5m2" in misc_checks:
+ # Ensure floats are valid float8 values
+ test_res_is_fp8 = all([float32_is_valid_float8(f) for f in test_result.flat])
+ ref_res_is_fp8 = all(
+ [float32_is_valid_float8(f) for f in reference_result.flat]
+ )
+ if not (test_res_is_fp8 and ref_res_is_fp8):
+ msg = (
+ "All output values must be valid float8. "
+ "reference_result: {ref_res_is_float8}; test_result: {test_res_is_float8}"
+ )
+ return (TestResult.INCORRECT_FLOAT, 0.0, msg)
# for quantized test, allow +-(quantize_tolerance) error
if reference_result.dtype in (