From bc2a3db54ecee48fe2236f7fc03da8fd07d81ca0 Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Tue, 27 Sep 2022 13:50:00 +0100 Subject: Rename FLOAT type to FP32 Update tensor operations naming to state input type as TxT in all cases. Effects CONV2D, CONV3D, DEPTHWISE_CONV2D, FULLY_CONNECTED, TRANSPOSE_CONV2D. Signed-off-by: Jeremy Johnson Change-Id: Ic959acfcb3aa0a910b33b774a5a85fac08219205 --- verif/generator/tosa_arg_gen.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) (limited to 'verif/generator/tosa_arg_gen.py') diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index e0c6cf0..791fbf7 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -776,7 +776,7 @@ class TosaTensorValuesGen: ), "Op.MUL must have 2 placeholders, 0 consts" tens = [] - if dtypeList[0] in (DType.FP16, DType.FLOAT): + if dtypeList[0] in (DType.FP16, DType.FP32): tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:])) else: placeholders = [] @@ -1106,10 +1106,10 @@ class TosaArgGen: @staticmethod def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None): - if isinstance(dtypes, list) or isinstance(dtypes, tuple): - input_dtype = dtypes[0] - else: - input_dtype = dtypes + assert isinstance(dtypes, list) or isinstance( + dtypes, tuple + ), f"{dtypes} unexpected" + input_dtype = dtypes[0] if error_name == ErrorIf.WrongOutputType: accum_dtype = get_wrong_output_type(opName, testGen.rng, input_dtype) @@ -1129,9 +1129,9 @@ class TosaArgGen: elif dtype == DType.INT16: accum_dtypes = [DType.INT48] elif dtype == DType.FP16: - accum_dtypes = [DType.FP16, DType.FLOAT] - elif dtype == DType.FLOAT: - accum_dtypes = [DType.FLOAT] + accum_dtypes = [DType.FP16, DType.FP32] + elif dtype == DType.FP32: + accum_dtypes = [DType.FP32] elif error_name is None: assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}" @@ -1245,7 +1245,7 @@ class TosaArgGen: if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]: pad_const_int = testGen.getRandNumberDType(dtype) pad_const_fp = 0 - elif dtype in (DType.FP16, DType.FLOAT): + elif dtype in (DType.FP16, DType.FP32): pad_const_int = 0 pad_const_fp = testGen.getRandNumberDType(dtype) else: @@ -1303,9 +1303,9 @@ class TosaArgGen: elif dtype == DType.INT8 or dtype == DType.INT16: accum_dtypes = [DType.INT32] elif dtype == DType.FP16: - accum_dtypes = [DType.FP16, DType.FLOAT] - elif dtype == DType.FLOAT: - accum_dtypes = [DType.FLOAT] + accum_dtypes = [DType.FP16, DType.FP32] + elif dtype == DType.FP32: + accum_dtypes = [DType.FP32] elif error_name is None: assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}" else: @@ -1408,20 +1408,20 @@ class TosaArgGen: if error_name == ErrorIf.WrongOutputType: dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype) elif inDtype == DType.INT8: - dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT] + dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FP32] elif inDtype == DType.INT16: - dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT] + dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FP32] elif inDtype == DType.INT32: - dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT] + dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32] elif inDtype == DType.BOOL: dtypeList = [DType.INT8, DType.INT16, DType.INT32] elif inDtype == DType.FP16: dtypeList = [DType.INT8, DType.INT16, DType.INT32] - elif inDtype == DType.FLOAT: + elif inDtype == DType.FP32: dtypeList = [DType.INT8, DType.INT16, DType.INT32] elif error_name == ErrorIf.WrongInputType: # Pick some potentially correct output type for incorrect input type - dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT] + dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32] else: raise Exception("Unexpected input dtype: {}".format(inDtype)) @@ -1826,8 +1826,8 @@ class TosaArgGen: outputDTypeList = [DType.INT48] elif dtype == DType.FP16: outputDTypeList = [DType.FP16] - elif dtype == DType.FLOAT: - outputDTypeList = [DType.FLOAT] + elif dtype == DType.FP32: + outputDTypeList = [DType.FP32] elif error_name == ErrorIf.WrongInputType: # If an incorrect input type is used then we set a 'correct' # output type to avoid other errors -- cgit v1.2.1