From 24dbc420aae556649f50e645bd94489dab2cc75a Mon Sep 17 00:00:00 2001 From: James Ward Date: Wed, 19 Oct 2022 12:20:31 +0100 Subject: Add BF16 support to reference model * Upgrade Eigen to 3.4.0 (for bfloat16 support) and add work- arounds for reduce.any() and reduce.all() bugs (introduced between 3.3.7 and 3.4.0) * Truncation to bfloat16 now performed in eval() methods Signed-off-by: James Ward Signed-off-by: Jeremy Johnson Change-Id: If5f5c988d76d3d30790acf3b97081726b89205fe --- verif/checker/tosa_result_checker.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) (limited to 'verif/checker') diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py index 8ae3218..b7a76b6 100644 --- a/verif/checker/tosa_result_checker.py +++ b/verif/checker/tosa_result_checker.py @@ -9,6 +9,7 @@ from enum import unique from pathlib import Path import numpy as np +from generator.tosa_utils import float32_is_valid_bfloat16 ################################## color_printing = True @@ -63,7 +64,12 @@ TestResultErrorStr = [ def test_check( - reference, result, test_name="test", quantize_tolerance=0, float_tolerance=1e-3 + reference, + result, + test_name="test", + quantize_tolerance=0, + float_tolerance=1e-3, + misc_checks=[], ): """Check if the result is the same as the expected reference.""" if not os.path.isfile(reference): @@ -111,6 +117,20 @@ def test_check( ) return (TestResult.MISMATCH, 0.0, msg) + # Perform miscellaneous checks + if "bf16" in misc_checks: + # Ensure floats are valid bfloat16 values + test_res_is_bf16 = all([float32_is_valid_bfloat16(f) for f in test_result.flat]) + ref_res_is_bf16 = all( + [float32_is_valid_bfloat16(f) for f in reference_result.flat] + ) + if not (test_res_is_bf16 and ref_res_is_bf16): + msg = ( + "All output values must be valid bfloat16. " + "reference_result: {ref_res_is_bf16}; test_result: {test_res_is_bf16}" + ) + return (TestResult.INCORRECT_FORMAT, 0.0, msg) + # for quantized test, allow +-(quantize_tolerance) error if reference_result.dtype == np.int32 or reference_result.dtype == np.int64: -- cgit v1.2.1