From bc2a3db54ecee48fe2236f7fc03da8fd07d81ca0 Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Tue, 27 Sep 2022 13:50:00 +0100 Subject: Rename FLOAT type to FP32 Update tensor operations naming to state input type as TxT in all cases. Effects CONV2D, CONV3D, DEPTHWISE_CONV2D, FULLY_CONNECTED, TRANSPOSE_CONV2D. Signed-off-by: Jeremy Johnson Change-Id: Ic959acfcb3aa0a910b33b774a5a85fac08219205 --- verif/generator/tosa_error_if.py | 56 +++++++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 23 deletions(-) (limited to 'verif/generator/tosa_error_if.py') diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index a766803..abe1a97 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -119,7 +119,7 @@ class TosaErrorIfArgGen: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, DType.FP16, ) elif mode == ResizeMode.NEAREST and dtype == DType.INT16: @@ -128,7 +128,7 @@ class TosaErrorIfArgGen: DType.INT8, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, DType.FP16, ) elif mode == ResizeMode.BILINEAR and dtype == DType.INT8: @@ -137,7 +137,7 @@ class TosaErrorIfArgGen: DType.INT8, DType.INT16, DType.INT48, - DType.FLOAT, + DType.FP32, DType.FP16, ) elif mode == ResizeMode.BILINEAR and dtype == DType.INT16: @@ -146,7 +146,7 @@ class TosaErrorIfArgGen: DType.INT8, DType.INT16, DType.INT32, - DType.FLOAT, + DType.FP32, DType.FP16, ) elif dtype == DType.FP16: @@ -156,9 +156,9 @@ class TosaErrorIfArgGen: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, ) - elif dtype == DType.FLOAT: + elif dtype == DType.FP32: incorrect_types = ( DType.INT4, DType.INT8, @@ -299,8 +299,8 @@ class TosaErrorIfArgGen: @staticmethod def eiCastErrorIf(testGen, input_dtype): - if input_dtype in [DType.BOOL, DType.FP16, DType.FLOAT]: - outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.FLOAT] + if input_dtype in [DType.BOOL, DType.FP16, DType.FP32]: + outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.FP32] elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]: outputDType = [DType.INT48] else: @@ -366,6 +366,16 @@ class TosaErrorValidator: } wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes)) + # Turn the wrong dtypes into required list of types + if op["op"] in [ + Op.FULLY_CONNECTED, + Op.CONV2D, + Op.CONV3D, + Op.DEPTHWISE_CONV2D, + Op.TRANSPOSE_CONV2D, + ]: + wrong_input_dtypes = [[t, t, t] for t in wrong_input_dtypes] + if op["op"] == Op.CLAMP: wrong_input_dtypes.remove(DType.INT48) @@ -415,7 +425,7 @@ class TosaErrorValidator: and output_dtype != DType.INT48 ) or (input_dtype == DType.FP16 and output_dtype != DType.FP16) - or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT) + or (input_dtype == DType.FP32 and output_dtype != DType.FP32) ): error_result = True @@ -430,28 +440,28 @@ class TosaErrorValidator: or (input_dtype == DType.INT16 and output_dtype != DType.INT48) or ( input_dtype == DType.FP16 - and output_dtype not in (DType.FP16, DType.FLOAT) + and output_dtype not in (DType.FP16, DType.FP32) ) - or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT) + or (input_dtype == DType.FP32 and output_dtype != DType.FP32) ): error_result = True elif op["op"] == Op.ARGMAX: if ( - input_dtype in [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT] + input_dtype in [DType.INT8, DType.INT16, DType.FP16, DType.FP32] and output_dtype != DType.INT32 ): error_result = True elif op["op"] == Op.MUL: if ( - input_dtype not in (DType.FP16, DType.FLOAT) + input_dtype not in (DType.FP16, DType.FP32) and output_dtype != DType.INT32 ): error_result = True elif input_dtype == DType.FP16 and output_dtype != DType.FP16: error_result = True - elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT: + elif input_dtype == DType.FP32 and output_dtype != DType.FP32: error_result = True elif op["op"] == Op.TABLE: @@ -477,7 +487,7 @@ class TosaErrorValidator: DType.BOOL, DType.INT16, DType.INT32, - DType.FLOAT, + DType.FP32, DType.FP16, ] ) @@ -488,7 +498,7 @@ class TosaErrorValidator: DType.BOOL, DType.INT8, DType.INT32, - DType.FLOAT, + DType.FP32, DType.FP16, ] ) @@ -499,7 +509,7 @@ class TosaErrorValidator: DType.BOOL, DType.INT8, DType.INT16, - DType.FLOAT, + DType.FP32, DType.FP16, ] ) @@ -508,7 +518,7 @@ class TosaErrorValidator: and output_dtype not in [DType.INT8, DType.INT16, DType.INT32] ) or ( - input_dtype == DType.FLOAT + input_dtype == DType.FP32 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32] ) ): @@ -526,9 +536,9 @@ class TosaErrorValidator: or input_dtype == DType.INT16 and output_dtype != DType.INT48 or input_dtype == DType.FP16 - and output_dtype not in (DType.FP16, DType.FLOAT) - or input_dtype == DType.FLOAT - and output_dtype != DType.FLOAT + and output_dtype not in (DType.FP16, DType.FP32) + or input_dtype == DType.FP32 + and output_dtype != DType.FP32 ): error_result = True # invalid input types are ignored, to avoid reporting multiple errors @@ -2306,12 +2316,12 @@ class TosaInvalidValidator: not (input_dtype == DType.INT8 and output_dtype == DType.INT32) and not (input_dtype == DType.INT16 and output_dtype == DType.INT48) and not (input_dtype == DType.FP16 and output_dtype == DType.FP16) - and not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT) + and not (input_dtype == DType.FP32 and output_dtype == DType.FP32) ) elif mode == ResizeMode.NEAREST: # Invalid output data type / Invalid input datatype return (input_dtype != output_dtype) or ( - input_dtype not in [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT] + input_dtype not in [DType.INT8, DType.INT16, DType.FP16, DType.FP32] ) else: # Invalid resize mode -- cgit v1.2.1