From 7e9ac9ab74ff2a793e226abf86d2543d1421d3c9 Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Mon, 8 Nov 2021 18:10:51 +0000 Subject: Add Broadcast DimensionMismatch errors Add RankMismatch and DimensionMismatch support for SELECT Update RankMismatch ops to also support DimensionMismatch Update POW op to have proper broadcast testing A few other broadcastable ops missing Rank/Dimension testing Change-Id: I6566f45a7a0db4f9f008456ea7a8e23d4192f4f9 Signed-off-by: Jeremy Johnson --- verif/tosa_error_if.py | 1 + verif/tosa_test_gen.py | 132 ++++++++++++++++++++++++++++++++++--------------- 2 files changed, 92 insertions(+), 41 deletions(-) diff --git a/verif/tosa_error_if.py b/verif/tosa_error_if.py index c3a9068..eb67ea8 100644 --- a/verif/tosa_error_if.py +++ b/verif/tosa_error_if.py @@ -30,6 +30,7 @@ class ErrorIf(object): BatchMismatch = "BatchMismatch" ChannelMismatch = "ChannelMismatch" RankMismatch = "RankMismatch" + DimensionMismatch = "DimensionMismatch" InputZeroPointNotZero = "InputZeroPointNotZero" WeightZeroPointNotZero = "WeightZeroPointNotZero" OutputZeroPointNotZero = "OutputZeroPointNotZero" diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py index 80ccff3..db44328 100644 --- a/verif/tosa_test_gen.py +++ b/verif/tosa_test_gen.py @@ -270,21 +270,26 @@ class TosaTensorGen: shape_list = [] # Choose one of the inputs to broadcast - bcast_idx = testGen.randInt(0, pl + const) + # Note: Simplifies OutputShaper code if we don't change first shape for errors + bcast_idx = testGen.randInt(0 if error_name == None else 1, pl + const) for i in range(pl + const): shape_bcast = shape.copy() - if error_name == ErrorIf.RankMismatch: - bcast_idx = -1 # Turn off broadcast because we are not testing it - if rank == 1 and i != 1: - shape_bcast = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3])) - elif i != 1: - shape_bcast = testGen.makeShape(rank + testGen.rng.choice([-1, 1])) - # If the chosen input, pick a random index to broadcast if i == bcast_idx: fuzz_idx = testGen.randInt(0, rank) - shape_bcast[fuzz_idx] = 1 + if error_name == ErrorIf.DimensionMismatch: + shape_bcast[fuzz_idx] += 1 + elif error_name == ErrorIf.RankMismatch: + # Add one rank to the shape (or more for rank of 1) + extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1 + shape_bcast = np.concatenate((shape_bcast, testGen.makeShape(extra_ranks))) + if rank != 1: + # Either keep the extra rank, or remove it + new_len = testGen.rng.choice([-2, len(shape_bcast)]) + shape_bcast = shape_bcast[:new_len] + else: + shape_bcast[fuzz_idx] = 1 shape_list.append(shape_bcast) @@ -2001,8 +2006,14 @@ class TosaErrorValidator: if check: input1_shape = kwargs['input1'].shape input2_shape = kwargs['input2'].shape + # In case of SELECT op + input3_shape = kwargs['input3'].shape if 'input3' in kwargs else input2_shape output_shape = kwargs['result_tensor'].shape - if (len(input1_shape) != len(output_shape)) or (len(input2_shape) != len(output_shape)): + if ( + (len(input1_shape) != len(output_shape)) or + (len(input2_shape) != len(output_shape)) or + (len(input3_shape) != len(output_shape)) + ): error_result = True info_dict = { @@ -2013,6 +2024,35 @@ class TosaErrorValidator: } return info_dict + @staticmethod + def evDimensionMismatch(check=False, **kwargs): + error_name = ErrorIf.DimensionMismatch + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "Input Dimensions do not match output" + + if check: + input1_shape = kwargs['input1'].shape + input2_shape = kwargs['input2'].shape + # In case of SELECT op + input3_shape = kwargs['input3'].shape if 'input3' in kwargs else input2_shape + output_shape = kwargs['result_tensor'].shape + for i in range(min(len(input1_shape), len(input2_shape), len(input3_shape))): + if ( + (input1_shape[i] != 1 and input1_shape[i] != output_shape[i]) or + (input2_shape[i] != 1 and input2_shape[i] != output_shape[i]) or + (input3_shape[i] != 1 and input3_shape[i] != output_shape[i]) + ): + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs + } + return info_dict + @staticmethod def evInputZeroPointNotZero(check=False, **kwargs): op = kwargs['op'] @@ -3492,6 +3532,9 @@ class TosaTestGen: validator_fcns, error_name, op=op, + input1 = cond, + input2 = a, + input3 = b, input_shape = a.shape, input_dtype = a.dtype, output_dtype = result_tens.dtype, @@ -3519,6 +3562,8 @@ class TosaTestGen: validator_fcns, error_name, op=op, + input1 = a, + input2 = b, input_shape = a.shape, input_dtype = a.dtype, output_shape = result_tens.shape, @@ -5019,7 +5064,7 @@ class TosaTestGen: ) tens.extend(placeholders) - elif op["op"] == Op.MUL: + elif op["op"] == Op.MUL and error_name == None: assert ( pCount == 2 and cCount == 0 ), "Op.MUL must have 2 placeholders, 0 consts" @@ -5363,7 +5408,7 @@ class TosaTestGen: "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_FI32, "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) }, "arithmetic_right_shift": { "op": Op.ARITHMETIC_RIGHT_SHIFT, @@ -5374,8 +5419,8 @@ class TosaTestGen: TosaArgGen.agArithmeticRightShift, ), "types": TYPE_INT, - "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, - TosaErrorValidator.evWrongOutputList) + "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) }, "bitwise_and": { "op": Op.BITWISE_AND, @@ -5383,7 +5428,7 @@ class TosaTestGen: "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_INT, "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) }, "bitwise_or": { "op": Op.BITWISE_OR, @@ -5391,7 +5436,7 @@ class TosaTestGen: "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_INT, "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) }, "bitwise_xor": { "op": Op.BITWISE_XOR, @@ -5399,7 +5444,7 @@ class TosaTestGen: "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_INT, "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) }, "intdiv": { "op": Op.INTDIV, @@ -5407,7 +5452,7 @@ class TosaTestGen: "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None), "types": [DType.INT32], "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) }, "logical_and": { "op": Op.LOGICAL_AND, @@ -5415,7 +5460,7 @@ class TosaTestGen: "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_BOOL, "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) }, "logical_left_shift": { "op": Op.LOGICAL_LEFT_SHIFT, @@ -5423,7 +5468,7 @@ class TosaTestGen: "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_INT, "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) }, "logical_right_shift": { "op": Op.LOGICAL_RIGHT_SHIFT, @@ -5431,7 +5476,7 @@ class TosaTestGen: "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_INT, "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) }, "logical_or": { "op": Op.LOGICAL_OR, @@ -5439,7 +5484,7 @@ class TosaTestGen: "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_BOOL, "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) }, "logical_xor": { "op": Op.LOGICAL_XOR, @@ -5447,7 +5492,7 @@ class TosaTestGen: "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_BOOL, "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) }, "maximum": { "op": Op.MAXIMUM, @@ -5455,7 +5500,7 @@ class TosaTestGen: "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_FI32, "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) }, "minimum": { "op": Op.MINIMUM, @@ -5463,7 +5508,7 @@ class TosaTestGen: "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_FI32, "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) }, "mul": { "op": Op.MUL, @@ -5471,15 +5516,15 @@ class TosaTestGen: "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul), "types": TYPE_INT_FP, "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, - TosaErrorValidator.evWrongOutputList) + TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evRankMismatch, TosaErrorValidator.evDimensionMismatch) }, "pow": { "op": Op.POW, "operands": (2, 0), - "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBasic, None), + "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_FP, "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) }, "sub": { "op": Op.SUB, @@ -5487,7 +5532,7 @@ class TosaTestGen: "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_FI32, "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) }, "table": { "op": Op.TABLE, @@ -5597,8 +5642,8 @@ class TosaTestGen: "operands": (3, 0), "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_FIB, - "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) }, # Comparison operators "equal": { @@ -5606,24 +5651,24 @@ class TosaTestGen: "operands": (2, 0), "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_FI32, - "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) }, "greater_equal": { "op": Op.GREATER_EQUAL, "operands": (2, 0), "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_FI32, - "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) }, "greater": { "op": Op.GREATER, "operands": (2, 0), "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_FI32, - "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) }, # Reduction operators "reduce_all": { @@ -5916,12 +5961,16 @@ class OutputShaper: @staticmethod def selectOp(ser, rng, cond, a, b, error_name=None): - assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape) + if error_name != ErrorIf.RankMismatch: + assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape) assert a.dtype == b.dtype shape = [] - for i in range(len(a.shape)): - shape.append(max(cond.shape[i], a.shape[i], b.shape[i])) + for i in range(len(cond.shape)): + if cond.shape[i] == 1 and error_name == None: + shape.append(max(cond.shape[i], a.shape[i], b.shape[i])) + else: + shape.append(cond.shape[i]) if error_name == ErrorIf.WrongOutputType: all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT] @@ -5934,7 +5983,8 @@ class OutputShaper: @staticmethod def binaryComparisonOp(ser, rng, a, b , error_name=None): - assert len(a.shape) == len(b.shape) + if error_name != ErrorIf.RankMismatch: + assert len(a.shape) == len(b.shape) assert a.dtype == b.dtype # Do broadcast -- cgit v1.2.1