diff options
author | James Ward <james.ward@arm.com> | 2022-10-19 12:20:31 +0100 |
---|---|---|
committer | Jeremy Johnson <jeremy.johnson@arm.com> | 2022-11-09 12:19:51 +0000 |
commit | 24dbc420aae556649f50e645bd94489dab2cc75a (patch) | |
tree | 490345da43e9c5bae0f450ba05ffe85874077e0a /verif/generator/tosa_error_if.py | |
parent | 3b0544c1e7463295c49a48a162ebb9a546326829 (diff) | |
download | reference_model-24dbc420aae556649f50e645bd94489dab2cc75a.tar.gz |
Add BF16 support to reference model
* Upgrade Eigen to 3.4.0 (for bfloat16 support) and add work-
arounds for reduce.any() and reduce.all() bugs (introduced
between 3.3.7 and 3.4.0)
* Truncation to bfloat16 now performed in eval() methods
Signed-off-by: James Ward <james.ward@arm.com>
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: If5f5c988d76d3d30790acf3b97081726b89205fe
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r-- | verif/generator/tosa_error_if.py | 35 |
1 files changed, 30 insertions, 5 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index abe1a97..a850699 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -158,6 +158,15 @@ class TosaErrorIfArgGen: DType.INT48, DType.FP32, ) + elif dtype == DType.BF16: + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FP32, + ) elif dtype == DType.FP32: incorrect_types = ( DType.INT4, @@ -299,8 +308,8 @@ class TosaErrorIfArgGen: @staticmethod def eiCastErrorIf(testGen, input_dtype): - if input_dtype in [DType.BOOL, DType.FP16, DType.FP32]: - outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.FP32] + if input_dtype in [DType.BOOL, DType.FP16, DType.BF16, DType.FP32]: + outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.BF16, DType.FP32] elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]: outputDType = [DType.INT48] else: @@ -425,6 +434,7 @@ class TosaErrorValidator: and output_dtype != DType.INT48 ) or (input_dtype == DType.FP16 and output_dtype != DType.FP16) + or (input_dtype == DType.BF16 and output_dtype != DType.BF16) or (input_dtype == DType.FP32 and output_dtype != DType.FP32) ): error_result = True @@ -442,25 +452,29 @@ class TosaErrorValidator: input_dtype == DType.FP16 and output_dtype not in (DType.FP16, DType.FP32) ) + or (input_dtype == DType.BF16 and output_dtype != DType.FP32) or (input_dtype == DType.FP32 and output_dtype != DType.FP32) ): error_result = True elif op["op"] == Op.ARGMAX: if ( - input_dtype in [DType.INT8, DType.INT16, DType.FP16, DType.FP32] + input_dtype + in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32] and output_dtype != DType.INT32 ): error_result = True elif op["op"] == Op.MUL: if ( - input_dtype not in (DType.FP16, DType.FP32) + input_dtype not in (DType.FP16, DType.BF16, DType.FP32) and output_dtype != DType.INT32 ): error_result = True elif input_dtype == DType.FP16 and output_dtype != DType.FP16: error_result = True + elif input_dtype == DType.BF16 and output_dtype != DType.BF16: + error_result = True elif input_dtype == DType.FP32 and output_dtype != DType.FP32: error_result = True @@ -489,6 +503,7 @@ class TosaErrorValidator: DType.INT32, DType.FP32, DType.FP16, + DType.BF16, ] ) or ( @@ -500,6 +515,7 @@ class TosaErrorValidator: DType.INT32, DType.FP32, DType.FP16, + DType.BF16, ] ) or ( @@ -511,6 +527,7 @@ class TosaErrorValidator: DType.INT16, DType.FP32, DType.FP16, + DType.BF16, ] ) or ( @@ -518,6 +535,10 @@ class TosaErrorValidator: and output_dtype not in [DType.INT8, DType.INT16, DType.INT32] ) or ( + input_dtype == DType.BF16 + and output_dtype not in [DType.INT8, DType.INT16, DType.INT32] + ) + or ( input_dtype == DType.FP32 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32] ) @@ -537,6 +558,8 @@ class TosaErrorValidator: and output_dtype != DType.INT48 or input_dtype == DType.FP16 and output_dtype not in (DType.FP16, DType.FP32) + or input_dtype == DType.BF16 + and output_dtype != DType.FP32 or input_dtype == DType.FP32 and output_dtype != DType.FP32 ): @@ -2316,12 +2339,14 @@ class TosaInvalidValidator: not (input_dtype == DType.INT8 and output_dtype == DType.INT32) and not (input_dtype == DType.INT16 and output_dtype == DType.INT48) and not (input_dtype == DType.FP16 and output_dtype == DType.FP16) + and not (input_dtype == DType.BF16 and output_dtype == DType.BF16) and not (input_dtype == DType.FP32 and output_dtype == DType.FP32) ) elif mode == ResizeMode.NEAREST: # Invalid output data type / Invalid input datatype return (input_dtype != output_dtype) or ( - input_dtype not in [DType.INT8, DType.INT16, DType.FP16, DType.FP32] + input_dtype + not in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32] ) else: # Invalid resize mode |