aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2022-09-27 13:50:00 +0100
committerJeremy Johnson <jeremy.johnson@arm.com>2022-10-13 18:21:15 +0100
commitbc2a3db54ecee48fe2236f7fc03da8fd07d81ca0 (patch)
treec3908f23c369fd3226e840f81c3ba4b49cc409a0 /verif/generator/tosa_arg_gen.py
parent93d4390f9aa5c4369f889e1cd336aa4e809ff6a7 (diff)
downloadreference_model-bc2a3db54ecee48fe2236f7fc03da8fd07d81ca0.tar.gz
Rename FLOAT type to FP32
Update tensor operations naming to state input type as TxT in all cases. Effects CONV2D, CONV3D, DEPTHWISE_CONV2D, FULLY_CONNECTED, TRANSPOSE_CONV2D. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: Ic959acfcb3aa0a910b33b774a5a85fac08219205
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py38
1 files changed, 19 insertions, 19 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index e0c6cf0..791fbf7 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -776,7 +776,7 @@ class TosaTensorValuesGen:
), "Op.MUL must have 2 placeholders, 0 consts"
tens = []
- if dtypeList[0] in (DType.FP16, DType.FLOAT):
+ if dtypeList[0] in (DType.FP16, DType.FP32):
tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
else:
placeholders = []
@@ -1106,10 +1106,10 @@ class TosaArgGen:
@staticmethod
def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None):
- if isinstance(dtypes, list) or isinstance(dtypes, tuple):
- input_dtype = dtypes[0]
- else:
- input_dtype = dtypes
+ assert isinstance(dtypes, list) or isinstance(
+ dtypes, tuple
+ ), f"{dtypes} unexpected"
+ input_dtype = dtypes[0]
if error_name == ErrorIf.WrongOutputType:
accum_dtype = get_wrong_output_type(opName, testGen.rng, input_dtype)
@@ -1129,9 +1129,9 @@ class TosaArgGen:
elif dtype == DType.INT16:
accum_dtypes = [DType.INT48]
elif dtype == DType.FP16:
- accum_dtypes = [DType.FP16, DType.FLOAT]
- elif dtype == DType.FLOAT:
- accum_dtypes = [DType.FLOAT]
+ accum_dtypes = [DType.FP16, DType.FP32]
+ elif dtype == DType.FP32:
+ accum_dtypes = [DType.FP32]
elif error_name is None:
assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
@@ -1245,7 +1245,7 @@ class TosaArgGen:
if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
pad_const_int = testGen.getRandNumberDType(dtype)
pad_const_fp = 0
- elif dtype in (DType.FP16, DType.FLOAT):
+ elif dtype in (DType.FP16, DType.FP32):
pad_const_int = 0
pad_const_fp = testGen.getRandNumberDType(dtype)
else:
@@ -1303,9 +1303,9 @@ class TosaArgGen:
elif dtype == DType.INT8 or dtype == DType.INT16:
accum_dtypes = [DType.INT32]
elif dtype == DType.FP16:
- accum_dtypes = [DType.FP16, DType.FLOAT]
- elif dtype == DType.FLOAT:
- accum_dtypes = [DType.FLOAT]
+ accum_dtypes = [DType.FP16, DType.FP32]
+ elif dtype == DType.FP32:
+ accum_dtypes = [DType.FP32]
elif error_name is None:
assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
else:
@@ -1408,20 +1408,20 @@ class TosaArgGen:
if error_name == ErrorIf.WrongOutputType:
dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
elif inDtype == DType.INT8:
- dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
+ dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FP32]
elif inDtype == DType.INT16:
- dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
+ dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FP32]
elif inDtype == DType.INT32:
- dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
+ dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
elif inDtype == DType.BOOL:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
elif inDtype == DType.FP16:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
- elif inDtype == DType.FLOAT:
+ elif inDtype == DType.FP32:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
elif error_name == ErrorIf.WrongInputType:
# Pick some potentially correct output type for incorrect input type
- dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
+ dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
else:
raise Exception("Unexpected input dtype: {}".format(inDtype))
@@ -1826,8 +1826,8 @@ class TosaArgGen:
outputDTypeList = [DType.INT48]
elif dtype == DType.FP16:
outputDTypeList = [DType.FP16]
- elif dtype == DType.FLOAT:
- outputDTypeList = [DType.FLOAT]
+ elif dtype == DType.FP32:
+ outputDTypeList = [DType.FP32]
elif error_name == ErrorIf.WrongInputType:
# If an incorrect input type is used then we set a 'correct'
# output type to avoid other errors