aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py148
1 files changed, 59 insertions, 89 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 9ff6ec5..78d86cd 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -13,6 +13,7 @@ from generator.tosa_error_if import ErrorIf
from generator.tosa_error_if import TosaErrorIfArgGen
from generator.tosa_error_if import TosaErrorValidator
from generator.tosa_error_if import TosaInvalidValidator
+from generator.tosa_utils import DTYPE_ATTRIBUTES
from generator.tosa_utils import MAX_RESIZE_DIMENSION
from generator.tosa_utils import usableDTypes
from tosa.DType import DType
@@ -83,7 +84,7 @@ class TosaTestGen:
)
elif dtype == DType.FP16:
return np.float16(self.rng.random(size=shape))
- elif dtype == DType.FLOAT:
+ elif dtype == DType.FP32:
return np.float32(self.rng.random(size=shape))
else:
raise Exception("Unrecognized Dtype: {}".format(dtype))
@@ -128,7 +129,7 @@ class TosaTestGen:
return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
def getRandNumberDType(self, dtype):
- if dtype == DType.FLOAT:
+ if dtype == DType.FP32:
return self.rng.random()
elif dtype == DType.FP16:
rand_f32 = self.rng.random()
@@ -162,58 +163,26 @@ class TosaTestGen:
return "x".join(sStr)
- def typeStr(self, t):
- if isinstance(t, list):
- assert len(t) >= 2
- return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
+ def typeStr(self, dtype):
+ if isinstance(dtype, list) or isinstance(dtype, tuple):
+ assert len(dtype) >= 2
+ strs = [self.typeStr(t) for t in dtype]
+ # Limit types to the first 2 as the 3rd is the accumulator
+ return "x".join(strs[:2])
else:
- if t == DType.BOOL:
- return "b"
- elif t == DType.INT4:
- return "i4"
- elif t == DType.INT8:
- return "i8"
- elif t == DType.UINT8:
- return "u8"
- elif t == DType.INT16:
- return "i16"
- elif t == DType.UINT16:
- return "u16"
- elif t == DType.INT32:
- return "i32"
- elif t == DType.INT48:
- return "i48"
- elif t == DType.FP16:
- return "f16"
- elif t == DType.FLOAT:
- return "float"
+ if dtype in DTYPE_ATTRIBUTES:
+ return DTYPE_ATTRIBUTES[dtype]["str"]
else:
- raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
+ raise Exception(
+ "Unknown dtype, cannot convert to string: {}".format(dtype)
+ )
- def typeWidth(self, t):
+ def typeWidth(self, dtype):
"""Get the datatype width for data types"""
- if t == DType.INT4:
- return 4
- elif t == DType.INT8:
- return 8
- elif t == DType.UINT8:
- return 8
- elif t == DType.INT16:
- return 16
- elif t == DType.UINT16:
- return 16
- elif t == DType.INT32:
- return 32
- elif t == DType.INT48:
- return 48
- elif t == DType.FP16:
- return 16
- elif t == DType.FLOAT:
- return 32
- elif t == DType.BOOL:
- return 1
+ if dtype in DTYPE_ATTRIBUTES:
+ return DTYPE_ATTRIBUTES[dtype]["width"]
else:
- raise Exception(f"Unknown dtype, cannot determine width: {t}")
+ raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
# Argument generators
# Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
@@ -355,7 +324,7 @@ class TosaTestGen:
# Special for multiply:
# Force the result to INT32 for INT types
- if a.dtype not in (DType.FP16, DType.FLOAT):
+ if a.dtype not in (DType.FP16, DType.FP32):
result_tens.setDtype(DType.INT32)
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
@@ -1074,7 +1043,7 @@ class TosaTestGen:
return None
attr = ts.TosaSerializerAttribute()
- if a.dtype in (DType.FP16, DType.FLOAT):
+ if a.dtype in (DType.FP16, DType.FP32):
attr.ClampAttribute(0, 0, min_val, max_val)
else:
attr.ClampAttribute(min_val, max_val, 0, 0)
@@ -1086,7 +1055,7 @@ class TosaTestGen:
result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
attr = ts.TosaSerializerAttribute()
- attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
+ attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
return result_tens
@@ -1890,7 +1859,7 @@ class TosaTestGen:
op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
)
- if a.dtype in (DType.FLOAT, DType.FP16, DType.INT32):
+ if a.dtype in (DType.FP32, DType.FP16, DType.INT32):
then_op, else_op = Op.ADD, Op.SUB
elif a.dtype in (DType.INT8, DType.INT16):
then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
@@ -2001,7 +1970,7 @@ class TosaTestGen:
if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
cond_tens = self.ser.addOutput(
- [], self.rng.choice([DType.INT8, DType.INT32, DType.FLOAT])
+ [], self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
)
else:
cond_tens = self.ser.addOutput([], DType.BOOL)
@@ -2429,7 +2398,7 @@ class TosaTestGen:
# if not specified, defaults to (1, 4)
# 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
# 'types': array of datatypes to be tested
- TYPE_FP = [DType.FLOAT, DType.FP16]
+ TYPE_FP = [DType.FP32, DType.FP16]
TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
TYPE_INT_FP = [
@@ -2437,30 +2406,31 @@ class TosaTestGen:
DType.INT16,
DType.INT32,
DType.FP16,
- DType.FLOAT,
+ DType.FP32,
] # Excludes INT4
TYPE_BOOL = [DType.BOOL]
- TYPE_FI32 = [DType.FLOAT, DType.FP16, DType.INT32] # floating-types and INT32
+ TYPE_FI32 = [DType.FP32, DType.FP16, DType.INT32] # floating-types and INT32
TYPE_FIB = [
DType.FP16,
- DType.FLOAT,
+ DType.FP32,
DType.INT8,
DType.INT16,
DType.INT32,
DType.BOOL,
]
- TYPE_FI16 = [DType.FLOAT, DType.INT16]
+ TYPE_FI16 = [DType.FP32, DType.INT16]
- TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT]
+ TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.FP32]
+ # List of [Input Type 1, Input Type 2, Accumulator Type]
TYPE_CONV = [
[DType.INT8, DType.INT4, DType.INT32],
[DType.INT8, DType.INT8, DType.INT32],
[DType.INT16, DType.INT8, DType.INT48],
[DType.FP16, DType.FP16, DType.FP16],
- [DType.FP16, DType.FP16, DType.FLOAT],
- DType.FLOAT,
+ [DType.FP16, DType.FP16, DType.FP32],
+ [DType.FP32, DType.FP32, DType.FP32],
]
DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
@@ -3478,7 +3448,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgReduceSum,
TosaArgGen.agAxis,
),
- "types": (DType.FP16, DType.FLOAT, DType.INT32),
+ "types": (DType.FP16, DType.FP32, DType.INT32),
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evAxisSmallerZero,
@@ -3665,7 +3635,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgDefault,
None,
),
- "types": (DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.FLOAT),
+ "types": (DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.FP32),
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
@@ -3706,7 +3676,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agResize,
),
- "types": (DType.INT8, DType.INT16, DType.FP16, DType.FLOAT),
+ "types": (DType.INT8, DType.INT16, DType.FP16, DType.FP32),
"invalid_test_validators": (
TosaInvalidValidator.ivWrongDataTypeOrModeResize,
),
@@ -3742,7 +3712,7 @@ class TosaTestGen:
),
"types": (
DType.FP16,
- DType.FLOAT,
+ DType.FP32,
DType.INT8,
DType.INT16,
DType.INT32,
@@ -3872,7 +3842,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -3901,7 +3871,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -3929,7 +3899,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -3958,7 +3928,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
]
outputDType = rng.choice(wrong_dtypes)
else:
@@ -3984,7 +3954,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4016,7 +3986,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
]
wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
outputDType = rng.choice(wrong_dtypes)
@@ -4069,7 +4039,7 @@ class OutputShaper:
if error_name == ErrorIf.WrongOutputType:
if ifm.dtype == DType.FP16:
- excludes = [DType.FP16, DType.FLOAT]
+ excludes = [DType.FP16, DType.FP32]
else:
excludes = [out_dtype]
wrong_dtypes = list(usableDTypes(excludes=excludes))
@@ -4131,7 +4101,7 @@ class OutputShaper:
if error_name == ErrorIf.WrongOutputType:
if ifm.dtype == DType.FP16:
- excludes = [DType.FP16, DType.FLOAT]
+ excludes = [DType.FP16, DType.FP32]
else:
excludes = [out_dtype]
wrong_dtypes = list(usableDTypes(excludes=excludes))
@@ -4182,7 +4152,7 @@ class OutputShaper:
if error_name == ErrorIf.WrongOutputType:
if ifm.dtype == DType.FP16:
- excludes = [DType.FP16, DType.FLOAT]
+ excludes = [DType.FP16, DType.FP32]
else:
excludes = [out_dtype]
wrong_dtypes = list(usableDTypes(excludes=excludes))
@@ -4217,7 +4187,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
DType.FP16,
]
wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
@@ -4255,7 +4225,7 @@ class OutputShaper:
DType.INT8,
DType.INT16,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
)
elif a.dtype == DType.INT16:
incorrect_types = (
@@ -4263,9 +4233,9 @@ class OutputShaper:
DType.INT8,
DType.INT16,
DType.INT32,
- DType.FLOAT,
+ DType.FP32,
)
- elif a.dtype == DType.FLOAT or a.dtype == DType.FP16:
+ elif a.dtype == DType.FP32 or a.dtype == DType.FP16:
incorrect_types = (
DType.INT4,
DType.INT8,
@@ -4307,7 +4277,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
}
wrong_dtypes = list(all_dtypes - set([input1.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4334,7 +4304,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
DType.FP16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
@@ -4358,7 +4328,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4376,7 +4346,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4412,7 +4382,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4440,7 +4410,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4464,7 +4434,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
]
wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4491,7 +4461,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
]
wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4512,7 +4482,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
]
wrong_dtypes.remove(output_dtype)
output_dtype = rng.choice(wrong_dtypes)
@@ -4619,7 +4589,7 @@ class OutputShaper:
if error_name == ErrorIf.WrongOutputType:
if ifm.dtype == DType.FP16:
- excludes = [DType.FP16, DType.FLOAT]
+ excludes = [DType.FP16, DType.FP32]
else:
excludes = [out_dtype]
wrong_dtypes = list(usableDTypes(excludes=excludes))