From 24dbc420aae556649f50e645bd94489dab2cc75a Mon Sep 17 00:00:00 2001 From: James Ward Date: Wed, 19 Oct 2022 12:20:31 +0100 Subject: Add BF16 support to reference model * Upgrade Eigen to 3.4.0 (for bfloat16 support) and add work- arounds for reduce.any() and reduce.all() bugs (introduced between 3.3.7 and 3.4.0) * Truncation to bfloat16 now performed in eval() methods Signed-off-by: James Ward Signed-off-by: Jeremy Johnson Change-Id: If5f5c988d76d3d30790acf3b97081726b89205fe --- verif/generator/tosa_error_if.py | 35 ++++++++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 5 deletions(-) (limited to 'verif/generator/tosa_error_if.py') diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index abe1a97..a850699 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -158,6 +158,15 @@ class TosaErrorIfArgGen: DType.INT48, DType.FP32, ) + elif dtype == DType.BF16: + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FP32, + ) elif dtype == DType.FP32: incorrect_types = ( DType.INT4, @@ -299,8 +308,8 @@ class TosaErrorIfArgGen: @staticmethod def eiCastErrorIf(testGen, input_dtype): - if input_dtype in [DType.BOOL, DType.FP16, DType.FP32]: - outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.FP32] + if input_dtype in [DType.BOOL, DType.FP16, DType.BF16, DType.FP32]: + outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.BF16, DType.FP32] elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]: outputDType = [DType.INT48] else: @@ -425,6 +434,7 @@ class TosaErrorValidator: and output_dtype != DType.INT48 ) or (input_dtype == DType.FP16 and output_dtype != DType.FP16) + or (input_dtype == DType.BF16 and output_dtype != DType.BF16) or (input_dtype == DType.FP32 and output_dtype != DType.FP32) ): error_result = True @@ -442,25 +452,29 @@ class TosaErrorValidator: input_dtype == DType.FP16 and output_dtype not in (DType.FP16, DType.FP32) ) + or (input_dtype == DType.BF16 and output_dtype != DType.FP32) or (input_dtype == DType.FP32 and output_dtype != DType.FP32) ): error_result = True elif op["op"] == Op.ARGMAX: if ( - input_dtype in [DType.INT8, DType.INT16, DType.FP16, DType.FP32] + input_dtype + in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32] and output_dtype != DType.INT32 ): error_result = True elif op["op"] == Op.MUL: if ( - input_dtype not in (DType.FP16, DType.FP32) + input_dtype not in (DType.FP16, DType.BF16, DType.FP32) and output_dtype != DType.INT32 ): error_result = True elif input_dtype == DType.FP16 and output_dtype != DType.FP16: error_result = True + elif input_dtype == DType.BF16 and output_dtype != DType.BF16: + error_result = True elif input_dtype == DType.FP32 and output_dtype != DType.FP32: error_result = True @@ -489,6 +503,7 @@ class TosaErrorValidator: DType.INT32, DType.FP32, DType.FP16, + DType.BF16, ] ) or ( @@ -500,6 +515,7 @@ class TosaErrorValidator: DType.INT32, DType.FP32, DType.FP16, + DType.BF16, ] ) or ( @@ -511,12 +527,17 @@ class TosaErrorValidator: DType.INT16, DType.FP32, DType.FP16, + DType.BF16, ] ) or ( input_dtype == DType.FP16 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32] ) + or ( + input_dtype == DType.BF16 + and output_dtype not in [DType.INT8, DType.INT16, DType.INT32] + ) or ( input_dtype == DType.FP32 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32] @@ -537,6 +558,8 @@ class TosaErrorValidator: and output_dtype != DType.INT48 or input_dtype == DType.FP16 and output_dtype not in (DType.FP16, DType.FP32) + or input_dtype == DType.BF16 + and output_dtype != DType.FP32 or input_dtype == DType.FP32 and output_dtype != DType.FP32 ): @@ -2316,12 +2339,14 @@ class TosaInvalidValidator: not (input_dtype == DType.INT8 and output_dtype == DType.INT32) and not (input_dtype == DType.INT16 and output_dtype == DType.INT48) and not (input_dtype == DType.FP16 and output_dtype == DType.FP16) + and not (input_dtype == DType.BF16 and output_dtype == DType.BF16) and not (input_dtype == DType.FP32 and output_dtype == DType.FP32) ) elif mode == ResizeMode.NEAREST: # Invalid output data type / Invalid input datatype return (input_dtype != output_dtype) or ( - input_dtype not in [DType.INT8, DType.INT16, DType.FP16, DType.FP32] + input_dtype + not in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32] ) else: # Invalid resize mode -- cgit v1.2.1