diff options
author | Jeremy Johnson <jeremy.johnson@arm.com> | 2022-09-27 13:50:00 +0100 |
---|---|---|
committer | Jeremy Johnson <jeremy.johnson@arm.com> | 2022-10-13 18:21:15 +0100 |
commit | bc2a3db54ecee48fe2236f7fc03da8fd07d81ca0 (patch) | |
tree | c3908f23c369fd3226e840f81c3ba4b49cc409a0 /verif/generator/tosa_utils.py | |
parent | 93d4390f9aa5c4369f889e1cd336aa4e809ff6a7 (diff) | |
download | reference_model-bc2a3db54ecee48fe2236f7fc03da8fd07d81ca0.tar.gz |
Rename FLOAT type to FP32
Update tensor operations naming to state input type as TxT in
all cases. Effects CONV2D, CONV3D, DEPTHWISE_CONV2D,
FULLY_CONNECTED, TRANSPOSE_CONV2D.
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: Ic959acfcb3aa0a910b33b774a5a85fac08219205
Diffstat (limited to 'verif/generator/tosa_utils.py')
-rw-r--r-- | verif/generator/tosa_utils.py | 25 |
1 files changed, 18 insertions, 7 deletions
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py index 7fa31e7..104d9bb 100644 --- a/verif/generator/tosa_utils.py +++ b/verif/generator/tosa_utils.py @@ -5,6 +5,19 @@ from tosa.DType import DType # Maximum dimension size for output and inputs for RESIZE MAX_RESIZE_DIMENSION = 16384 +DTYPE_ATTRIBUTES = { + DType.BOOL: {"str": "b", "width": 1}, + DType.INT4: {"str": "i4", "width": 4}, + DType.INT8: {"str": "i8", "width": 8}, + DType.UINT8: {"str": "u8", "width": 8}, + DType.INT16: {"str": "i16", "width": 16}, + DType.UINT16: {"str": "u16", "width": 16}, + DType.INT32: {"str": "i32", "width": 32}, + DType.INT48: {"str": "i48", "width": 48}, + DType.FP16: {"str": "f16", "width": 16}, + DType.FP32: {"str": "f32", "width": 32}, +} + def valueToName(item, value): """Get the name of an attribute with the given value. @@ -88,10 +101,8 @@ def product(shape): def get_accum_dtype_from_tgTypes(dtypes): # Get accumulate data-type from the test generator's defined types - if isinstance(dtypes, list) or isinstance(dtypes, tuple): - return dtypes[-1] - else: - return dtypes + assert isinstance(dtypes, list) or isinstance(dtypes, tuple) + return dtypes[-1] def get_wrong_output_type(op_name, rng, input_dtype): @@ -102,7 +113,7 @@ def get_wrong_output_type(op_name, rng, input_dtype): DType.INT8, DType.INT16, DType.INT48, - DType.FLOAT, + DType.FP32, DType.FP16, ) elif input_dtype == DType.INT16: @@ -111,10 +122,10 @@ def get_wrong_output_type(op_name, rng, input_dtype): DType.INT8, DType.INT16, DType.INT32, - DType.FLOAT, + DType.FP32, DType.FP16, ) - elif input_dtype == DType.FLOAT or input_dtype == DType.FP16: + elif input_dtype == DType.FP32 or input_dtype == DType.FP16: incorrect_types = ( DType.INT4, DType.INT8, |