aboutsummaryrefslogtreecommitdiff
path: root/verif/checker/tosa_result_checker.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/checker/tosa_result_checker.py')
-rw-r--r--verif/checker/tosa_result_checker.py22
1 files changed, 21 insertions, 1 deletions
diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py
index 8ae3218..b7a76b6 100644
--- a/verif/checker/tosa_result_checker.py
+++ b/verif/checker/tosa_result_checker.py
@@ -9,6 +9,7 @@ from enum import unique
from pathlib import Path
import numpy as np
+from generator.tosa_utils import float32_is_valid_bfloat16
##################################
color_printing = True
@@ -63,7 +64,12 @@ TestResultErrorStr = [
def test_check(
- reference, result, test_name="test", quantize_tolerance=0, float_tolerance=1e-3
+ reference,
+ result,
+ test_name="test",
+ quantize_tolerance=0,
+ float_tolerance=1e-3,
+ misc_checks=[],
):
"""Check if the result is the same as the expected reference."""
if not os.path.isfile(reference):
@@ -111,6 +117,20 @@ def test_check(
)
return (TestResult.MISMATCH, 0.0, msg)
+ # Perform miscellaneous checks
+ if "bf16" in misc_checks:
+ # Ensure floats are valid bfloat16 values
+ test_res_is_bf16 = all([float32_is_valid_bfloat16(f) for f in test_result.flat])
+ ref_res_is_bf16 = all(
+ [float32_is_valid_bfloat16(f) for f in reference_result.flat]
+ )
+ if not (test_res_is_bf16 and ref_res_is_bf16):
+ msg = (
+ "All output values must be valid bfloat16. "
+ "reference_result: {ref_res_is_bf16}; test_result: {test_res_is_bf16}"
+ )
+ return (TestResult.INCORRECT_FORMAT, 0.0, msg)
+
# for quantized test, allow +-(quantize_tolerance) error
if reference_result.dtype == np.int32 or reference_result.dtype == np.int64: