From e807aae606a78d923a2565052f7c2179e3050650 Mon Sep 17 00:00:00 2001 From: Matthew Haddon Date: Mon, 11 Oct 2021 18:12:58 +0100 Subject: Add Negative tests for pad, reshape, slice, transpose Signed-off-by: Matthew Haddon Change-Id: I659337aadfd0498bf88a95737f69c51efec797cc --- verif/tosa_error_if.py | 8 + verif/tosa_test_gen.py | 500 +++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 471 insertions(+), 37 deletions(-) diff --git a/verif/tosa_error_if.py b/verif/tosa_error_if.py index 35a391e..93a35b3 100644 --- a/verif/tosa_error_if.py +++ b/verif/tosa_error_if.py @@ -45,5 +45,13 @@ class ErrorIf(object): PoolingOutputShapeMismatch = "PoolingOutputShapeMismatch" ScaleNotTrue = "ScaleNotTrue" ScaleTrue = "ScaleTrue" + TensorSizeInputOutputMismatch = "TensorSizeInputOutputMismatch" + StartSmallerZero = "StartSmallerZero" + SizeSmallerEqualZero = "SizeSmallerEqualZero" + StartSizeOutsideBounds = "StartSizeOutsideBounds" + SizeOutputShapeMismatch = "SizeOutputShapeMismatch" + InputSizeStartLengthMismatch = "InputSizeStartLengthMismatch" + IndexOutsideBounds = "IndexOutsideBounds" + IndexUsedTwice = "IndexUsedTwice" diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py index 1ec4a47..1f35b8b 100644 --- a/verif/tosa_test_gen.py +++ b/verif/tosa_test_gen.py @@ -125,7 +125,10 @@ class TosaQuantGen: @staticmethod def qgPad(testGen, op, dtype, error_name=None): qinfo = ts.TosaSerializerQuantInfo() - qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype)) + if error_name == ErrorIf.InputZeroPointNotZero: + qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype, error_name)) + else: + qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype)) return qinfo @staticmethod @@ -668,6 +671,8 @@ class TosaArgGen: # - for padding >9, the name format needs to be more distinctive pad_min, pad_max = 0, 1 pad_values = [x for x in range(pad_min, pad_max + 1)] + if error_name == ErrorIf.PadSmallerZero: + pad_values = [x for x in range(-2, 0)] axis_pad_values = [x for x in itertools.product(pad_values, pad_values)] shape_pad_values = itertools.product(*([axis_pad_values] * rank)) @@ -920,8 +925,23 @@ class TosaArgGen: ifm_shape = shapeList[0] - # Get all permutations - permutations = [p for p in itertools.permutations(range(len(ifm_shape)))] + + if error_name == ErrorIf.IndexOutsideBounds: + incorrect_large_index = range(len(ifm_shape)+1, 2*len(ifm_shape)+1) + incorrect_small_index = range(-len(ifm_shape), 0) + permutations = [p for p in itertools.permutations(incorrect_large_index)] + permutations.extend([p for p in itertools.permutations(incorrect_small_index)]) + elif error_name == ErrorIf.IndexUsedTwice: + # Create list with a duplicated index + perm_range = list(range(len(ifm_shape))) + index_choice = testGen.rng.choice(range(len(perm_range))) + perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice] + permutations = [p for p in itertools.permutations(perm_range)] + + + else: + # Get all permutations + permutations = [p for p in itertools.permutations(range(len(ifm_shape)))] # Limit to possible permutations from shape dimension or argument setting limit = min(len(permutations), testGen.args.num_rand_permutations) @@ -944,25 +964,27 @@ class TosaArgGen: rank = len(ifm_shape) for p in range(testGen.args.num_rand_permutations): - begin = [] + start = [] size = [] valid = True for i in range(rank): if ifm_shape[i] > 1: - begin.append(testGen.randInt(0, ifm_shape[i])) - size.append(testGen.randInt(0, ifm_shape[i] - begin[i])) + start.append(testGen.randInt(0, ifm_shape[i])) + size.append(testGen.randInt(0, ifm_shape[i] - start[i])) # Invalid slice size? if size[i] == 0: valid = False else: - begin.append(0) + start.append(0) size.append(1) if valid: - arg_list.append(("perm{}".format(p), [begin, size])) + # If ERROR_IF test required then incorrect start, size will be returned + start, size = TosaErrorIfArgGen.eiSliceErrorIf(testGen, error_name, ifm_shape, start, size) + arg_list.append(("perm{}".format(p), [start, size])) return arg_list @staticmethod @@ -1241,6 +1263,7 @@ class TosaErrorIfArgGen: return shift, stride, stride_fp, offset, offset_fp, outputDType + @staticmethod def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel): if (error_name == ErrorIf.StrideSmallerOne @@ -1266,6 +1289,7 @@ class TosaErrorIfArgGen: else: return None, None, None + @staticmethod def eiRescaleWrongOutputType(input_dtype, output_dtype): if input_dtype == DType.INT8: @@ -1300,6 +1324,7 @@ class TosaErrorIfArgGen: output_list = [] return input_list, output_list + @staticmethod def eiRestrictDimension(shape, error_name): # Restrict dimension size if rank is large for WrongRank Error_If @@ -1310,6 +1335,38 @@ class TosaErrorIfArgGen: return shape + + def eiSliceErrorIf(testGen, error_name, input_shape, start, size): + if error_name == ErrorIf.StartSmallerZero: + newStart = [] + for i in range(len(input_shape)): + newStart.append(testGen.rng.choice([-3, -2, -1])) + return newStart, size + elif error_name == ErrorIf.SizeSmallerEqualZero: + newSize = [] + for i in range(len(input_shape)): + newSize.append(testGen.rng.choice([-3, -2, -1, 0])) + return start, newSize + elif error_name == ErrorIf.StartSizeOutsideBounds: + newStart, newSize = [], [] + for i in range(len(input_shape)): + newStart.append(input_shape[i]-1) + newSize.append(testGen.rng.choice([2, 3, 4])) + return newStart, newSize + elif error_name == ErrorIf.InputSizeStartLengthMismatch: + remove = testGen.rng.choice([True, False]) + if remove: + newStart = start[1:] + newSize = size[1:] + else: + newStart = start + newStart.append(1) + newSize = size + newSize.append(1) + return newStart, newSize + else: + return start, size + class TosaErrorValidator: @staticmethod @@ -1477,8 +1534,13 @@ class TosaErrorValidator: op = kwargs['op'] input_list = kwargs['input_list'] num_operands = kwargs['num_operands'] - if len(input_list) != num_operands: - error_result = True + # both PAD, TRANSPOSE add an extra const layer in the build function + if op['op'] in [Op.PAD, Op.TRANSPOSE]: + if len(input_list) != num_operands + 1: + error_result = True + else: + if len(input_list) != num_operands: + error_result = True info_dict = { "error_name": error_name, @@ -2020,9 +2082,15 @@ class TosaErrorValidator: error_reason = "At least one pad is smaller than zero" if check: + op = kwargs['op'] pad = kwargs['pad'] - if min(pad) < 0: - error_result = True + if op['op'] == Op.PAD: + for padding in pad: + if min(padding) < 0: + error_result = True + else: + if min(pad) < 0: + error_result = True info_dict = { "error_name": error_name, @@ -2240,6 +2308,203 @@ class TosaErrorValidator: } return info_dict + @staticmethod + def evTensorSizeInputOutputMismatch(check=False, **kwargs): + error_name = ErrorIf.TensorSizeInputOutputMismatch + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "Input tensor size does not match output tensor size" + + if check: + input_shape = kwargs['input_shape'] + output_shape = kwargs['output_shape'] + input_size = np.prod(input_shape) + output_size = np.prod(output_shape) + if input_size != output_size: + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs + } + return info_dict + + @staticmethod + def evStartSmallerZero(check=False, **kwargs): + error_name = ErrorIf.StartSmallerZero + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "Starting point smaller than zero" + + if check: + input_shape = kwargs['input_shape'] + start = kwargs['start'] + rank = len(input_shape) + if len(start) == rank: + for index in range(rank): + if start[index] < 0: + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs + } + return info_dict + + + @staticmethod + def evSizeSmallerEqualZero(check=False, **kwargs): + error_name = ErrorIf.SizeSmallerEqualZero + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "Size smaller than or equal to zero" + + if check: + input_shape = kwargs['input_shape'] + size = kwargs['size'] + rank = len(input_shape) + if len(size) == rank: + for index in range(rank): + if size[index] <= 0: + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs + } + return info_dict + + + @staticmethod + def evStartSizeOutsideBounds(check=False, **kwargs): + error_name = ErrorIf.StartSizeOutsideBounds + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "starting point plus size larger than input dimension" + + if check: + input_shape = kwargs['input_shape'] + start = kwargs['start'] + size = kwargs['size'] + rank = len(input_shape) + if len(start) == rank and len(size) == rank: + for index in range(rank): + if start[index] + size[index] > input_shape[index]: + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs + } + return info_dict + + + @staticmethod + def evSizeOutputShapeMismatch(check=False, **kwargs): + error_name = ErrorIf.SizeOutputShapeMismatch + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "Size does not match output dimension" + + if check: + input_shape = kwargs['input_shape'] + output_shape = kwargs['output_shape'] + size = kwargs['size'] + rank = len(input_shape) + if len(size) == rank: + for index in range(rank): + if size[index] != output_shape[index]: + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs + } + return info_dict + + @staticmethod + def evInputSizeStartLengthMismatch(check=False, **kwargs): + error_name = ErrorIf.InputSizeStartLengthMismatch + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "rank of input not equal to length of start or size" + + if check: + input_shape = kwargs['input_shape'] + start = kwargs['start'] + size = kwargs['size'] + rank = len(input_shape) + if rank != len(start) or rank != len(size): + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs + } + return info_dict + + @staticmethod + def evIndexOutsideBounds(check=False, **kwargs): + error_name = ErrorIf.IndexOutsideBounds + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "Index outside of allowed bounds" + + if check: + input_shape = kwargs['input_shape'] + perms = kwargs['perms'] + rank = len(input_shape) + + for index in perms: + if index < 0 or index > rank: + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs + } + return info_dict + + @staticmethod + def evIndexUsedTwice(check=False, **kwargs): + error_name = ErrorIf.IndexUsedTwice + param_reqs = {"rank": [2,4], "dtype": None, "shape": None} + error_result = False + error_reason = "Index used multiple times" + + if check: + input_shape = kwargs['input_shape'] + perms = kwargs['perms'] + rank = len(input_shape) + + unique_indices = [] + for index in perms: + if index in unique_indices: + error_result = True + else: + unique_indices.append(index) + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs + } + return info_dict class TosaInvalidValidator: @@ -2961,26 +3226,72 @@ class TosaTestGen: self.ser.addOperator(op['op'], input_tensor_names, [result_tens.name], attr) return result_tens - def build_pad(self, op, a, padding, qinfo): - result_tens = OutputShaper.padOp(self.ser, a, padding) + def build_pad(self, op, a, padding, validator_fcns=None, error_name=None, qinfo=None): + result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name) # Need to turn the padding array into a TOSA tensor here. # This is one of the few tensor operands that does not get # randomly generated padding_tens = self.ser.addConst(padding.shape, DType.INT32, padding) + # Invalidate Input/Output list for error if checks. + input_list = [a.name, padding_tens.name] + output_list = [result_tens.name] + pCount, cCount = op["operands"] + num_operands = pCount + cCount + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + + TosaErrorValidator.evValidateErrorIfs( + self.ser, + validator_fcns, + error_name, + op=op, + input_shape = a.shape, + output_shape = result_tens.shape, + input_dtype = a.dtype, + output_dtype = result_tens.dtype, + pad=padding, + qinfo=qinfo, + result_tensor = result_tens, + input_list=input_list, + output_list=output_list, + num_operands=num_operands, + ) + self.ser.addOperator( - op['op'], [a.name, padding_tens.name], [result_tens.name], None, qinfo + op['op'], input_list, output_list, None, qinfo ) return result_tens - def build_reshape(self, op, a, newShape): - result_tens = OutputShaper.reshapeOp(self.ser, a, newShape) + def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None): + result_tens = OutputShaper.reshapeOp(self.ser, self.rng, a, newShape, error_name) + + # Invalidate Input/Output list for error if checks. + input_list = [a.name] + output_list = [result_tens.name] + pCount, cCount = op["operands"] + num_operands = pCount + cCount + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + + TosaErrorValidator.evValidateErrorIfs( + self.ser, + validator_fcns, + error_name, + op=op, + input_shape = a.shape, + output_shape = result_tens.shape, + input_dtype = a.dtype, + output_dtype = result_tens.dtype, + result_tensor = result_tens, + input_list=input_list, + output_list=output_list, + num_operands=num_operands, + ) attr = ts.TosaSerializerAttribute() attr.ReshapeAttribute(newShape) - self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr) + self.ser.addOperator(op['op'], input_list, output_list, attr) return result_tens def build_reverse(self, op, a, axis): @@ -2992,21 +3303,69 @@ class TosaTestGen: self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr) return result_tens - def build_transpose(self, op, a, perms): - result_tens = OutputShaper.transposeOp(self.ser, a, perms) + def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None): + result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name) perms_tens = self.ser.addConst([len(perms)], DType.INT32, np.int32(perms)) - self.ser.addOperator(op['op'], [a.name, perms_tens.name], [result_tens.name]) + # Invalidate Input/Output list for error if checks. + input_list = [a.name, perms_tens.name] + output_list = [result_tens.name] + pCount, cCount = op["operands"] + num_operands = pCount + cCount + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + + TosaErrorValidator.evValidateErrorIfs( + self.ser, + validator_fcns, + error_name, + op=op, + input_shape = a.shape, + output_shape = result_tens.shape, + perms=perms, + input_dtype = a.dtype, + output_dtype = result_tens.dtype, + result_tensor = result_tens, + input_list=input_list, + output_list=output_list, + num_operands=num_operands, + ) + + + self.ser.addOperator(op['op'], input_list, output_list) return result_tens - def build_slice(self, op, a, begin, size): - result_tens = OutputShaper.sliceOp(self.ser, a, begin, size) + def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None): + result_tens = OutputShaper.sliceOp(self.ser, self.rng, a, start, size, error_name) + + # Invalidate Input/Output list for error if checks. + input_list = [a.name] + output_list = [result_tens.name] + pCount, cCount = op["operands"] + num_operands = pCount + cCount + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + + TosaErrorValidator.evValidateErrorIfs( + self.ser, + validator_fcns, + error_name, + op=op, + input_shape = a.shape, + output_shape = result_tens.shape, + input_dtype = a.dtype, + output_dtype = result_tens.dtype, + start=start, + size=size, + result_tensor = result_tens, + input_list=input_list, + output_list=output_list, + num_operands=num_operands, + ) attr = ts.TosaSerializerAttribute() - attr.SliceAttribute(begin, size) + attr.SliceAttribute(start, size) - self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr) + self.ser.addOperator(op['op'], input_list, output_list, attr) return result_tens def build_tile(self, op, a, multiples): @@ -3417,7 +3776,11 @@ class TosaTestGen: } return filterDict elif testType == 'negative': - validator_info = validator(check=False, op=op) + if validator is not None: + validator_info = validator(check=False, op=op) + else: + return None + error_arguments = validator_info['param_reqs'] #Set parameters as required @@ -3473,6 +3836,8 @@ class TosaTestGen: error_name = None filterDict = self.create_filter_lists(op, shapeFilter, rankFilter, dtypeFilter, testType, validator) + if filterDict == None: + return [] cleanRankFilter = filterDict['rankFilter'] cleanDtypeFilter = filterDict['dtypeFilter'] cleanShapeFilter = filterDict['shapeFilter'] @@ -4377,15 +4742,20 @@ class TosaTestGen: "pad": { "op": Op.PAD, "operands": (1, 0), + "rank": (1, 5), "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad), "qgen": TosaQuantGen.qgPad, "types": TYPE_FIB, + "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evPadSmallerZero, + TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) }, "reshape": { "op": Op.RESHAPE, "operands": (1, 0), "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape), "types": TYPE_FIB, + "error_if_validators": (TosaErrorValidator.evTensorSizeInputOutputMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) }, "reverse": { "op": Op.REVERSE, @@ -4396,8 +4766,12 @@ class TosaTestGen: "slice": { "op": Op.SLICE, "operands": (1, 0), + "rank": (1, 4), "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice), "types": TYPE_FIB, + "error_if_validators": (TosaErrorValidator.evStartSmallerZero, TosaErrorValidator.evSizeSmallerEqualZero, TosaErrorValidator.evStartSizeOutsideBounds, + TosaErrorValidator.evSizeOutputShapeMismatch, TosaErrorValidator.evInputSizeStartLengthMismatch, TosaErrorValidator.evWrongRank, + TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) }, "tile": { "op": Op.TILE, @@ -4415,6 +4789,8 @@ class TosaTestGen: TosaArgGen.agTranspose, ), "types": TYPE_FIB, + "error_if_validators": (TosaErrorValidator.evIndexOutsideBounds, TosaErrorValidator.evIndexUsedTwice, TosaErrorValidator.evWrongRank, + TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) }, # Data nodes "const": { @@ -4862,17 +5238,28 @@ class OutputShaper: return ser.addOutput(output_shape, input1.dtype) @staticmethod - def padOp(ser, a, padding): + def padOp(ser, rng, a, padding, error_name=None): output_shape = a.shape.copy() for i in range(len(output_shape)): output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i] - return ser.addOutput(output_shape, a.dtype) + # Fix negative output shape if error_if test causes it + if error_name == ErrorIf.PadSmallerZero and min(output_shape) < 1: + output_shape = [i if i >= 1 else 1 for i in output_shape] + + if error_name == ErrorIf.WrongOutputType: + all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT] + wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) + outputDType = rng.choice(wrong_dtypes) + else: + outputDType = a.dtype + + return ser.addOutput(output_shape, outputDType) @staticmethod - def reshapeOp(ser, a, shape): + def reshapeOp(ser, rng, a, shape, error_name=None): output_shape = shape.copy() totalElements = 1 @@ -4890,13 +5277,40 @@ class OutputShaper: if output_shape[i] == -1: output_shape[i] = totalElements // totalOutputElements - return ser.addOutput(output_shape, a.dtype) + if error_name == ErrorIf.TensorSizeInputOutputMismatch: + for i in range(len(output_shape)): + output_shape[i] = output_shape[i] + rng.integers(1, 10) + + if error_name == ErrorIf.WrongOutputType: + all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT] + wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) + outputDType = rng.choice(wrong_dtypes) + else: + outputDType = a.dtype + + return ser.addOutput(output_shape, outputDType) @staticmethod - def sliceOp(ser, a, begin, size): + def sliceOp(ser, rng, a, start, size, error_name=None): - output_shape = size.copy() - return ser.addOutput(output_shape, a.dtype) + if error_name == ErrorIf.WrongOutputType: + all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT] + wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) + outputDType = rng.choice(wrong_dtypes) + else: + outputDType = a.dtype + + if error_name == ErrorIf.SizeOutputShapeMismatch: + output_shape = size.copy() + for index in range(len(output_shape)): + if output_shape[index] <= 2: + output_shape[index] = output_shape[index] + rng.choice([1, 2]) + else: + output_shape[index] = output_shape[index] + rng.choice([-2, -1, 1, 2]) + else: + output_shape = size.copy() + + return ser.addOutput(output_shape, outputDType) @staticmethod def tileOp(ser, a, multiples): @@ -4910,14 +5324,26 @@ class OutputShaper: return ser.addOutput(output_shape, a.dtype) @staticmethod - def transposeOp(ser, a, perms): + def transposeOp(ser, rng, a, perms, error_name=None): output_shape = a.shape.copy() + assert len(perms) == len(output_shape) - for i in range(len(output_shape)): - output_shape[i] = a.shape[perms[i]] + if error_name == ErrorIf.IndexOutsideBounds: + for i in range(len(output_shape)): + output_shape[i] = a.shape[0] + else: + for i in range(len(output_shape)): + output_shape[i] = a.shape[perms[i]] - return ser.addOutput(output_shape, a.dtype) + if error_name == ErrorIf.WrongOutputType: + all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT] + wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) + outputDType = rng.choice(wrong_dtypes) + else: + outputDType = a.dtype + + return ser.addOutput(output_shape, outputDType) @staticmethod def gatherOp(ser, values, indices): -- cgit v1.2.1