diff options
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 69 |
1 files changed, 55 insertions, 14 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index bfafd23..68a4e94 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -381,8 +381,35 @@ class TosaTestGen: """Enhanced build information containing result tensor and associated compliance dict.""" def __init__(self, resultTensor, complianceDict): - self.resultTensor = resultTensor - self.complianceDict = complianceDict + if isinstance(resultTensor, list): + assert complianceDict is None or isinstance(complianceDict, list) + self.resultTensorList = resultTensor + self.complianceDictList = complianceDict + else: + self.resultTensorList = [resultTensor] + if complianceDict is None: + self.complianceDictList = None + else: + self.complianceDictList = [complianceDict] + + def getComplianceInfo(self): + if self.complianceDictList is None: + return None + else: + tens_dict = {} + for tens, comp in zip(self.resultTensorList, self.complianceDictList): + if comp is not None: + tens_dict[tens.name] = comp + + if tens_dict: + # Have some compliance data, so return the info + compliance = { + "version": "0.1", + "tensors": tens_dict, + } + else: + compliance = None + return compliance def build_unary( self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None @@ -2491,12 +2518,16 @@ class TosaTestGen: def build_fft2d( self, op, - val1, - val2, - inverse, + inputs, + args_dict, validator_fcns=None, error_name=None, + qinfo=None, ): + assert len(inputs) == 2 + val1, val2 = inputs + inverse = args_dict["inverse"] + results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name) input_names = [val1.name, val2.name] @@ -2537,7 +2568,16 @@ class TosaTestGen: attr.FFTAttribute(inverse, local_bound) self.ser.addOperator(op["op"], input_names, output_names, attr) - return results + + compliance = [] + for res in results: + compliance.append( + self.tensorComplianceMetaData( + op, val1.dtype, args_dict, res, error_name + ) + ) + + return TosaTestGen.BuildInfo(results, compliance) def build_rfft2d( self, @@ -2933,13 +2973,11 @@ class TosaTestGen: if result: # The test is valid, serialize it - if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict: - # Add the compliance meta data - # NOTE: This currently expects only one result output - tensMeta["compliance"] = { - "version": "0.1", - "tensors": {result.resultTensor.name: result.complianceDict}, - } + if isinstance(result, TosaTestGen.BuildInfo): + # Add the compliance meta data (if any) + compliance = result.getComplianceInfo() + if compliance: + tensMeta["compliance"] = compliance self.serialize("test", tensMeta) else: # The test is not valid @@ -4708,7 +4746,7 @@ class TosaTestGen: "build_fcn": ( build_fft2d, TosaTensorGen.tgFFT2d, - TosaTensorValuesGen.tvgDefault, + TosaTensorValuesGen.tvgLazyGenDefault, TosaArgGen.agFFT2d, ), "types": [DType.FP32], @@ -4723,6 +4761,9 @@ class TosaTestGen: TosaErrorValidator.evFFTInputShapeMismatch, TosaErrorValidator.evFFTOutputShapeMismatch, ), + "data_gen": { + "fp": (gtu.DataGenType.DOT_PRODUCT,), + }, }, "rfft2d": { "op": Op.RFFT2D, |