diff options
Diffstat (limited to 'verif')
-rw-r--r-- | verif/checker/tosa_result_checker.py | 16 | ||||
-rw-r--r-- | verif/conformance/tosa_main_profile_ops_info.json | 5 | ||||
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 22 | ||||
-rw-r--r-- | verif/generator/tosa_test_gen.py | 69 |
4 files changed, 88 insertions, 24 deletions
diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py index 6948378..212c809 100644 --- a/verif/checker/tosa_result_checker.py +++ b/verif/checker/tosa_result_checker.py @@ -1,5 +1,5 @@ """TOSA result checker script.""" -# Copyright (c) 2020-2023, ARM Limited. +# Copyright (c) 2020-2024, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import argparse import json @@ -55,9 +55,9 @@ def _print_result(color, msg): def compliance_check( - imp_result_path, - ref_result_path, - bnd_result_path, + imp_result_data, + ref_result_data, + bnd_result_data, test_name, compliance_config, ofm_name, @@ -78,14 +78,18 @@ def compliance_check( return (TestResult.INTERNAL_ERROR, 0.0, msg) success = vlib.verify_data( - ofm_name, compliance_config, imp_result_path, ref_result_path, bnd_result_path + ofm_name, compliance_config, imp_result_data, ref_result_data, bnd_result_data ) if success: _print_result(LogColors.GREEN, f"Compliance Results PASS {test_name}") return (TestResult.PASS, 0.0, "") else: _print_result(LogColors.RED, f"Results NON-COMPLIANT {test_name}") - return (TestResult.MISMATCH, 0.0, "Non-compliance results found") + return ( + TestResult.MISMATCH, + 0.0, + f"Non-compliance results found for {ofm_name}", + ) def test_check( diff --git a/verif/conformance/tosa_main_profile_ops_info.json b/verif/conformance/tosa_main_profile_ops_info.json index 5e35e8b..067fab7 100644 --- a/verif/conformance/tosa_main_profile_ops_info.json +++ b/verif/conformance/tosa_main_profile_ops_info.json @@ -980,6 +980,7 @@ "profile": [ "tosa-mi" ], + "support_for": [ "lazy_data_gen" ], "generation": { "standard": { "generator_args": [ @@ -987,13 +988,13 @@ "--target-dtype", "fp32", "--fp-values-range", - "-2.0,2.0" + "-max,max" ], [ "--target-dtype", "fp32", "--fp-values-range", - "-2.0,2.0", + "-max,max", "--target-shape", "1,256,64", "--target-shape", diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index b4939da..f6a46b4 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -2798,9 +2798,27 @@ class TosaArgGen: def agFFT2d(testGen, opName, shapeList, dtype, error_name=None): arg_list = [] - arg_list.append(("inverseTrue", [True])) - arg_list.append(("inverseFalse", [False])) + shape = shapeList[0] + dot_products = gtu.product(shape) + ks = 2 * shape[1] * shape[2] # 2*H*W + for inverse in (True, False): + args_dict = { + "dot_products": dot_products, + "shape": shape, + "ks": ks, + "acc_type": dtype, + "inverse": inverse, + } + arg_list.append((f"inverse{inverse}", args_dict)) + arg_list = TosaArgGen._add_data_generators( + testGen, + opName, + dtype, + arg_list, + error_name, + ) + # Return list of tuples: (arg_str, args_dict) return arg_list # Helper function for reshape. Gets some factors of a larger number. diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index bfafd23..68a4e94 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -381,8 +381,35 @@ class TosaTestGen: """Enhanced build information containing result tensor and associated compliance dict.""" def __init__(self, resultTensor, complianceDict): - self.resultTensor = resultTensor - self.complianceDict = complianceDict + if isinstance(resultTensor, list): + assert complianceDict is None or isinstance(complianceDict, list) + self.resultTensorList = resultTensor + self.complianceDictList = complianceDict + else: + self.resultTensorList = [resultTensor] + if complianceDict is None: + self.complianceDictList = None + else: + self.complianceDictList = [complianceDict] + + def getComplianceInfo(self): + if self.complianceDictList is None: + return None + else: + tens_dict = {} + for tens, comp in zip(self.resultTensorList, self.complianceDictList): + if comp is not None: + tens_dict[tens.name] = comp + + if tens_dict: + # Have some compliance data, so return the info + compliance = { + "version": "0.1", + "tensors": tens_dict, + } + else: + compliance = None + return compliance def build_unary( self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None @@ -2491,12 +2518,16 @@ class TosaTestGen: def build_fft2d( self, op, - val1, - val2, - inverse, + inputs, + args_dict, validator_fcns=None, error_name=None, + qinfo=None, ): + assert len(inputs) == 2 + val1, val2 = inputs + inverse = args_dict["inverse"] + results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name) input_names = [val1.name, val2.name] @@ -2537,7 +2568,16 @@ class TosaTestGen: attr.FFTAttribute(inverse, local_bound) self.ser.addOperator(op["op"], input_names, output_names, attr) - return results + + compliance = [] + for res in results: + compliance.append( + self.tensorComplianceMetaData( + op, val1.dtype, args_dict, res, error_name + ) + ) + + return TosaTestGen.BuildInfo(results, compliance) def build_rfft2d( self, @@ -2933,13 +2973,11 @@ class TosaTestGen: if result: # The test is valid, serialize it - if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict: - # Add the compliance meta data - # NOTE: This currently expects only one result output - tensMeta["compliance"] = { - "version": "0.1", - "tensors": {result.resultTensor.name: result.complianceDict}, - } + if isinstance(result, TosaTestGen.BuildInfo): + # Add the compliance meta data (if any) + compliance = result.getComplianceInfo() + if compliance: + tensMeta["compliance"] = compliance self.serialize("test", tensMeta) else: # The test is not valid @@ -4708,7 +4746,7 @@ class TosaTestGen: "build_fcn": ( build_fft2d, TosaTensorGen.tgFFT2d, - TosaTensorValuesGen.tvgDefault, + TosaTensorValuesGen.tvgLazyGenDefault, TosaArgGen.agFFT2d, ), "types": [DType.FP32], @@ -4723,6 +4761,9 @@ class TosaTestGen: TosaErrorValidator.evFFTInputShapeMismatch, TosaErrorValidator.evFFTOutputShapeMismatch, ), + "data_gen": { + "fp": (gtu.DataGenType.DOT_PRODUCT,), + }, }, "rfft2d": { "op": Op.RFFT2D, |