diff options
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 107 |
1 files changed, 69 insertions, 38 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index d1fe11d..35cd78f 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -352,28 +352,26 @@ class TosaTestGen: self.resultTensor = resultTensor self.complianceDict = complianceDict - def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None): - result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name) + def build_unary( + self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + ): + assert len(inputs) == 1 + a = inputs[0] + result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name) - # build_placeholder returns an int, ABS/other ops does not - if isinstance(op, int): - self.ser.addOperator(op, a.name, result_tens.name, None) - return result_tens - elif op["op"] == Op.IDENTITY: - self.ser.addOperator(op["op"], a.name, result_tens.name, None) - return result_tens + assert not isinstance(op, int) # Ensure new output type has correct qinfo if error_name == ErrorIf.WrongOutputType: - if result_tens.dtype not in [DType.INT8, DType.UINT8]: + if result_tensor.dtype not in [DType.INT8, DType.UINT8]: qinfo = [ TosaQuantGen.getZeroPoint(self, a.dtype), - TosaQuantGen.getZeroPoint(self, result_tens.dtype), + TosaQuantGen.getZeroPoint(self, result_tensor.dtype), ] # Invalidate Input/Output list for error if checks. input_list = [a.name] - output_list = [result_tens.name] + output_list = [result_tensor.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( @@ -386,9 +384,9 @@ class TosaTestGen: error_name, op=op, input_dtype=a.dtype, - output_dtype=result_tens.dtype, + output_dtype=result_tensor.dtype, qinfo=qinfo, - result_tensors=[result_tens], + result_tensors=[result_tensor], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -401,7 +399,15 @@ class TosaTestGen: attr.NegateAttribute(qinfo[0], qinfo[1]) self.ser.addOperator(op["op"], input_list, output_list, attr) - return result_tens + + if op["op"] in (Op.EXP, Op.LOG): + # TODO - add compliance support LOG and EXP + compliance = None + else: + compliance = self.tensorComplianceMetaData( + op, a.dtype, args_dict, result_tensor, error_name + ) + return TosaTestGen.BuildInfo(result_tensor, compliance) def build_binary_broadcast( self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None @@ -3622,8 +3628,8 @@ class TosaTestGen: "build_fcn": ( build_unary, TosaTensorGen.tgBasic, - TosaTensorValuesGen.tvgDefault, - None, + TosaTensorValuesGen.tvgLazyGenDefault, + TosaArgGen.agNone, ), "types": TYPE_FI32, "error_if_validators": ( @@ -3632,6 +3638,9 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), + "data_gen": { + "fp": (gtu.DataGenType.PSEUDO_RANDOM,), + }, }, "bitwise_not": { "op": Op.BITWISE_NOT, @@ -3639,8 +3648,8 @@ class TosaTestGen: "build_fcn": ( build_unary, TosaTensorGen.tgBasic, - TosaTensorValuesGen.tvgDefault, - None, + TosaTensorValuesGen.tvgLazyGenDefault, + TosaArgGen.agNone, ), "types": TYPE_INT, "error_if_validators": ( @@ -3656,8 +3665,8 @@ class TosaTestGen: "build_fcn": ( build_unary, TosaTensorGen.tgBasic, - TosaTensorValuesGen.tvgDefault, - None, + TosaTensorValuesGen.tvgLazyGenDefault, + TosaArgGen.agNone, ), "types": TYPE_FP, "error_if_validators": ( @@ -3666,6 +3675,10 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), + "data_gen": { + "fp": (gtu.DataGenType.PSEUDO_RANDOM,), + }, + "compliance": {"ulp": 0.5}, }, "clz": { "op": Op.CLZ, @@ -3673,8 +3686,8 @@ class TosaTestGen: "build_fcn": ( build_unary, TosaTensorGen.tgBasic, - TosaTensorValuesGen.tvgDefault, - None, + TosaTensorValuesGen.tvgLazyGenDefault, + TosaArgGen.agNone, ), "types": [DType.INT32], "error_if_validators": ( @@ -3690,8 +3703,8 @@ class TosaTestGen: "build_fcn": ( build_unary, TosaTensorGen.tgBasic, - TosaTensorValuesGen.tvgDefault, - None, + TosaTensorValuesGen.tvgLazyGenDefault, + TosaArgGen.agNone, ), "types": TYPE_FP, "error_if_validators": ( @@ -3707,8 +3720,8 @@ class TosaTestGen: "build_fcn": ( build_unary, TosaTensorGen.tgBasic, - TosaTensorValuesGen.tvgDefault, - None, + TosaTensorValuesGen.tvgLazyGenDefault, + TosaArgGen.agNone, ), "types": TYPE_FP, "error_if_validators": ( @@ -3717,6 +3730,10 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), + "data_gen": { + "fp": (gtu.DataGenType.PSEUDO_RANDOM,), + }, + "compliance": {"ulp": 0.5}, }, "log": { "op": Op.LOG, @@ -3724,8 +3741,8 @@ class TosaTestGen: "build_fcn": ( build_unary, TosaTensorGen.tgBasic, - TosaTensorValuesGen.tvgDefault, - None, + TosaTensorValuesGen.tvgLazyGenDefault, + TosaArgGen.agNone, ), "types": TYPE_FP, "error_if_validators": ( @@ -3741,8 +3758,8 @@ class TosaTestGen: "build_fcn": ( build_unary, TosaTensorGen.tgBasic, - TosaTensorValuesGen.tvgDefault, - None, + TosaTensorValuesGen.tvgLazyGenDefault, + TosaArgGen.agNone, ), "types": TYPE_BOOL, "error_if_validators": ( @@ -3759,7 +3776,7 @@ class TosaTestGen: build_unary, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgNegate, - None, + TosaArgGen.agNone, ), "qgen": TosaQuantGen.qgUnary, "types": TYPE_INT_FP, @@ -3771,6 +3788,9 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), + "data_gen": { + "fp": (gtu.DataGenType.PSEUDO_RANDOM,), + }, }, "reciprocal": { "op": Op.RECIPROCAL, @@ -3778,8 +3798,8 @@ class TosaTestGen: "build_fcn": ( build_unary, TosaTensorGen.tgBasic, - TosaTensorValuesGen.tvgDefault, - None, + TosaTensorValuesGen.tvgLazyGenDefault, + TosaArgGen.agNone, ), "types": TYPE_FP, "error_if_validators": ( @@ -3788,6 +3808,10 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), + "data_gen": { + "fp": (gtu.DataGenType.PSEUDO_RANDOM,), + }, + "compliance": {"ulp": 1.0}, }, "rsqrt": { "op": Op.RSQRT, @@ -3795,8 +3819,8 @@ class TosaTestGen: "build_fcn": ( build_unary, TosaTensorGen.tgBasic, - TosaTensorValuesGen.tvgDefault, - None, + TosaTensorValuesGen.tvgLazyGenDefault, + TosaArgGen.agNone, ), "types": TYPE_FP, "error_if_validators": ( @@ -3805,6 +3829,10 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), + "data_gen": { + "fp": (gtu.DataGenType.PSEUDO_RANDOM,), + }, + "compliance": {"ulp": 2}, }, # Elementwise Ternary operators "select": { @@ -4220,10 +4248,13 @@ class TosaTestGen: "build_fcn": ( build_unary, TosaTensorGen.tgBasic, - TosaTensorValuesGen.tvgDefault, - None, + TosaTensorValuesGen.tvgLazyGenDefault, + TosaArgGen.agNone, ), "types": TYPE_FIB, + "data_gen": { + "fp": (gtu.DataGenType.PSEUDO_RANDOM,), + }, }, # Scatter/Gather "gather": { |