aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_utils.py')
-rw-r--r--verif/generator/tosa_utils.py25
1 files changed, 18 insertions, 7 deletions
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index 7fa31e7..104d9bb 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -5,6 +5,19 @@ from tosa.DType import DType
# Maximum dimension size for output and inputs for RESIZE
MAX_RESIZE_DIMENSION = 16384
+DTYPE_ATTRIBUTES = {
+ DType.BOOL: {"str": "b", "width": 1},
+ DType.INT4: {"str": "i4", "width": 4},
+ DType.INT8: {"str": "i8", "width": 8},
+ DType.UINT8: {"str": "u8", "width": 8},
+ DType.INT16: {"str": "i16", "width": 16},
+ DType.UINT16: {"str": "u16", "width": 16},
+ DType.INT32: {"str": "i32", "width": 32},
+ DType.INT48: {"str": "i48", "width": 48},
+ DType.FP16: {"str": "f16", "width": 16},
+ DType.FP32: {"str": "f32", "width": 32},
+}
+
def valueToName(item, value):
"""Get the name of an attribute with the given value.
@@ -88,10 +101,8 @@ def product(shape):
def get_accum_dtype_from_tgTypes(dtypes):
# Get accumulate data-type from the test generator's defined types
- if isinstance(dtypes, list) or isinstance(dtypes, tuple):
- return dtypes[-1]
- else:
- return dtypes
+ assert isinstance(dtypes, list) or isinstance(dtypes, tuple)
+ return dtypes[-1]
def get_wrong_output_type(op_name, rng, input_dtype):
@@ -102,7 +113,7 @@ def get_wrong_output_type(op_name, rng, input_dtype):
DType.INT8,
DType.INT16,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
DType.FP16,
)
elif input_dtype == DType.INT16:
@@ -111,10 +122,10 @@ def get_wrong_output_type(op_name, rng, input_dtype):
DType.INT8,
DType.INT16,
DType.INT32,
- DType.FLOAT,
+ DType.FP32,
DType.FP16,
)
- elif input_dtype == DType.FLOAT or input_dtype == DType.FP16:
+ elif input_dtype == DType.FP32 or input_dtype == DType.FP16:
incorrect_types = (
DType.INT4,
DType.INT8,