From bc2a3db54ecee48fe2236f7fc03da8fd07d81ca0 Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Tue, 27 Sep 2022 13:50:00 +0100 Subject: Rename FLOAT type to FP32 Update tensor operations naming to state input type as TxT in all cases. Effects CONV2D, CONV3D, DEPTHWISE_CONV2D, FULLY_CONNECTED, TRANSPOSE_CONV2D. Signed-off-by: Jeremy Johnson Change-Id: Ic959acfcb3aa0a910b33b774a5a85fac08219205 --- verif/generator/tosa_test_gen.py | 148 ++++++++++++++++----------------------- 1 file changed, 59 insertions(+), 89 deletions(-) (limited to 'verif/generator/tosa_test_gen.py') diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 9ff6ec5..78d86cd 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -13,6 +13,7 @@ from generator.tosa_error_if import ErrorIf from generator.tosa_error_if import TosaErrorIfArgGen from generator.tosa_error_if import TosaErrorValidator from generator.tosa_error_if import TosaInvalidValidator +from generator.tosa_utils import DTYPE_ATTRIBUTES from generator.tosa_utils import MAX_RESIZE_DIMENSION from generator.tosa_utils import usableDTypes from tosa.DType import DType @@ -83,7 +84,7 @@ class TosaTestGen: ) elif dtype == DType.FP16: return np.float16(self.rng.random(size=shape)) - elif dtype == DType.FLOAT: + elif dtype == DType.FP32: return np.float32(self.rng.random(size=shape)) else: raise Exception("Unrecognized Dtype: {}".format(dtype)) @@ -128,7 +129,7 @@ class TosaTestGen: return np.int32(self.rng.integers(low=low, high=high, size=1))[0] def getRandNumberDType(self, dtype): - if dtype == DType.FLOAT: + if dtype == DType.FP32: return self.rng.random() elif dtype == DType.FP16: rand_f32 = self.rng.random() @@ -162,58 +163,26 @@ class TosaTestGen: return "x".join(sStr) - def typeStr(self, t): - if isinstance(t, list): - assert len(t) >= 2 - return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1])) + def typeStr(self, dtype): + if isinstance(dtype, list) or isinstance(dtype, tuple): + assert len(dtype) >= 2 + strs = [self.typeStr(t) for t in dtype] + # Limit types to the first 2 as the 3rd is the accumulator + return "x".join(strs[:2]) else: - if t == DType.BOOL: - return "b" - elif t == DType.INT4: - return "i4" - elif t == DType.INT8: - return "i8" - elif t == DType.UINT8: - return "u8" - elif t == DType.INT16: - return "i16" - elif t == DType.UINT16: - return "u16" - elif t == DType.INT32: - return "i32" - elif t == DType.INT48: - return "i48" - elif t == DType.FP16: - return "f16" - elif t == DType.FLOAT: - return "float" + if dtype in DTYPE_ATTRIBUTES: + return DTYPE_ATTRIBUTES[dtype]["str"] else: - raise Exception("Unknown dtype, cannot convert to string: {}".format(t)) + raise Exception( + "Unknown dtype, cannot convert to string: {}".format(dtype) + ) - def typeWidth(self, t): + def typeWidth(self, dtype): """Get the datatype width for data types""" - if t == DType.INT4: - return 4 - elif t == DType.INT8: - return 8 - elif t == DType.UINT8: - return 8 - elif t == DType.INT16: - return 16 - elif t == DType.UINT16: - return 16 - elif t == DType.INT32: - return 32 - elif t == DType.INT48: - return 48 - elif t == DType.FP16: - return 16 - elif t == DType.FLOAT: - return 32 - elif t == DType.BOOL: - return 1 + if dtype in DTYPE_ATTRIBUTES: + return DTYPE_ATTRIBUTES[dtype]["width"] else: - raise Exception(f"Unknown dtype, cannot determine width: {t}") + raise Exception(f"Unknown dtype, cannot determine width: {dtype}") # Argument generators # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list]) @@ -355,7 +324,7 @@ class TosaTestGen: # Special for multiply: # Force the result to INT32 for INT types - if a.dtype not in (DType.FP16, DType.FLOAT): + if a.dtype not in (DType.FP16, DType.FP32): result_tens.setDtype(DType.INT32) if error_name == ErrorIf.WrongOutputType: all_dtypes = [DType.INT8, DType.INT16, DType.INT48] @@ -1074,7 +1043,7 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - if a.dtype in (DType.FP16, DType.FLOAT): + if a.dtype in (DType.FP16, DType.FP32): attr.ClampAttribute(0, 0, min_val, max_val) else: attr.ClampAttribute(min_val, max_val, 0, 0) @@ -1086,7 +1055,7 @@ class TosaTestGen: result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name) attr = ts.TosaSerializerAttribute() - attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT)) + attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32)) self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr) return result_tens @@ -1890,7 +1859,7 @@ class TosaTestGen: op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr ) - if a.dtype in (DType.FLOAT, DType.FP16, DType.INT32): + if a.dtype in (DType.FP32, DType.FP16, DType.INT32): then_op, else_op = Op.ADD, Op.SUB elif a.dtype in (DType.INT8, DType.INT16): then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT @@ -2001,7 +1970,7 @@ class TosaTestGen: if error_name == ErrorIf.CondGraphOutputNotMatchingBool: cond_tens = self.ser.addOutput( - [], self.rng.choice([DType.INT8, DType.INT32, DType.FLOAT]) + [], self.rng.choice([DType.INT8, DType.INT32, DType.FP32]) ) else: cond_tens = self.ser.addOutput([], DType.BOOL) @@ -2429,7 +2398,7 @@ class TosaTestGen: # if not specified, defaults to (1, 4) # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum) # 'types': array of datatypes to be tested - TYPE_FP = [DType.FLOAT, DType.FP16] + TYPE_FP = [DType.FP32, DType.FP16] TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4 TYPE_INT_FP = [ @@ -2437,30 +2406,31 @@ class TosaTestGen: DType.INT16, DType.INT32, DType.FP16, - DType.FLOAT, + DType.FP32, ] # Excludes INT4 TYPE_BOOL = [DType.BOOL] - TYPE_FI32 = [DType.FLOAT, DType.FP16, DType.INT32] # floating-types and INT32 + TYPE_FI32 = [DType.FP32, DType.FP16, DType.INT32] # floating-types and INT32 TYPE_FIB = [ DType.FP16, - DType.FLOAT, + DType.FP32, DType.INT8, DType.INT16, DType.INT32, DType.BOOL, ] - TYPE_FI16 = [DType.FLOAT, DType.INT16] + TYPE_FI16 = [DType.FP32, DType.INT16] - TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT] + TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.FP32] + # List of [Input Type 1, Input Type 2, Accumulator Type] TYPE_CONV = [ [DType.INT8, DType.INT4, DType.INT32], [DType.INT8, DType.INT8, DType.INT32], [DType.INT16, DType.INT8, DType.INT48], [DType.FP16, DType.FP16, DType.FP16], - [DType.FP16, DType.FP16, DType.FLOAT], - DType.FLOAT, + [DType.FP16, DType.FP16, DType.FP32], + [DType.FP32, DType.FP32, DType.FP32], ] DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK) @@ -3478,7 +3448,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgReduceSum, TosaArgGen.agAxis, ), - "types": (DType.FP16, DType.FLOAT, DType.INT32), + "types": (DType.FP16, DType.FP32, DType.INT32), "error_if_validators": ( TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, @@ -3665,7 +3635,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgDefault, None, ), - "types": (DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.FLOAT), + "types": (DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.FP32), "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, @@ -3706,7 +3676,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgDefault, TosaArgGen.agResize, ), - "types": (DType.INT8, DType.INT16, DType.FP16, DType.FLOAT), + "types": (DType.INT8, DType.INT16, DType.FP16, DType.FP32), "invalid_test_validators": ( TosaInvalidValidator.ivWrongDataTypeOrModeResize, ), @@ -3742,7 +3712,7 @@ class TosaTestGen: ), "types": ( DType.FP16, - DType.FLOAT, + DType.FP32, DType.INT8, DType.INT16, DType.INT32, @@ -3872,7 +3842,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -3901,7 +3871,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -3929,7 +3899,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -3958,7 +3928,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, ] outputDType = rng.choice(wrong_dtypes) else: @@ -3984,7 +3954,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4016,7 +3986,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, ] wrong_dtypes = list(set(all_dtypes) - set([DType.INT32])) outputDType = rng.choice(wrong_dtypes) @@ -4069,7 +4039,7 @@ class OutputShaper: if error_name == ErrorIf.WrongOutputType: if ifm.dtype == DType.FP16: - excludes = [DType.FP16, DType.FLOAT] + excludes = [DType.FP16, DType.FP32] else: excludes = [out_dtype] wrong_dtypes = list(usableDTypes(excludes=excludes)) @@ -4131,7 +4101,7 @@ class OutputShaper: if error_name == ErrorIf.WrongOutputType: if ifm.dtype == DType.FP16: - excludes = [DType.FP16, DType.FLOAT] + excludes = [DType.FP16, DType.FP32] else: excludes = [out_dtype] wrong_dtypes = list(usableDTypes(excludes=excludes)) @@ -4182,7 +4152,7 @@ class OutputShaper: if error_name == ErrorIf.WrongOutputType: if ifm.dtype == DType.FP16: - excludes = [DType.FP16, DType.FLOAT] + excludes = [DType.FP16, DType.FP32] else: excludes = [out_dtype] wrong_dtypes = list(usableDTypes(excludes=excludes)) @@ -4217,7 +4187,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, DType.FP16, ] wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype])) @@ -4255,7 +4225,7 @@ class OutputShaper: DType.INT8, DType.INT16, DType.INT48, - DType.FLOAT, + DType.FP32, ) elif a.dtype == DType.INT16: incorrect_types = ( @@ -4263,9 +4233,9 @@ class OutputShaper: DType.INT8, DType.INT16, DType.INT32, - DType.FLOAT, + DType.FP32, ) - elif a.dtype == DType.FLOAT or a.dtype == DType.FP16: + elif a.dtype == DType.FP32 or a.dtype == DType.FP16: incorrect_types = ( DType.INT4, DType.INT8, @@ -4307,7 +4277,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, } wrong_dtypes = list(all_dtypes - set([input1.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4334,7 +4304,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, DType.FP16, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) @@ -4358,7 +4328,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4376,7 +4346,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4412,7 +4382,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4440,7 +4410,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4464,7 +4434,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, ] wrong_dtypes = list(set(all_dtypes) - set([values.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4491,7 +4461,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, ] wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4512,7 +4482,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, ] wrong_dtypes.remove(output_dtype) output_dtype = rng.choice(wrong_dtypes) @@ -4619,7 +4589,7 @@ class OutputShaper: if error_name == ErrorIf.WrongOutputType: if ifm.dtype == DType.FP16: - excludes = [DType.FP16, DType.FLOAT] + excludes = [DType.FP16, DType.FP32] else: excludes = [out_dtype] wrong_dtypes = list(usableDTypes(excludes=excludes)) -- cgit v1.2.1