aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py10
1 files changed, 8 insertions, 2 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 0203513..932ad55 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -776,7 +776,7 @@ class TosaTensorValuesGen:
), "Op.MUL must have 2 placeholders, 0 consts"
tens = []
- if dtypeList[0] in (DType.FP16, DType.FP32):
+ if dtypeList[0] in (DType.FP16, DType.BF16, DType.FP32):
tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
else:
placeholders = []
@@ -1130,6 +1130,8 @@ class TosaArgGen:
accum_dtypes = [DType.INT48]
elif dtype == DType.FP16:
accum_dtypes = [DType.FP16, DType.FP32]
+ elif dtype == DType.BF16:
+ accum_dtypes = [DType.FP32]
elif dtype == DType.FP32:
accum_dtypes = [DType.FP32]
elif error_name is None:
@@ -1304,7 +1306,7 @@ class TosaArgGen:
accum_dtypes = [DType.INT32]
elif dtype == DType.FP16:
accum_dtypes = [DType.FP16, DType.FP32]
- elif dtype == DType.FP32:
+ elif dtype == DType.BF16 or dtype == DType.FP32:
accum_dtypes = [DType.FP32]
elif error_name is None:
assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
@@ -1417,6 +1419,8 @@ class TosaArgGen:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
elif inDtype == DType.FP16:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
+ elif inDtype == DType.BF16:
+ dtypeList = [DType.INT8, DType.INT16, DType.INT32]
elif inDtype == DType.FP32:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
elif error_name == ErrorIf.WrongInputType:
@@ -1826,6 +1830,8 @@ class TosaArgGen:
outputDTypeList = [DType.INT48]
elif dtype == DType.FP16:
outputDTypeList = [DType.FP16]
+ elif dtype == DType.BF16:
+ outputDTypeList = [DType.BF16]
elif dtype == DType.FP32:
outputDTypeList = [DType.FP32]
elif error_name == ErrorIf.WrongInputType: