diff options
Diffstat (limited to 'verif/tosa_test_gen.py')
-rw-r--r-- | verif/tosa_test_gen.py | 140 |
1 files changed, 130 insertions, 10 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py index 8d69831..43b188d 100644 --- a/verif/tosa_test_gen.py +++ b/verif/tosa_test_gen.py @@ -475,11 +475,18 @@ class TosaArgGen: def agAxis(testGen, opName, shapeList, dtype, error_name=None): """Build the axis argument for operators that take a single axis""" axes = [] - shape = shapeList[0] - for a in range(0, len(shape)): - axes.append(("axis{}".format(a), [a])) + if error_name == ErrorIf.AxisSmallerZero: + small_axis = testGen.rng.integers(-5, 0) + axes.append(("axis{}".format(small_axis), [small_axis])) + elif error_name == ErrorIf.AxisLargerRank: + large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10) + axes.append(("axis{}".format(large_axis), [large_axis])) + else: + for a in range(0, len(shape)): + axes.append(("axis{}".format(a), [a])) + return axes @staticmethod @@ -1715,6 +1722,70 @@ class TosaErrorValidator: } return info_dict + @staticmethod + def evAxisSmallerZero(check=False, **kwargs): + error_name = ErrorIf.AxisSmallerZero + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "Axis smaller than zero" + + if check: + axis = kwargs['axis'] + if axis < 0: + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs + } + return info_dict + + + @staticmethod + def evAxisLargerRank(check=False, **kwargs): + error_name = ErrorIf.AxisLargerRank + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "Axis larger than rank" + + if check: + axis = kwargs['axis'] + shape = kwargs['input_shape'] + if axis > len(shape): + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs + } + return info_dict + + + @staticmethod + def evShapeOfAxisNotOne(check=False, **kwargs): + error_name = ErrorIf.ShapeOfAxisNotOne + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "shape[axis] is not equal to 1" + + if check: + axis = kwargs['axis'] + shape = kwargs['output_shape'] + if (0 <= axis < len(shape)) and shape[axis] != 1: + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs + } + return info_dict + class TosaInvalidValidator: @@ -2233,13 +2304,36 @@ class TosaTestGen: self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], None, qinfo) return result_tens - def build_reduce(self, op, a, axis): - result_tens = OutputShaper.reduceOp(self.ser, a, axis) + def build_reduce(self, op, a, axis, validator_fcns, error_name=None): + result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name) + + # Invalidate Input/Output list for error if checks. + input_list = [a.name] + output_list = [result_tens.name] + pCount, cCount = op["operands"] + num_operands = pCount + cCount + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + + TosaErrorValidator.evValidateErrorIfs( + self.ser, + validator_fcns, + error_name, + op=op, + axis = axis, + input_shape = a.shape, + output_shape = result_tens.shape, + input_dtype = a.dtype, + output_dtype = result_tens.dtype, + result_tensor = result_tens, + input_list=input_list, + output_list=output_list, + num_operands=num_operands, + ) attr = ts.TosaSerializerAttribute() attr.AxisAttribute(axis) - self.ser.addOperator(op['op'], [a.name], result_tens.name, attr) + self.ser.addOperator(op['op'], input_list, output_list, attr) return result_tens def build_clamp(self, op, a): @@ -3606,36 +3700,54 @@ class TosaTestGen: "operands": (1, 0), "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis), "types": TYPE_BOOL, + "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne, + TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, + TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) }, "reduce_any": { "op": Op.REDUCE_ANY, "operands": (1, 0), "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis), "types": TYPE_BOOL, + "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne, + TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, + TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) }, "reduce_max": { "op": Op.REDUCE_MAX, "operands": (1, 0), "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis), "types": TYPE_INT_FP, + "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne, + TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, + TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) }, "reduce_min": { "op": Op.REDUCE_MAX, "operands": (1, 0), "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis), "types": TYPE_INT_FP, + "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne, + TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, + TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) }, "reduce_product": { "op": Op.REDUCE_PRODUCT, "operands": (1, 0), "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis), "types": TYPE_FP, + "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne, + TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, + TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) }, "reduce_sum": { "op": Op.REDUCE_SUM, "operands": (1, 0), "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis), "types": TYPE_FI32, + "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne, + TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, + TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) }, # Data layout operators "concat": { @@ -3865,13 +3977,21 @@ class OutputShaper: return ser.addOutput(shape, DType.BOOL) @staticmethod - def reduceOp(ser, a, axis): - + def reduceOp(ser, rng, a, axis, error_name=None): shape = a.shape.copy() + if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank, ErrorIf.ShapeOfAxisNotOne]: + shape[axis] = 1 + if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1: + shape[axis] = rng.integers(2, 10) - shape[axis] = 1 + if error_name == ErrorIf.WrongOutputType: + all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT] + wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) + outputDType = rng.choice(wrong_dtypes) + else: + outputDType = a.dtype - return ser.addOutput(shape, a.dtype) + return ser.addOutput(shape, outputDType) @staticmethod def argmaxOp(ser, a, axis): |