diff options
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 0203513..932ad55 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -776,7 +776,7 @@ class TosaTensorValuesGen: ), "Op.MUL must have 2 placeholders, 0 consts" tens = [] - if dtypeList[0] in (DType.FP16, DType.FP32): + if dtypeList[0] in (DType.FP16, DType.BF16, DType.FP32): tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:])) else: placeholders = [] @@ -1130,6 +1130,8 @@ class TosaArgGen: accum_dtypes = [DType.INT48] elif dtype == DType.FP16: accum_dtypes = [DType.FP16, DType.FP32] + elif dtype == DType.BF16: + accum_dtypes = [DType.FP32] elif dtype == DType.FP32: accum_dtypes = [DType.FP32] elif error_name is None: @@ -1304,7 +1306,7 @@ class TosaArgGen: accum_dtypes = [DType.INT32] elif dtype == DType.FP16: accum_dtypes = [DType.FP16, DType.FP32] - elif dtype == DType.FP32: + elif dtype == DType.BF16 or dtype == DType.FP32: accum_dtypes = [DType.FP32] elif error_name is None: assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}" @@ -1417,6 +1419,8 @@ class TosaArgGen: dtypeList = [DType.INT8, DType.INT16, DType.INT32] elif inDtype == DType.FP16: dtypeList = [DType.INT8, DType.INT16, DType.INT32] + elif inDtype == DType.BF16: + dtypeList = [DType.INT8, DType.INT16, DType.INT32] elif inDtype == DType.FP32: dtypeList = [DType.INT8, DType.INT16, DType.INT32] elif error_name == ErrorIf.WrongInputType: @@ -1826,6 +1830,8 @@ class TosaArgGen: outputDTypeList = [DType.INT48] elif dtype == DType.FP16: outputDTypeList = [DType.FP16] + elif dtype == DType.BF16: + outputDTypeList = [DType.BF16] elif dtype == DType.FP32: outputDTypeList = [DType.FP32] elif error_name == ErrorIf.WrongInputType: |