From 24dbc420aae556649f50e645bd94489dab2cc75a Mon Sep 17 00:00:00 2001 From: James Ward Date: Wed, 19 Oct 2022 12:20:31 +0100 Subject: Add BF16 support to reference model * Upgrade Eigen to 3.4.0 (for bfloat16 support) and add work- arounds for reduce.any() and reduce.all() bugs (introduced between 3.3.7 and 3.4.0) * Truncation to bfloat16 now performed in eval() methods Signed-off-by: James Ward Signed-off-by: Jeremy Johnson Change-Id: If5f5c988d76d3d30790acf3b97081726b89205fe --- verif/tests/test_tosa_refmodel.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) (limited to 'verif/tests/test_tosa_refmodel.py') diff --git a/verif/tests/test_tosa_refmodel.py b/verif/tests/test_tosa_refmodel.py index b608fd8..50ff1ab 100644 --- a/verif/tests/test_tosa_refmodel.py +++ b/verif/tests/test_tosa_refmodel.py @@ -47,6 +47,7 @@ REF_MODEL_TYPE_TO_OUT = { "int32": "i32", "fp32": "f32", "fp16": "f16", + "bf16": "bf16", } @@ -127,11 +128,13 @@ TEST_PARAMS = [ ("abs", "int32", 1), ("abs", "fp32", 1), ("abs", "fp16", 1), + ("abs", "bf16", 1), ("negate", "int8", 1), ("negate", "int16", 1), ("negate", "int32", 1), ("negate", "fp32", 1), ("negate", "fp16", 1), + ("negate", "bf16", 1), # One test per axis (shape dimensions) ("concat", "bool", SHAPE_DIMS), ("concat", "int8", SHAPE_DIMS), @@ -139,6 +142,7 @@ TEST_PARAMS = [ ("concat", "int32", SHAPE_DIMS), ("concat", "fp32", SHAPE_DIMS), ("concat", "fp16", SHAPE_DIMS), + ("concat", "bf16", SHAPE_DIMS), ] @@ -165,6 +169,9 @@ def test_refmodel_simple_op(tosaTest): # Generate TOSA test(s) (mostly should be single test) test_dirs = tosaTest.create_test() + # Indicate miscellaneous checks to run in tosa_check + misc_checks = [] + for test_dir in test_dirs: # Run ref model desc_file = test_dir / TEST_DESC_FILENAME @@ -227,8 +234,15 @@ def test_refmodel_simple_op(tosaTest): np.save(str(result_file), result) assert result_file.is_file() + # Ensure valid bf16 + if tosaTest.ref_model_type == "bf16": + misc_checks.append("bf16") + # Check Numpy result versus refmodel check_result, tolerance, msg = tosa_check( - str(result_file), str(ofm_file), test_name=test_dir.name + str(result_file), + str(ofm_file), + test_name=test_dir.name, + misc_checks=misc_checks, ) assert check_result == TosaResult.PASS -- cgit v1.2.1