aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py38
1 files changed, 19 insertions, 19 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index e0c6cf0..791fbf7 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -776,7 +776,7 @@ class TosaTensorValuesGen:
), "Op.MUL must have 2 placeholders, 0 consts"
tens = []
- if dtypeList[0] in (DType.FP16, DType.FLOAT):
+ if dtypeList[0] in (DType.FP16, DType.FP32):
tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
else:
placeholders = []
@@ -1106,10 +1106,10 @@ class TosaArgGen:
@staticmethod
def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None):
- if isinstance(dtypes, list) or isinstance(dtypes, tuple):
- input_dtype = dtypes[0]
- else:
- input_dtype = dtypes
+ assert isinstance(dtypes, list) or isinstance(
+ dtypes, tuple
+ ), f"{dtypes} unexpected"
+ input_dtype = dtypes[0]
if error_name == ErrorIf.WrongOutputType:
accum_dtype = get_wrong_output_type(opName, testGen.rng, input_dtype)
@@ -1129,9 +1129,9 @@ class TosaArgGen:
elif dtype == DType.INT16:
accum_dtypes = [DType.INT48]
elif dtype == DType.FP16:
- accum_dtypes = [DType.FP16, DType.FLOAT]
- elif dtype == DType.FLOAT:
- accum_dtypes = [DType.FLOAT]
+ accum_dtypes = [DType.FP16, DType.FP32]
+ elif dtype == DType.FP32:
+ accum_dtypes = [DType.FP32]
elif error_name is None:
assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
@@ -1245,7 +1245,7 @@ class TosaArgGen:
if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
pad_const_int = testGen.getRandNumberDType(dtype)
pad_const_fp = 0
- elif dtype in (DType.FP16, DType.FLOAT):
+ elif dtype in (DType.FP16, DType.FP32):
pad_const_int = 0
pad_const_fp = testGen.getRandNumberDType(dtype)
else:
@@ -1303,9 +1303,9 @@ class TosaArgGen:
elif dtype == DType.INT8 or dtype == DType.INT16:
accum_dtypes = [DType.INT32]
elif dtype == DType.FP16:
- accum_dtypes = [DType.FP16, DType.FLOAT]
- elif dtype == DType.FLOAT:
- accum_dtypes = [DType.FLOAT]
+ accum_dtypes = [DType.FP16, DType.FP32]
+ elif dtype == DType.FP32:
+ accum_dtypes = [DType.FP32]
elif error_name is None:
assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
else:
@@ -1408,20 +1408,20 @@ class TosaArgGen:
if error_name == ErrorIf.WrongOutputType:
dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
elif inDtype == DType.INT8:
- dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
+ dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FP32]
elif inDtype == DType.INT16:
- dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
+ dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FP32]
elif inDtype == DType.INT32:
- dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
+ dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
elif inDtype == DType.BOOL:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
elif inDtype == DType.FP16:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
- elif inDtype == DType.FLOAT:
+ elif inDtype == DType.FP32:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
elif error_name == ErrorIf.WrongInputType:
# Pick some potentially correct output type for incorrect input type
- dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
+ dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
else:
raise Exception("Unexpected input dtype: {}".format(inDtype))
@@ -1826,8 +1826,8 @@ class TosaArgGen:
outputDTypeList = [DType.INT48]
elif dtype == DType.FP16:
outputDTypeList = [DType.FP16]
- elif dtype == DType.FLOAT:
- outputDTypeList = [DType.FLOAT]
+ elif dtype == DType.FP32:
+ outputDTypeList = [DType.FP32]
elif error_name == ErrorIf.WrongInputType:
# If an incorrect input type is used then we set a 'correct'
# output type to avoid other errors