diff options
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 164 |
1 files changed, 155 insertions, 9 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 159ee83..b9352ac 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -167,9 +167,10 @@ class TosaTestGen: rng = (-128, 128) elif dtype == DType.INT16: rng = (-32768, 32768) - elif dtype in (DType.INT32, DType.SHAPE): - # restricting too large value for SHAPE + elif dtype == DType.INT32: rng = (-(1 << 31), (1 << 31)) + elif dtype == DType.SHAPE: + rng = tuple(self.args.tensor_shape_range[0:2]) elif dtype == DType.INT48: rng = (-(1 << 47), (1 << 47)) else: @@ -190,7 +191,7 @@ class TosaTestGen: if dtype == DType.BOOL: return np.bool_(self.rng.choice(a=[False, True], size=shape)) - elif dtype == DType.INT48: + elif dtype in (DType.INT48, DType.SHAPE): return np.int64(self.rng.integers(low=low, high=high, size=shape)) elif dtype in (DType.FP16, DType.BF16, DType.FP32): f_tensor = self.rng.uniform(low=low, high=high, size=shape) @@ -1399,7 +1400,10 @@ class TosaTestGen: def build_concat( self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None ): - axis = args_dict["axis"] + if op["op"] == Op.CONCAT_SHAPE: + axis = 0 + else: + axis = args_dict["axis"] if error_name != ErrorIf.WrongInputType: assert type(axis) == int @@ -1438,9 +1442,12 @@ class TosaTestGen: ): return None - attr = ts.TosaSerializerAttribute() - attr.AxisAttribute(axis) - + if op["op"] == Op.CONCAT: + attr = ts.TosaSerializerAttribute() + attr.AxisAttribute(axis) + else: + assert op["op"] == Op.CONCAT_SHAPE + attr = None self.ser.addOperator(op["op"], input_list, output_list, attr) compliance = self.tensorComplianceMetaData( @@ -2512,6 +2519,52 @@ class TosaTestGen: self.ser.addOperator(op["op"], input_names, output_names, attr) return results + def build_shape_op( + self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + ): + assert len(inputs) == 2 + a, b = inputs + + result_tensor = OutputShaper.addShapeOp(self.ser, self.rng, a, b, error_name) + + # Invalidate Input/Output list for error if checks. + input_list = [a.name, b.name] + output_list = [result_tensor.name] + pCount, cCount = op["operands"] + num_operands = pCount + cCount + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) + + if not TosaErrorValidator.evValidateErrorIfs( + self.ser, + validator_fcns, + error_name, + op=op, + input1=a, + input2=b, + input_shape=a.shape, + input_dtype=a.dtype, + output_shape=result_tensor.shape, + output_dtype=result_tensor.dtype, + result_tensors=[result_tensor], + input_list=input_list, + output_list=output_list, + num_operands=num_operands, + ): + return None + + self.ser.addOperator( + op["op"], + input_list, + output_list, + ) + compliance = self.tensorComplianceMetaData( + op, a.dtype, args_dict, result_tensor, error_name + ) + + return TosaTestGen.BuildInfo(result_tensor, compliance) + def create_filter_lists( self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None ): @@ -2725,12 +2778,12 @@ class TosaTestGen: if isinstance(dtype_or_dtypeList, list): dtypeList = dtype_or_dtypeList - elif op["op"] == Op.CONCAT: + elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE): dtypeList = [dtype_or_dtypeList] * len(shapeList) else: dtypeList = [dtype_or_dtypeList] * (num_operands) - if op["op"] != Op.CONCAT: + if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE): assert ( len(shapeList) == num_operands ), "shapeList length {} must match number of operands {}".format( @@ -4605,6 +4658,78 @@ class TosaTestGen: TosaErrorValidator.evFFTOutputShapeMismatch, ), }, + # Shape + "add_shape": { + "op": Op.ADD_SHAPE, + "operands": (2, 0), + "build_fcn": ( + build_shape_op, + TosaTensorGen.tgShape, + TosaTensorValuesGen.tvgAddSub, + TosaArgGen.agNone, + ), + "types": [DType.SHAPE], + "error_if_validators": (TosaErrorValidator.evDimensionMismatch,), + }, + "sub_shape": { + "op": Op.SUB_SHAPE, + "operands": (2, 0), + "build_fcn": ( + build_shape_op, + TosaTensorGen.tgShape, + TosaTensorValuesGen.tvgAddSub, + TosaArgGen.agNone, + ), + "types": [DType.SHAPE], + "error_if_validators": (TosaErrorValidator.evDimensionMismatch,), + }, + "mul_shape": { + "op": Op.MUL_SHAPE, + "operands": (2, 0), + "build_fcn": ( + build_shape_op, + TosaTensorGen.tgShape, + TosaTensorValuesGen.tvgMul, + TosaArgGen.agNone, + ), + "types": [DType.SHAPE], + "error_if_validators": (TosaErrorValidator.evDimensionMismatch,), + }, + "div_shape": { + "op": Op.DIV_SHAPE, + "operands": (2, 0), + "build_fcn": ( + build_shape_op, + TosaTensorGen.tgShape, + TosaTensorValuesGen.tvgIntDiv, + TosaArgGen.agNone, + ), + "types": [DType.SHAPE], + "error_if_validators": (TosaErrorValidator.evDimensionMismatch,), + }, + "concat_shape": { + "op": Op.CONCAT_SHAPE, + "operands": (2, 0), + "build_fcn": ( + build_concat, + TosaTensorGen.tgConcat, + TosaTensorValuesGen.tvgConcat, + TosaArgGen.agNone, + ), + "types": [DType.SHAPE], + "error_if_validators": (), + }, + "const_shape": { + "op": Op.CONST_SHAPE, + "operands": (0, 1), + "build_fcn": ( + build_const, + TosaTensorGen.tgBasic, + TosaTensorValuesGen.tvgDefault, + None, + ), + "types": [DType.SHAPE], + }, } @@ -5524,3 +5649,24 @@ class OutputShaper: outputs.append(serializer.addOutput(output_shape, output_dtype)) outputs.append(serializer.addOutput(output_shape, output_dtype)) return outputs + + @staticmethod + def addShapeOp(ser, rng, a, b, error_name=None): + if error_name != ErrorIf.RankMismatch: + assert len(a.shape) == len(b.shape) + assert a.dtype == b.dtype + + shape = [] + for i in range(len(a.shape)): + shape.append(a.shape[i]) + + fuzz_idx = rng.integers(0, len(a.shape)) + if error_name == ErrorIf.DimensionMismatch: + shape[fuzz_idx] += 1 + + if error_name == ErrorIf.WrongOutputType: + wrong_dtypes = gtu.get_wrong_output_type(a.dtype) + outputDType = rng.choice(wrong_dtypes) + else: + outputDType = DType.SHAPE + return ser.addOutput(shape, outputDType) |