aboutsummaryrefslogtreecommitdiff
path: root/verif/checker/tosa_result_checker.py
diff options
context:
space:
mode:
authorWon Jeon <won.jeon@arm.com>2024-02-06 18:37:00 +0000
committerWon Jeon <won.jeon@arm.com>2024-02-21 19:38:55 +0000
commit2c34b4616a10539211e7006bc43f3c71e86c30bb (patch)
treeaa4043a610ecd4c6d35b876cfb013dbe7dd0ab01 /verif/checker/tosa_result_checker.py
parent587cc84c2b8c4b0d030b5e257c9a32461c0969b9 (diff)
downloadreference_model-2c34b4616a10539211e7006bc43f3c71e86c30bb.tar.gz
Add support for FP8 to reference model
Signed-off-by: Won Jeon <won.jeon@arm.com> Change-Id: I99b70f94aff2ccd4af64875697e124eb60bc5b08
Diffstat (limited to 'verif/checker/tosa_result_checker.py')
-rw-r--r--verif/checker/tosa_result_checker.py13
1 files changed, 13 insertions, 0 deletions
diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py
index 212c809..4d6d345 100644
--- a/verif/checker/tosa_result_checker.py
+++ b/verif/checker/tosa_result_checker.py
@@ -13,6 +13,7 @@ from checker.color_print import print_color
from checker.verifier import VerifierError
from checker.verifier import VerifierLibrary
from generator.tosa_utils import float32_is_valid_bfloat16
+from generator.tosa_utils import float32_is_valid_float8
from schemavalidation.schemavalidation import TestDescSchemaValidator
@@ -195,6 +196,18 @@ def test_check(
"reference_result: {ref_res_is_bf16}; test_result: {test_res_is_bf16}"
)
return (TestResult.INCORRECT_FORMAT, 0.0, msg)
+ if "fp8e4m3" in misc_checks or "fp8e5m2" in misc_checks:
+ # Ensure floats are valid float8 values
+ test_res_is_fp8 = all([float32_is_valid_float8(f) for f in test_result.flat])
+ ref_res_is_fp8 = all(
+ [float32_is_valid_float8(f) for f in reference_result.flat]
+ )
+ if not (test_res_is_fp8 and ref_res_is_fp8):
+ msg = (
+ "All output values must be valid float8. "
+ "reference_result: {ref_res_is_float8}; test_result: {test_res_is_float8}"
+ )
+ return (TestResult.INCORRECT_FLOAT, 0.0, msg)
# for quantized test, allow +-(quantize_tolerance) error
if reference_result.dtype in (