diff options
Diffstat (limited to 'verif')
-rw-r--r-- | verif/checker/tosa_result_checker.py | 4 | ||||
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 230 | ||||
-rw-r--r-- | verif/generator/tosa_error_if.py | 74 | ||||
-rw-r--r-- | verif/generator/tosa_test_gen.py | 298 | ||||
-rw-r--r-- | verif/generator/tosa_utils.py | 39 | ||||
-rw-r--r-- | verif/tests/test_tosa_result_checker.py | 2 |
6 files changed, 444 insertions, 203 deletions
diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py index 66864c2..8ae3218 100644 --- a/verif/checker/tosa_result_checker.py +++ b/verif/checker/tosa_result_checker.py @@ -147,14 +147,14 @@ def test_check( tolerance = 0.0 # Fall-through to below to add failure values - elif reference_result.dtype == np.float32: + # TODO: update for fp16 tolerance + elif reference_result.dtype == np.float32 or reference_result.dtype == np.float16: tolerance = float_tolerance if np.allclose(reference_result, test_result, atol=tolerance, equal_nan=True): print_color(LogColors.GREEN, "Results PASS {}".format(test_name)) return (TestResult.PASS, tolerance, "") msg = "Float result does not match within tolerance of {}".format(tolerance) # Fall-through to below to add failure values - else: print_color(LogColors.RED, "Results UNSUPPORTED TYPE {}".format(test_name)) msg = "Unsupported results type: {}".format(reference_result.dtype) diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index a65e220..69968d3 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -2,10 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 import itertools import math +import warnings import numpy as np from generator.tosa_error_if import ErrorIf from generator.tosa_error_if import TosaErrorIfArgGen +from generator.tosa_utils import get_accum_dtype_from_tgTypes +from generator.tosa_utils import get_wrong_output_type from generator.tosa_utils import MAX_RESIZE_DIMENSION from serializer.tosa_serializer import DTypeNames from tosa.DType import DType @@ -773,7 +776,7 @@ class TosaTensorValuesGen: ), "Op.MUL must have 2 placeholders, 0 consts" tens = [] - if dtypeList[0] == DType.FLOAT: + if dtypeList[0] in (DType.FP16, DType.FLOAT): tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:])) else: placeholders = [] @@ -982,7 +985,7 @@ class TosaArgGen: return axes @staticmethod - def agConv(testGen, opName, shapeList, dtype, error_name=None): + def agConv(testGen, opName, shapeList, dtypes, error_name=None): arg_list = [] ifm_shape = shapeList[0] @@ -990,6 +993,8 @@ class TosaArgGen: # determine the kernel shape from operator name (e.g. "conv2d_3x3" => [3,3]) k = [int(x) for x in opName.split("_")[-1].split("x")] + accum_dtype = get_accum_dtype_from_tgTypes(dtypes) + # Check the rank rank = 5 if opName.startswith("conv3d") else 4 if error_name != ErrorIf.WrongRank: @@ -1089,12 +1094,13 @@ class TosaArgGen: ): arg_list.append( ( - "st{}_pad{}_dilat{}".format( + "acc{}_st{}_pad{}_dilat{}".format( + testGen.typeStr(accum_dtype), "".join([str(x) for x in s]), "".join([str(x) for x in p]), "".join([str(x) for x in d]), ), - [s, p, d], + [accum_dtype, s, p, d], ) ) n += 1 @@ -1102,12 +1108,55 @@ class TosaArgGen: return arg_list @staticmethod - def agTransposeConv2D(testGen, opName, shapeList, dtype, error_name=None): + def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None): + + if isinstance(dtypes, list) or isinstance(dtypes, tuple): + input_dtype = dtypes[0] + else: + input_dtype = dtypes + + if error_name == ErrorIf.WrongOutputType: + accum_dtype = get_wrong_output_type(opName, testGen.rng, input_dtype) + elif error_name == ErrorIf.WrongInputType: + # Pick some potentially correct output dtype if input type is incorrect + accum_dtype = DType.INT32 + else: + accum_dtype = get_accum_dtype_from_tgTypes(dtypes) + + return [(f"acc{testGen.typeStr(accum_dtype)}", [accum_dtype])] + + @staticmethod + def agMatMul(testGen, opName, shapeList, dtype, error_name=None): + # Get valid accumulate type(s) + if dtype == DType.INT8: + accum_dtypes = [DType.INT32] + elif dtype == DType.INT16: + accum_dtypes = [DType.INT48] + elif dtype == DType.FP16: + accum_dtypes = [DType.FP16, DType.FLOAT] + elif dtype == DType.FLOAT: + accum_dtypes = [DType.FLOAT] + elif error_name is None: + assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}" + + if error_name == ErrorIf.WrongOutputType: + # Get incorrect output dtype for ErrorIf case + accum_dtypes = [get_wrong_output_type(opName, testGen.rng, dtype)] + elif error_name == ErrorIf.WrongInputType: + # Pick some potentially correct output dtype if input type is incorrect + accum_dtypes = [DType.INT32] + + return [(f"acc{testGen.typeStr(a)}", [a]) for a in accum_dtypes] + + @staticmethod + def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None): arg_list = [] ifm_shape = shapeList[0] filter_shape = shapeList[1] + accum_dtype = get_accum_dtype_from_tgTypes(dtypes) + # Must be rank 4 if error_name != ErrorIf.WrongRank: assert len(ifm_shape) == 4 @@ -1169,12 +1218,13 @@ class TosaArgGen: os = [ifm_shape[0], oh, ow, filter_shape[0]] arg_list.append( ( - "st{}_pad{}_os{}".format( + "acc{}_st{}_pad{}_os{}".format( + testGen.typeStr(accum_dtype), "".join([str(x) for x in s]), "".join([str(x) for x in p]), "x".join([str(x) for x in os]), ), - [s, p, os], + [accum_dtype, s, p, os], ) ) n += 1 @@ -1199,18 +1249,38 @@ class TosaArgGen: if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]: pad_const_int = testGen.getRandNumberDType(dtype) pad_const_fp = 0 - elif dtype == DType.FLOAT: + elif dtype in (DType.FP16, DType.FLOAT): pad_const_int = 0 pad_const_fp = testGen.getRandNumberDType(dtype) else: return [] for paddings in shape_pad_values: - name = "pad" - for r in range(rank): - before, after = paddings[r] - name = f"{name}{before}{after}" - arg_list.append((name, [np.array(paddings), pad_const_int, pad_const_fp])) + paddings = list(paddings) + args_valid = True + + if error_name == ErrorIf.PadSmallerZero: + # Prevent negative output shapes while ensuring still testing for negative padding + for i in range(rank): + dim_after_padding = ( + paddings[i][0] + paddings[i][1] + shapeList[0][i] + ) + if dim_after_padding < 1: + paddings[i] = (0, 0) + if all([p > -1 for p in paddings[i]]): + args_valid = False + + if args_valid: + name = "pad" + for r in range(rank): + before, after = paddings[r] + name = f"{name}{before}{after}" + arg_list.append( + (name, [np.array(paddings), pad_const_int, pad_const_fp]) + ) + + if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0: + warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}") return arg_list @@ -1232,6 +1302,21 @@ class TosaArgGen: k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 1)] kernels = {x for x in itertools.product(*([k_vals] * 2))} + if opName == "max_pool2d": + accum_dtypes = [None] # max_pool has no accumulate dtype + elif dtype == DType.INT8 or dtype == DType.INT16: + accum_dtypes = [DType.INT32] + elif dtype == DType.FP16: + accum_dtypes = [DType.FP16, DType.FLOAT] + elif dtype == DType.FLOAT: + accum_dtypes = [DType.FLOAT] + elif error_name is None: + assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}" + else: + # Set to something for the ErrorIf case which has + # incorrect input data-type + accum_dtypes = [DType.INT32] + if testGen.args.oversize: # add some oversize argument values bigStride = 7 @@ -1252,63 +1337,70 @@ class TosaArgGen: sparsity_factor = 2 if error_name else 500 sparsity = len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1 + arg_str = ( + "acc{}_st{}_kern{}_pad{}" + if accum_dtypes[0] is not None + else "st{}_kern{}_pad{}" + ) + + def get_arg_list_element(accum, stride, pad, kern): + # Return tuple containing the formatted argument string and + # the corresponding argument values + arg_str_elems = [ + "".join([str(x) for x in stride]), + "".join([str(x) for x in kern]), + "".join([str(x) for x in pad]), + ] + # Note: different order to string + arg_val_elems = [stride, pad, kern] + + if accum is not None: + arg_str_elems.insert(0, testGen.typeStr(accum)) + arg_val_elems.insert(0, accum) + return (arg_str.format(*arg_str_elems), arg_val_elems) + n = 0 - for s in sorted(list(strides)): - for p in sorted(list(paddings)): - for k in sorted(list(kernels)): - if error_name in [ - ErrorIf.StrideSmallerOne, - ErrorIf.KernelSmallerOne, - ErrorIf.PadSmallerZero, - ErrorIf.PadLargerEqualKernel, - ]: - sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf( - testGen, error_name, s, p, k - ) - if None not in [sNew, pNew, kNew] and n % sparsity == 0: - arg_list.append( - ( - "st{}_kern{}_pad{}".format( - "".join([str(x) for x in sNew]), - "".join([str(x) for x in kNew]), - "".join([str(x) for x in pNew]), - ), - [sNew, pNew, kNew], - ) + for a in accum_dtypes: + for s in sorted(list(strides)): + for p in sorted(list(paddings)): + for k in sorted(list(kernels)): + if error_name in [ + ErrorIf.StrideSmallerOne, + ErrorIf.KernelSmallerOne, + ErrorIf.PadSmallerZero, + ErrorIf.PadLargerEqualKernel, + ]: + sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf( + testGen, error_name, s, p, k ) - elif ( - n % sparsity == 0 - # padding must not exceed the kernel size - and p[0] < k[0] - and p[1] < k[0] - and p[2] < k[1] - and p[3] < k[1] - # the padded shape must exceed the kernel size - and (shape[1] + p[0] + p[1]) > k[0] - and (shape[2] + p[2] + p[3]) > k[1] - ): - remainder_h = (shape[1] + p[0] + p[1] - k[0]) % s[0] - remainder_w = (shape[2] + p[2] + p[3] - k[1]) % s[1] - if ( - # the parameters must produce integer exact output - error_name != ErrorIf.PoolingOutputShapeNonInteger - and remainder_h == 0 - and remainder_w == 0 - ) or ( - error_name == ErrorIf.PoolingOutputShapeNonInteger - and (remainder_h != 0 or remainder_w != 0) + if None not in [sNew, pNew, kNew] and n % sparsity == 0: + arg_vals = [a, sNew, pNew, kNew] + arg_list.append(get_arg_list_element(*arg_vals)) + elif ( + n % sparsity == 0 + # padding must not exceed the kernel size + and p[0] < k[0] + and p[1] < k[0] + and p[2] < k[1] + and p[3] < k[1] + # the padded shape must exceed the kernel size + and (shape[1] + p[0] + p[1]) > k[0] + and (shape[2] + p[2] + p[3]) > k[1] ): - arg_list.append( - ( - "st{}_kern{}_pad{}".format( - "".join([str(x) for x in s]), - "".join([str(x) for x in k]), - "".join([str(x) for x in p]), - ), - [s, p, k], - ) - ) - n += 1 + remainder_h = (shape[1] + p[0] + p[1] - k[0]) % s[0] + remainder_w = (shape[2] + p[2] + p[3] - k[1]) % s[1] + if ( + # the parameters must produce integer exact output + error_name != ErrorIf.PoolingOutputShapeNonInteger + and remainder_h == 0 + and remainder_w == 0 + ) or ( + error_name == ErrorIf.PoolingOutputShapeNonInteger + and (remainder_h != 0 or remainder_w != 0) + ): + arg_vals = [a, s, p, k] + arg_list.append(get_arg_list_element(*arg_vals)) + n += 1 return arg_list @@ -1327,6 +1419,8 @@ class TosaArgGen: dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT] elif inDtype == DType.BOOL: dtypeList = [DType.INT8, DType.INT16, DType.INT32] + elif inDtype == DType.FP16: + dtypeList = [DType.INT8, DType.INT16, DType.INT32] elif inDtype == DType.FLOAT: dtypeList = [DType.INT8, DType.INT16, DType.INT32] elif error_name == ErrorIf.WrongInputType: @@ -1734,6 +1828,8 @@ class TosaArgGen: outputDTypeList = [DType.INT32] elif mode == ResizeMode.BILINEAR and dtype == DType.INT16: outputDTypeList = [DType.INT48] + elif dtype == DType.FP16: + outputDTypeList = [DType.FP16] elif dtype == DType.FLOAT: outputDTypeList = [DType.FLOAT] elif error_name == ErrorIf.WrongInputType: diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index f9a00f9..a766803 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -120,6 +120,7 @@ class TosaErrorIfArgGen: DType.INT32, DType.INT48, DType.FLOAT, + DType.FP16, ) elif mode == ResizeMode.NEAREST and dtype == DType.INT16: incorrect_types = ( @@ -128,6 +129,7 @@ class TosaErrorIfArgGen: DType.INT32, DType.INT48, DType.FLOAT, + DType.FP16, ) elif mode == ResizeMode.BILINEAR and dtype == DType.INT8: incorrect_types = ( @@ -136,6 +138,7 @@ class TosaErrorIfArgGen: DType.INT16, DType.INT48, DType.FLOAT, + DType.FP16, ) elif mode == ResizeMode.BILINEAR and dtype == DType.INT16: incorrect_types = ( @@ -144,6 +147,16 @@ class TosaErrorIfArgGen: DType.INT16, DType.INT32, DType.FLOAT, + DType.FP16, + ) + elif dtype == DType.FP16: + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FLOAT, ) elif dtype == DType.FLOAT: incorrect_types = ( @@ -152,6 +165,7 @@ class TosaErrorIfArgGen: DType.INT16, DType.INT32, DType.INT48, + DType.FP16, ) outputDType = testGen.rng.choice(a=incorrect_types) @@ -285,8 +299,8 @@ class TosaErrorIfArgGen: @staticmethod def eiCastErrorIf(testGen, input_dtype): - if input_dtype in [DType.BOOL, DType.FLOAT]: - outputDType = [DType.BOOL, DType.INT48, DType.FLOAT] + if input_dtype in [DType.BOOL, DType.FP16, DType.FLOAT]: + outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.FLOAT] elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]: outputDType = [DType.INT48] else: @@ -400,6 +414,7 @@ class TosaErrorValidator: and input_dtype == DType.INT16 and output_dtype != DType.INT48 ) + or (input_dtype == DType.FP16 and output_dtype != DType.FP16) or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT) ): error_result = True @@ -413,19 +428,28 @@ class TosaErrorValidator: if ( (input_dtype == DType.INT8 and output_dtype != DType.INT32) or (input_dtype == DType.INT16 and output_dtype != DType.INT48) + or ( + input_dtype == DType.FP16 + and output_dtype not in (DType.FP16, DType.FLOAT) + ) or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT) ): error_result = True elif op["op"] == Op.ARGMAX: if ( - input_dtype in [DType.INT8, DType.INT16, DType.FLOAT] + input_dtype in [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT] and output_dtype != DType.INT32 ): error_result = True elif op["op"] == Op.MUL: - if input_dtype != DType.FLOAT and output_dtype != DType.INT32: + if ( + input_dtype not in (DType.FP16, DType.FLOAT) + and output_dtype != DType.INT32 + ): + error_result = True + elif input_dtype == DType.FP16 and output_dtype != DType.FP16: error_result = True elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT: error_result = True @@ -449,17 +473,39 @@ class TosaErrorValidator: or ( input_dtype == DType.INT8 and output_dtype - not in [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT] + not in [ + DType.BOOL, + DType.INT16, + DType.INT32, + DType.FLOAT, + DType.FP16, + ] ) or ( input_dtype == DType.INT16 and output_dtype - not in [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT] + not in [ + DType.BOOL, + DType.INT8, + DType.INT32, + DType.FLOAT, + DType.FP16, + ] ) or ( input_dtype == DType.INT32 and output_dtype - not in [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT] + not in [ + DType.BOOL, + DType.INT8, + DType.INT16, + DType.FLOAT, + DType.FP16, + ] + ) + or ( + input_dtype == DType.FP16 + and output_dtype not in [DType.INT8, DType.INT16, DType.INT32] ) or ( input_dtype == DType.FLOAT @@ -479,6 +525,8 @@ class TosaErrorValidator: and output_dtype != DType.INT32 or input_dtype == DType.INT16 and output_dtype != DType.INT48 + or input_dtype == DType.FP16 + and output_dtype not in (DType.FP16, DType.FLOAT) or input_dtype == DType.FLOAT and output_dtype != DType.FLOAT ): @@ -2257,12 +2305,13 @@ class TosaInvalidValidator: return ( not (input_dtype == DType.INT8 and output_dtype == DType.INT32) and not (input_dtype == DType.INT16 and output_dtype == DType.INT48) + and not (input_dtype == DType.FP16 and output_dtype == DType.FP16) and not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT) ) elif mode == ResizeMode.NEAREST: # Invalid output data type / Invalid input datatype return (input_dtype != output_dtype) or ( - input_dtype not in [DType.INT8, DType.INT16, DType.FLOAT] + input_dtype not in [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT] ) else: # Invalid resize mode @@ -2276,8 +2325,11 @@ class TosaInvalidValidator: input_shape = inputShapes[0] args = kwargs["args"] - strides = args[0] - padding = args[1] + + # MaxPool2D has no accum_dtype arg + stride_idx, pad_idx = (0, 1) if opName == "max_pool2d" else (1, 2) + strides = args[stride_idx] + padding = args[pad_idx] if opName.endswith("pool2d"): # avg_pool2d, max_pool2d @@ -2365,7 +2417,7 @@ class TosaInvalidValidator: @staticmethod def ivNonPositiveOutputShape(**kwargs): args = kwargs["args"] - output_shape = args[2] + output_shape = args[3] if output_shape[1] <= 0 or output_shape[2] <= 0: # Negative output shape return True diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index b76b656..9ff6ec5 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -81,6 +81,8 @@ class TosaTestGen: return np.int64( self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape) ) + elif dtype == DType.FP16: + return np.float16(self.rng.random(size=shape)) elif dtype == DType.FLOAT: return np.float32(self.rng.random(size=shape)) else: @@ -128,6 +130,9 @@ class TosaTestGen: def getRandNumberDType(self, dtype): if dtype == DType.FLOAT: return self.rng.random() + elif dtype == DType.FP16: + rand_f32 = self.rng.random() + return np.float16(rand_f32) elif dtype == DType.BOOL: return self.rng.choice([False, True]) # TOSA specific INT4 weight range from -7 to 7 @@ -178,13 +183,15 @@ class TosaTestGen: return "i32" elif t == DType.INT48: return "i48" + elif t == DType.FP16: + return "f16" elif t == DType.FLOAT: return "float" else: raise Exception("Unknown dtype, cannot convert to string: {}".format(t)) def typeWidth(self, t): - """Get the datatype width for integer types""" + """Get the datatype width for data types""" if t == DType.INT4: return 4 elif t == DType.INT8: @@ -199,6 +206,8 @@ class TosaTestGen: return 32 elif t == DType.INT48: return 48 + elif t == DType.FP16: + return 16 elif t == DType.FLOAT: return 32 elif t == DType.BOOL: @@ -346,7 +355,7 @@ class TosaTestGen: # Special for multiply: # Force the result to INT32 for INT types - if a.dtype != DType.FLOAT: + if a.dtype not in (DType.FP16, DType.FLOAT): result_tens.setDtype(DType.INT32) if error_name == ErrorIf.WrongOutputType: all_dtypes = [DType.INT8, DType.INT16, DType.INT48] @@ -533,6 +542,7 @@ class TosaTestGen: self, op, input, + accum_dtype, stride, pad, kernel, @@ -585,17 +595,43 @@ class TosaTestGen: qinfo = [0, 0] attr = ts.TosaSerializerAttribute() - attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1]) + attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens + def build_maxpool2d( + self, + op, + input, + stride, + pad, + kernel, + validator_fcns=None, + error_name=None, + qinfo=None, + ): + # Same as build_pool2d but manually sets accum_dtype value + # (maxpool has no accum_dtype) + return self.build_pool2d( + op, + input, + DType.UNKNOWN, + stride, + pad, + kernel, + validator_fcns, + error_name, + qinfo, + ) + def build_conv2d( self, op, ifm, filter, bias, + accum_dtype, strides, padding, dilations, @@ -605,7 +641,15 @@ class TosaTestGen: ): assert len(padding) == 4 result_tens = OutputShaper.conv2dOp( - self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name + self.ser, + self.rng, + ifm, + filter, + accum_dtype, + strides, + padding, + dilations, + error_name, ) # Ensure new output type has correct qinfo @@ -648,7 +692,7 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1]) + attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens @@ -659,6 +703,7 @@ class TosaTestGen: ifm, filter, bias, + accum_dtype, strides, padding, dilations, @@ -668,7 +713,15 @@ class TosaTestGen: ): assert len(padding) == 6 result_tens = OutputShaper.conv3dOp( - self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name + self.ser, + self.rng, + ifm, + filter, + accum_dtype, + strides, + padding, + dilations, + error_name, ) # Ensure new output type has correct qinfo @@ -711,7 +764,7 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1]) + attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens @@ -722,6 +775,7 @@ class TosaTestGen: ifm, filter, bias, + accum_dtype, stride, out_pad, output_shape, @@ -731,7 +785,7 @@ class TosaTestGen: ): assert len(out_pad) == 4 result_tens = OutputShaper.transposeConv2DOp( - self.ser, self.rng, ifm, output_shape, error_name + self.ser, self.rng, ifm, output_shape, accum_dtype, error_name ) # Ensure new output type has correct qinfo @@ -773,7 +827,9 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1]) + attr.TransposeConvAttribute( + out_pad, stride, output_shape, qinfo[0], qinfo[1], accum_dtype + ) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens @@ -784,6 +840,7 @@ class TosaTestGen: ifm, filter, bias, + accum_dtype, strides, padding, dilations, @@ -792,7 +849,15 @@ class TosaTestGen: qinfo=None, ): result_tens = OutputShaper.depthwiseConv2dOp( - self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name + self.ser, + self.rng, + ifm, + filter, + accum_dtype, + strides, + padding, + dilations, + error_name, ) # Ensure new output type has correct qinfo @@ -835,16 +900,24 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1]) + attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_fully_connected( - self, op, ifm, filter, bias, validator_fcns=None, error_name=None, qinfo=None + self, + op, + ifm, + filter, + bias, + accum_dtype, + validator_fcns=None, + error_name=None, + qinfo=None, ): result_tens = OutputShaper.fullyConnectedOp( - self.ser, self.rng, ifm, filter, error_name + self.ser, self.rng, ifm, filter, accum_dtype, error_name ) # Invalidate Input/Output list for error if checks. @@ -871,17 +944,22 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, + accum_dtype=accum_dtype, ): return None attr = ts.TosaSerializerAttribute() - attr.FullyConnectedAttribute(qinfo[0], qinfo[1]) + attr.FullyConnectedAttribute(qinfo[0], qinfo[1], accum_dtype) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens - def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None): - result_tens = OutputShaper.matmulOp(self.ser, self.rng, a, b, error_name) + def build_matmul( + self, op, a, b, accum_dtype, validator_fcns=None, error_name=None, qinfo=None + ): + result_tens = OutputShaper.matmulOp( + self.ser, self.rng, a, b, accum_dtype, error_name + ) # Invalidate Input/Output list for error if checks. input_list = [a.name, b.name] @@ -908,11 +986,12 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, + accum_dtype=accum_dtype, ): return None attr = ts.TosaSerializerAttribute() - attr.MatMulAttribute(qinfo[0], qinfo[1]) + attr.MatMulAttribute(qinfo[0], qinfo[1], accum_dtype) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens @@ -995,7 +1074,7 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - if a.dtype == DType.FLOAT: + if a.dtype in (DType.FP16, DType.FLOAT): attr.ClampAttribute(0, 0, min_val, max_val) else: attr.ClampAttribute(min_val, max_val, 0, 0) @@ -1811,7 +1890,7 @@ class TosaTestGen: op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr ) - if a.dtype in (DType.FLOAT, DType.INT32): + if a.dtype in (DType.FLOAT, DType.FP16, DType.INT32): then_op, else_op = Op.ADD, Op.SUB elif a.dtype in (DType.INT8, DType.INT16): then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT @@ -2350,22 +2429,37 @@ class TosaTestGen: # if not specified, defaults to (1, 4) # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum) # 'types': array of datatypes to be tested - TYPE_FP = [DType.FLOAT] + TYPE_FP = [DType.FLOAT, DType.FP16] TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4 - TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4 + TYPE_INT_FP = [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.FP16, + DType.FLOAT, + ] # Excludes INT4 TYPE_BOOL = [DType.BOOL] - TYPE_FI32 = [DType.FLOAT, DType.INT32] - TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL] + TYPE_FI32 = [DType.FLOAT, DType.FP16, DType.INT32] # floating-types and INT32 + TYPE_FIB = [ + DType.FP16, + DType.FLOAT, + DType.INT8, + DType.INT16, + DType.INT32, + DType.BOOL, + ] TYPE_FI16 = [DType.FLOAT, DType.INT16] - TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT] + TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT] TYPE_CONV = [ [DType.INT8, DType.INT4, DType.INT32], [DType.INT8, DType.INT8, DType.INT32], [DType.INT16, DType.INT8, DType.INT48], + [DType.FP16, DType.FP16, DType.FP16], + [DType.FP16, DType.FP16, DType.FLOAT], DType.FLOAT, ] @@ -2524,7 +2618,7 @@ class TosaTestGen: build_fully_connected, TosaTensorGen.tgFullyConnected, TosaTensorValuesGen.tvgDefault, - None, + TosaArgGen.agFullyConnected, ), "qgen": TosaQuantGen.qgConv, "types": TYPE_CONV, @@ -2546,7 +2640,7 @@ class TosaTestGen: build_matmul, TosaTensorGen.tgMatmul, TosaTensorValuesGen.tvgDefault, - None, + TosaArgGen.agMatMul, ), "qgen": TosaQuantGen.qgMatmul, "types": TYPE_NARROW_INT_FP, @@ -2564,7 +2658,7 @@ class TosaTestGen: "operands": (1, 0), "rank": (4, 4), "build_fcn": ( - build_pool2d, + build_maxpool2d, TosaTensorGen.tgNHWC, TosaTensorValuesGen.tvgDefault, TosaArgGen.agPooling, @@ -3384,7 +3478,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgReduceSum, TosaArgGen.agAxis, ), - "types": TYPE_FI32, + "types": (DType.FP16, DType.FLOAT, DType.INT32), "error_if_validators": ( TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, @@ -3571,7 +3665,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgDefault, None, ), - "types": TYPE_INT_FP, + "types": (DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.FLOAT), "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, @@ -3612,7 +3706,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgDefault, TosaArgGen.agResize, ), - "types": [DType.INT8, DType.INT16, DType.FLOAT], + "types": (DType.INT8, DType.INT16, DType.FP16, DType.FLOAT), "invalid_test_validators": ( TosaInvalidValidator.ivWrongDataTypeOrModeResize, ), @@ -3646,7 +3740,14 @@ class TosaTestGen: TosaTensorValuesGen.tvgDefault, TosaArgGen.agCast, ), - "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL], + "types": ( + DType.FP16, + DType.FLOAT, + DType.INT8, + DType.INT16, + DType.INT32, + DType.BOOL, + ), "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, @@ -3925,7 +4026,9 @@ class OutputShaper: return ser.addOutput(shape, outputDType) @staticmethod - def conv2dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None): + def conv2dOp( + ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None + ): # IFM: NHWC # Filter: OHWI @@ -3958,26 +4061,26 @@ class OutputShaper: ofm_shape = [ifm.shape[0], h, w, filter.shape[0]] - if ifm.dtype == DType.INT8: - out_dtype = DType.INT32 - elif ifm.dtype == DType.INT16: - out_dtype = DType.INT48 - elif ifm.dtype == DType.FLOAT: - out_dtype = DType.FLOAT - elif error_name == ErrorIf.WrongInputType: + if error_name == ErrorIf.WrongInputType: # Pick some potentially correct output dtype if input type is incorrect out_dtype = DType.INT32 else: - raise Exception(f"Unsupported input dtype: {ifm.dtype}") + out_dtype = accum_dtype if error_name == ErrorIf.WrongOutputType: - wrong_dtypes = list(usableDTypes(excludes=[out_dtype])) + if ifm.dtype == DType.FP16: + excludes = [DType.FP16, DType.FLOAT] + else: + excludes = [out_dtype] + wrong_dtypes = list(usableDTypes(excludes=excludes)) out_dtype = rng.choice(wrong_dtypes) return ser.addOutput(ofm_shape, out_dtype) @staticmethod - def conv3dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None): + def conv3dOp( + ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None + ): # IFM: NDHWC # Filter: ODHWI @@ -4020,27 +4123,25 @@ class OutputShaper: ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]] - if ifm.dtype == DType.INT8: - out_dtype = DType.INT32 - elif ifm.dtype == DType.INT16: - out_dtype = DType.INT48 - elif ifm.dtype == DType.FLOAT: - out_dtype = DType.FLOAT - elif error_name == ErrorIf.WrongInputType: + if error_name == ErrorIf.WrongInputType: # Pick some potentially correct output dtype if input type is incorrect out_dtype = DType.INT32 else: - raise Exception(f"Unsupported input dtype: {ifm.dtype}") + out_dtype = accum_dtype if error_name == ErrorIf.WrongOutputType: - wrong_dtypes = list(usableDTypes(excludes=[out_dtype])) + if ifm.dtype == DType.FP16: + excludes = [DType.FP16, DType.FLOAT] + else: + excludes = [out_dtype] + wrong_dtypes = list(usableDTypes(excludes=excludes)) out_dtype = rng.choice(wrong_dtypes) return ser.addOutput(ofm_shape, out_dtype) @staticmethod def depthwiseConv2dOp( - ser, rng, ifm, filter, strides, padding, dilations, error_name=None + ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None ): # IFM: NHWC # Filter: HWCM @@ -4073,20 +4174,18 @@ class OutputShaper: ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]] - if ifm.dtype == DType.INT8: - out_dtype = DType.INT32 - elif ifm.dtype == DType.INT16: - out_dtype = DType.INT48 - elif ifm.dtype == DType.FLOAT: - out_dtype = DType.FLOAT - elif error_name == ErrorIf.WrongInputType: + if error_name == ErrorIf.WrongInputType: # Pick some potentially correct output dtype if input type is incorrect out_dtype = DType.INT32 else: - raise Exception(f"Unsupported input dtype: {ifm.dtype}") + out_dtype = accum_dtype if error_name == ErrorIf.WrongOutputType: - wrong_dtypes = list(usableDTypes(excludes=[out_dtype])) + if ifm.dtype == DType.FP16: + excludes = [DType.FP16, DType.FLOAT] + else: + excludes = [out_dtype] + wrong_dtypes = list(usableDTypes(excludes=excludes)) out_dtype = rng.choice(wrong_dtypes) return ser.addOutput(ofm_shape, out_dtype) @@ -4119,6 +4218,7 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FLOAT, + DType.FP16, ] wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4128,55 +4228,20 @@ class OutputShaper: return ser.addOutput(ofm_shape, outputDType) @staticmethod - def fullyConnectedOp(ser, rng, input, filter, error_name=None): + def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None): # input: N, IC # filter: OC, IC # output: N, OC output_shape = [input.shape[0], filter.shape[0]] - if error_name == ErrorIf.WrongOutputType: - if input.dtype == DType.INT8: - incorrect_types = ( - DType.INT4, - DType.INT8, - DType.INT16, - DType.INT48, - DType.FLOAT, - ) - elif input.dtype == DType.INT16: - incorrect_types = ( - DType.INT4, - DType.INT8, - DType.INT16, - DType.INT32, - DType.FLOAT, - ) - elif input.dtype == DType.FLOAT: - incorrect_types = ( - DType.INT4, - DType.INT8, - DType.INT16, - DType.INT32, - DType.INT48, - ) - out_dtype = rng.choice(a=incorrect_types) - elif input.dtype == DType.INT8: - out_dtype = DType.INT32 - elif input.dtype == DType.INT16: - out_dtype = DType.INT48 - elif input.dtype == DType.FLOAT: - out_dtype = DType.FLOAT - elif error_name == ErrorIf.WrongInputType: - # Pick some potentially correct output dtype if input type is incorrect - out_dtype = DType.INT32 - else: - raise Exception("Unsupported input dtype: {}".format(input.dtype)) + # Validated in arg_gen (also invalidated for ErrorIf) + out_dtype = accum_dtype return ser.addOutput(output_shape, out_dtype) @staticmethod - def matmulOp(ser, rng, a, b, error_name=None): + def matmulOp(ser, rng, a, b, accum_dtype, error_name=None): # a: N, H, C # b: N, C, W # out: N, H, W @@ -4200,7 +4265,7 @@ class OutputShaper: DType.INT32, DType.FLOAT, ) - elif a.dtype == DType.FLOAT: + elif a.dtype == DType.FLOAT or a.dtype == DType.FP16: incorrect_types = ( DType.INT4, DType.INT8, @@ -4209,17 +4274,11 @@ class OutputShaper: DType.INT48, ) out_dtype = rng.choice(a=incorrect_types) - elif a.dtype == DType.INT8: - out_dtype = DType.INT32 - elif a.dtype == DType.INT16: - out_dtype = DType.INT48 - elif a.dtype == DType.FLOAT: - out_dtype = DType.FLOAT elif error_name == ErrorIf.WrongInputType: # Pick some potentially correct output dtype if input type is incorrect out_dtype = DType.INT32 else: - raise Exception("Unsupported input dtype for matmul: {}".format(a.dtype)) + out_dtype = accum_dtype # Validated in arg_gen return ser.addOutput(output_shape, out_dtype) @@ -4269,10 +4328,6 @@ class OutputShaper: bad_dim = rng.choice(range(len(output_shape))) output_shape[bad_dim] -= rng.choice([1, 2]) - # Fix negative output shape if error_if test causes it - if error_name == ErrorIf.PadSmallerZero and min(output_shape) < 1: - output_shape = [i if i >= 1 else 1 for i in output_shape] - if error_name == ErrorIf.WrongOutputType: all_dtypes = [ DType.INT8, @@ -4280,6 +4335,7 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FLOAT, + DType.FP16, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4546,7 +4602,7 @@ class OutputShaper: return ser.addOutput(val.shape, out_dtype) @staticmethod - def transposeConv2DOp(ser, rng, ifm, output_shape, error_name=None): + def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None): if error_name == ErrorIf.ConvOutputShapeMismatch: choices = [1, 2, 3] change = rng.choice(choices) @@ -4555,20 +4611,18 @@ class OutputShaper: if change in [2, 3]: output_shape[2] = output_shape[2] + rng.choice(choices) - if ifm.dtype == DType.INT8: - out_dtype = DType.INT32 - elif ifm.dtype == DType.INT16: - out_dtype = DType.INT48 - elif ifm.dtype == DType.FLOAT: - out_dtype = DType.FLOAT - elif error_name == ErrorIf.WrongInputType: + if error_name == ErrorIf.WrongInputType: # Pick some potentially correct output dtype if input type is incorrect out_dtype = DType.INT32 else: - raise Exception(f"Unsupported input dtype: {ifm.dtype}") + out_dtype = accum_dtype if error_name == ErrorIf.WrongOutputType: - wrong_dtypes = list(usableDTypes(excludes=[out_dtype])) + if ifm.dtype == DType.FP16: + excludes = [DType.FP16, DType.FLOAT] + else: + excludes = [out_dtype] + wrong_dtypes = list(usableDTypes(excludes=excludes)) out_dtype = rng.choice(wrong_dtypes) return ser.addOutput(output_shape, out_dtype) diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py index 6a689d0..7fa31e7 100644 --- a/verif/generator/tosa_utils.py +++ b/verif/generator/tosa_utils.py @@ -84,3 +84,42 @@ def product(shape): for n in shape: value *= n return value + + +def get_accum_dtype_from_tgTypes(dtypes): + # Get accumulate data-type from the test generator's defined types + if isinstance(dtypes, list) or isinstance(dtypes, tuple): + return dtypes[-1] + else: + return dtypes + + +def get_wrong_output_type(op_name, rng, input_dtype): + if op_name == "fully_connected" or op_name == "matmul": + if input_dtype == DType.INT8: + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT48, + DType.FLOAT, + DType.FP16, + ) + elif input_dtype == DType.INT16: + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT32, + DType.FLOAT, + DType.FP16, + ) + elif input_dtype == DType.FLOAT or input_dtype == DType.FP16: + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + ) + return rng.choice(a=incorrect_types) diff --git a/verif/tests/test_tosa_result_checker.py b/verif/tests/test_tosa_result_checker.py index efee23b..d78d158 100644 --- a/verif/tests/test_tosa_result_checker.py +++ b/verif/tests/test_tosa_result_checker.py @@ -40,7 +40,7 @@ def _delete_data_file(file: Path): (np.uint16, trc.TestResult.MISMATCH), (np.uint32, trc.TestResult.MISMATCH), (np.uint64, trc.TestResult.MISMATCH), - (np.float16, trc.TestResult.MISMATCH), + (np.float16, trc.TestResult.PASS), (np.float32, trc.TestResult.PASS), (np.float64, trc.TestResult.MISMATCH), (bool, trc.TestResult.PASS), |