diff options
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 153 |
1 files changed, 63 insertions, 90 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 9f65fd4..04093b8 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -405,13 +405,9 @@ class TosaTestGen: self.ser.addOperator(op["op"], input_list, output_list, attr) - if op["op"] in (Op.LOG,): - # TODO - add compliance support LOG - compliance = None - else: - compliance = self.tensorComplianceMetaData( - op, a.dtype, args_dict, result_tensor, error_name - ) + compliance = self.tensorComplianceMetaData( + op, a.dtype, args_dict, result_tensor, error_name + ) return TosaTestGen.BuildInfo(result_tensor, compliance) def build_binary_broadcast( @@ -1241,8 +1237,13 @@ class TosaTestGen: return TosaTestGen.BuildInfo(result_tensor, compliance) - def build_clamp(self, op, a, validator_fcns=None, error_name=None): - result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name) + def build_clamp( + self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + ): + assert len(inputs) == 1 + a = inputs[0] + + result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name) v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)] @@ -1258,7 +1259,7 @@ class TosaTestGen: # Invalidate Input/Output list for error if checks. input_list = [a.name] - output_list = [result_tens.name] + output_list = [result_tensor.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( @@ -1273,10 +1274,10 @@ class TosaTestGen: max_val=max_val, min_val=min_val, input_shape=a.shape, - output_shape=result_tens.shape, + output_shape=result_tensor.shape, input_dtype=a.dtype, - output_dtype=result_tens.dtype, - result_tensors=[result_tens], + output_dtype=result_tensor.dtype, + result_tensors=[result_tensor], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1295,7 +1296,12 @@ class TosaTestGen: attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0) self.ser.addOperator(op["op"], input_list, output_list, attr) - return result_tens + + compliance = self.tensorComplianceMetaData( + op, a.dtype, args_dict, result_tensor, error_name + ) + + return TosaTestGen.BuildInfo(result_tensor, compliance) def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None): result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name) @@ -1313,43 +1319,17 @@ class TosaTestGen: self.ser.addOperator(op["op"], [a.name], [result_tens.name]) return result_tens - def build_sigmoid(self, op, a, validator_fcns=None, error_name=None): - result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name) - - # Invalidate Input/Output list for error if checks. - input_list = [a.name] - output_list = [result_tens.name] - pCount, cCount = op["operands"] - num_operands = pCount + cCount - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list - ) - - if not TosaErrorValidator.evValidateErrorIfs( - self.ser, - validator_fcns, - error_name, - op=op, - input_shape=a.shape, - output_shape=result_tens.shape, - input_dtype=a.dtype, - output_dtype=result_tens.dtype, - result_tensors=[result_tens], - input_list=input_list, - output_list=output_list, - num_operands=num_operands, - ): - return None - - self.ser.addOperator(op["op"], input_list, output_list) - return result_tens + def build_activation( + self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + ): + assert len(inputs) == 1 + a = inputs[0] - def build_tanh(self, op, a, validator_fcns=None, error_name=None): - result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name) + result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name) # Invalidate Input/Output list for error if checks. input_list = [a.name] - output_list = [result_tens.name] + output_list = [result_tensor.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( @@ -1362,10 +1342,10 @@ class TosaTestGen: error_name, op=op, input_shape=a.shape, - output_shape=result_tens.shape, + output_shape=result_tensor.shape, input_dtype=a.dtype, - output_dtype=result_tens.dtype, - result_tensors=[result_tens], + output_dtype=result_tensor.dtype, + result_tensors=[result_tensor], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1373,38 +1353,12 @@ class TosaTestGen: return None self.ser.addOperator(op["op"], input_list, output_list) - return result_tens - - def build_erf(self, op, a, validator_fcns=None, error_name=None): - result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name) - # Invalidate Input/Output list for error if checks. - input_list = [a.name] - output_list = [result_tens.name] - pCount, cCount = op["operands"] - num_operands = pCount + cCount - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + compliance = self.tensorComplianceMetaData( + op, a.dtype, args_dict, result_tensor, error_name ) - if not TosaErrorValidator.evValidateErrorIfs( - self.ser, - validator_fcns, - error_name, - op=op, - input_shape=a.shape, - output_shape=result_tens.shape, - input_dtype=a.dtype, - output_dtype=result_tens.dtype, - result_tensors=[result_tens], - input_list=input_list, - output_list=output_list, - num_operands=num_operands, - ): - return None - - self.ser.addOperator(op["op"], input_list, output_list) - return result_tens + return TosaTestGen.BuildInfo(result_tensor, compliance) def build_concat( self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None @@ -3220,8 +3174,8 @@ class TosaTestGen: "build_fcn": ( build_clamp, TosaTensorGen.tgBasic, - TosaTensorValuesGen.tvgDefault, - None, + TosaTensorValuesGen.tvgLazyGenDefault, + TosaArgGen.agNone, ), "types": TYPE_NARROW_INT_FP, "error_if_validators": ( @@ -3231,15 +3185,18 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), + "data_gen": { + "fp": (gtu.DataGenType.PSEUDO_RANDOM,), + }, }, "sigmoid": { "op": Op.SIGMOID, "operands": (1, 0), "build_fcn": ( - build_sigmoid, + build_activation, TosaTensorGen.tgBasic, - TosaTensorValuesGen.tvgDefault, - None, + TosaTensorValuesGen.tvgLazyGenDefault, + TosaArgGen.agNone, ), "types": TYPE_FP, "error_if_validators": ( @@ -3248,15 +3205,19 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), + "data_gen": { + "fp": (gtu.DataGenType.PSEUDO_RANDOM,), + }, + "compliance": {"ulp": 5}, }, "tanh": { "op": Op.TANH, "operands": (1, 0), "build_fcn": ( - build_tanh, + build_activation, TosaTensorGen.tgBasic, - TosaTensorValuesGen.tvgDefault, - None, + TosaTensorValuesGen.tvgLazyGenDefault, + TosaArgGen.agNone, ), "types": TYPE_FP, "error_if_validators": ( @@ -3265,15 +3226,19 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), + "data_gen": { + "fp": (gtu.DataGenType.PSEUDO_RANDOM,), + }, + "compliance": {"ulp": 5}, }, "erf": { "op": Op.ERF, "operands": (1, 0), "build_fcn": ( - build_erf, + build_activation, TosaTensorGen.tgBasic, - TosaTensorValuesGen.tvgDefault, - None, + TosaTensorValuesGen.tvgLazyGenDefault, + TosaArgGen.agNone, ), "types": TYPE_FP, "error_if_validators": ( @@ -3282,6 +3247,10 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), + "data_gen": { + "fp": (gtu.DataGenType.PSEUDO_RANDOM,), + }, + "compliance": {"ulp": 5}, }, # Elementwise Binary Operators "add": { @@ -3778,6 +3747,10 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), + "data_gen": { + "fp": (gtu.DataGenType.PSEUDO_RANDOM,), + }, + "compliance": {"ulp": 5}, }, "logical_not": { "op": Op.LOGICAL_NOT, |