aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py80
1 files changed, 70 insertions, 10 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 78d86cd..95e06ed 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -16,6 +16,7 @@ from generator.tosa_error_if import TosaInvalidValidator
from generator.tosa_utils import DTYPE_ATTRIBUTES
from generator.tosa_utils import MAX_RESIZE_DIMENSION
from generator.tosa_utils import usableDTypes
+from generator.tosa_utils import vect_f32_to_bf16
from tosa.DType import DType
from tosa.Op import Op
@@ -84,6 +85,10 @@ class TosaTestGen:
)
elif dtype == DType.FP16:
return np.float16(self.rng.random(size=shape))
+ elif dtype == DType.BF16:
+ f32_tensor = np.float32(self.rng.random(size=shape))
+ # Floor the last 16 bits of each f32 value
+ return np.float32(vect_f32_to_bf16(f32_tensor))
elif dtype == DType.FP32:
return np.float32(self.rng.random(size=shape))
else:
@@ -134,6 +139,9 @@ class TosaTestGen:
elif dtype == DType.FP16:
rand_f32 = self.rng.random()
return np.float16(rand_f32)
+ elif dtype == DType.BF16:
+ rand_f32 = self.rng.random()
+ return vect_f32_to_bf16(rand_f32)
elif dtype == DType.BOOL:
return self.rng.choice([False, True])
# TOSA specific INT4 weight range from -7 to 7
@@ -324,7 +332,7 @@ class TosaTestGen:
# Special for multiply:
# Force the result to INT32 for INT types
- if a.dtype not in (DType.FP16, DType.FP32):
+ if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
result_tens.setDtype(DType.INT32)
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
@@ -1043,7 +1051,7 @@ class TosaTestGen:
return None
attr = ts.TosaSerializerAttribute()
- if a.dtype in (DType.FP16, DType.FP32):
+ if a.dtype in (DType.FP16, DType.BF16, DType.FP32):
attr.ClampAttribute(0, 0, min_val, max_val)
else:
attr.ClampAttribute(min_val, max_val, 0, 0)
@@ -1859,7 +1867,7 @@ class TosaTestGen:
op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
)
- if a.dtype in (DType.FP32, DType.FP16, DType.INT32):
+ if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
then_op, else_op = Op.ADD, Op.SUB
elif a.dtype in (DType.INT8, DType.INT16):
then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
@@ -2398,7 +2406,7 @@ class TosaTestGen:
# if not specified, defaults to (1, 4)
# 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
# 'types': array of datatypes to be tested
- TYPE_FP = [DType.FP32, DType.FP16]
+ TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
TYPE_INT_FP = [
@@ -2406,13 +2414,20 @@ class TosaTestGen:
DType.INT16,
DType.INT32,
DType.FP16,
+ DType.BF16,
DType.FP32,
] # Excludes INT4
TYPE_BOOL = [DType.BOOL]
- TYPE_FI32 = [DType.FP32, DType.FP16, DType.INT32] # floating-types and INT32
+ TYPE_FI32 = [
+ DType.FP32,
+ DType.FP16,
+ DType.BF16,
+ DType.INT32,
+ ] # floating-types and INT32
TYPE_FIB = [
DType.FP16,
+ DType.BF16,
DType.FP32,
DType.INT8,
DType.INT16,
@@ -2421,7 +2436,7 @@ class TosaTestGen:
]
TYPE_FI16 = [DType.FP32, DType.INT16]
- TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.FP32]
+ TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
# List of [Input Type 1, Input Type 2, Accumulator Type]
TYPE_CONV = [
@@ -2430,6 +2445,7 @@ class TosaTestGen:
[DType.INT16, DType.INT8, DType.INT48],
[DType.FP16, DType.FP16, DType.FP16],
[DType.FP16, DType.FP16, DType.FP32],
+ [DType.BF16, DType.BF16, DType.FP32],
[DType.FP32, DType.FP32, DType.FP32],
]
@@ -3448,7 +3464,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgReduceSum,
TosaArgGen.agAxis,
),
- "types": (DType.FP16, DType.FP32, DType.INT32),
+ "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evAxisSmallerZero,
@@ -3635,7 +3651,14 @@ class TosaTestGen:
TosaTensorValuesGen.tvgDefault,
None,
),
- "types": (DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.FP32),
+ "types": (
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
+ ),
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
@@ -3676,7 +3699,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agResize,
),
- "types": (DType.INT8, DType.INT16, DType.FP16, DType.FP32),
+ "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
"invalid_test_validators": (
TosaInvalidValidator.ivWrongDataTypeOrModeResize,
),
@@ -3712,6 +3735,7 @@ class TosaTestGen:
),
"types": (
DType.FP16,
+ DType.BF16,
DType.FP32,
DType.INT8,
DType.INT16,
@@ -3842,6 +3866,8 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
+ DType.FP16,
+ DType.BF16,
DType.FP32,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
@@ -3872,6 +3898,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -3900,6 +3928,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -3929,6 +3959,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
outputDType = rng.choice(wrong_dtypes)
else:
@@ -3955,6 +3987,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -3987,6 +4021,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
outputDType = rng.choice(wrong_dtypes)
@@ -4189,6 +4225,7 @@ class OutputShaper:
DType.INT48,
DType.FP32,
DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4226,6 +4263,8 @@ class OutputShaper:
DType.INT16,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
)
elif a.dtype == DType.INT16:
incorrect_types = (
@@ -4234,8 +4273,12 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
)
- elif a.dtype == DType.FP32 or a.dtype == DType.FP16:
+ elif (
+ a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
+ ):
incorrect_types = (
DType.INT4,
DType.INT8,
@@ -4278,6 +4321,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
}
wrong_dtypes = list(all_dtypes - set([input1.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4306,6 +4351,7 @@ class OutputShaper:
DType.INT48,
DType.FP32,
DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4329,6 +4375,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4347,6 +4395,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4383,6 +4433,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4411,6 +4463,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4435,6 +4489,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4462,6 +4518,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4483,6 +4541,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes.remove(output_dtype)
output_dtype = rng.choice(wrong_dtypes)