From 2c34b4616a10539211e7006bc43f3c71e86c30bb Mon Sep 17 00:00:00 2001 From: Won Jeon Date: Tue, 6 Feb 2024 18:37:00 +0000 Subject: Add support for FP8 to reference model Signed-off-by: Won Jeon Change-Id: I99b70f94aff2ccd4af64875697e124eb60bc5b08 --- verif/checker/tosa_result_checker.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) (limited to 'verif/checker/tosa_result_checker.py') diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py index 212c809..4d6d345 100644 --- a/verif/checker/tosa_result_checker.py +++ b/verif/checker/tosa_result_checker.py @@ -13,6 +13,7 @@ from checker.color_print import print_color from checker.verifier import VerifierError from checker.verifier import VerifierLibrary from generator.tosa_utils import float32_is_valid_bfloat16 +from generator.tosa_utils import float32_is_valid_float8 from schemavalidation.schemavalidation import TestDescSchemaValidator @@ -195,6 +196,18 @@ def test_check( "reference_result: {ref_res_is_bf16}; test_result: {test_res_is_bf16}" ) return (TestResult.INCORRECT_FORMAT, 0.0, msg) + if "fp8e4m3" in misc_checks or "fp8e5m2" in misc_checks: + # Ensure floats are valid float8 values + test_res_is_fp8 = all([float32_is_valid_float8(f) for f in test_result.flat]) + ref_res_is_fp8 = all( + [float32_is_valid_float8(f) for f in reference_result.flat] + ) + if not (test_res_is_fp8 and ref_res_is_fp8): + msg = ( + "All output values must be valid float8. " + "reference_result: {ref_res_is_float8}; test_result: {test_res_is_float8}" + ) + return (TestResult.INCORRECT_FLOAT, 0.0, msg) # for quantized test, allow +-(quantize_tolerance) error if reference_result.dtype in ( -- cgit v1.2.1