diff options
Diffstat (limited to 'verif/conformance/tosa_verif_conformance_generator.py')
-rw-r--r-- | verif/conformance/tosa_verif_conformance_generator.py | 281 |
1 files changed, 156 insertions, 125 deletions
diff --git a/verif/conformance/tosa_verif_conformance_generator.py b/verif/conformance/tosa_verif_conformance_generator.py index 433a33f..59e88bb 100644 --- a/verif/conformance/tosa_verif_conformance_generator.py +++ b/verif/conformance/tosa_verif_conformance_generator.py @@ -26,6 +26,7 @@ from pathlib import Path import conformance.model_files as cmf from conformance.test_select import Operator +from conformance.tosa_profiles import TosaProfiles from convert2conformance.convert2conformance import main as c2c_main from convert2conformance.convert2conformance import OUTPUT_TYPE_DEFAULT from convert2conformance.convert2conformance import OUTPUT_TYPES @@ -47,7 +48,7 @@ PROFILE_OPS_INFO = { "framework_tests": "tosa_main_profile_framework_ops_info.json", }, } -PROFILES_ALL = "all" +PROFILES_EXTENSIONS_ALL = "all" DEFAULT_SEED = 42 @@ -96,10 +97,22 @@ def _run_sh_command(args, cwd, full_cmd): return (rc.stdout, rc.stderr) +def _supports_for_enabled(profile_ext): + # The "supports_for" part of the config only works for MI and related extensions + # TODO - Update with TosaBI etc in future + return profile_ext in ( + TosaProfiles.TosaMI, + TosaProfiles.TosaExtFP8E4M3, + TosaProfiles.TosaExtFP8E5M2, + TosaProfiles.TosaExtBF16, + TosaProfiles.TosaExtFFT, + ) + + def build_op_tests( args, test_type, - profile, + profile_ext, operator, group, gen_args_list, @@ -115,7 +128,7 @@ def build_op_tests( Returns operator output directory """ build_tests_cmd = "tosa_verif_build_tests" - op_build_dir = args.build_dir / profile / group + op_build_dir = args.build_dir / profile_ext / group if gen_filter is None: gen_filter = f"^{operator}$" @@ -131,19 +144,19 @@ def build_op_tests( "--seed", str(args.random_seed), ] - if args.verbosity: build_cmd_base.append("-" + ("v" * args.verbosity)) if args.tests_list_file is not None: build_cmd_base.append("--list-tests") - if "lazy_data_gen" in supports and args.lazy_data_generation: - build_cmd_base.append("--lazy-data-generation") - if "stable_random_gen" in supports and not args.global_random_generation: - build_cmd_base.append("--stable-random-generation") - if "random_const_inputs" in supports: - build_cmd_base.append("--random-const-inputs") + if _supports_for_enabled(profile_ext): + if "lazy_data_gen" in supports and args.lazy_data_generation: + build_cmd_base.append("--lazy-data-generation") + if "stable_random_gen" in supports and not args.global_random_generation: + build_cmd_base.append("--stable-random-generation") + if "random_const_inputs" in supports: + build_cmd_base.append("--random-const-inputs") if "generator_select" in supports: if selector_info is None: @@ -252,11 +265,14 @@ def _get_all_tests_list(test_type, test_root_dir, operator): return tests -def generate_results(args, profile, operator, op_build_dir, supports=[], tests=None): +def generate_results( + args, profile_ext, operator, op_build_dir, supports=[], tests=None +): """Run tests on reference model and save result to the test directory.""" - if "lazy_data_gen" in supports and args.lazy_data_generation: - logger.info("Skipping running tests due to lazy data gen") - return + if _supports_for_enabled(profile_ext): + if "lazy_data_gen" in supports and args.lazy_data_generation: + logger.info("Skipping running tests due to lazy data gen") + return num_cores = args.num_cores @@ -320,11 +336,11 @@ def generate_results(args, profile, operator, op_build_dir, supports=[], tests=N def convert_tests( args, test_type, - profile, + profile_ext, operator, op_build_dir, output_dir, - op_profiles_list, + op_profiles_extensions_list, supports=[], tests=None, group=None, @@ -341,22 +357,23 @@ def convert_tests( c2c_args_base.extend(["--output-type", args.output_type]) # This op maybe in more than one profile - e.g. tosa_bi and tosa_mi # even if we are only producing tests for tosa_mi - for op_profile in op_profiles_list: + for op_profile in op_profiles_extensions_list: c2c_args_base.extend(["--profile", op_profile]) if tags is not None: for tag in tags: c2c_args_base.extend(["--tag", tag]) if args.framework_schema: c2c_args_base.extend(["--framework-schema", str(args.framework_schema)]) - if "lazy_data_gen" in supports and args.lazy_data_generation: - c2c_args_base.append("--lazy-data-generation") + if _supports_for_enabled(profile_ext): + if "lazy_data_gen" in supports and args.lazy_data_generation: + c2c_args_base.append("--lazy-data-generation") c2c_args_base.append("--output-directory") c2c_args_list = [] if not tests: tests = _get_all_tests_list(test_type, op_build_dir, operator) - logger.info(f"Converting all {profile} profile tests of type {test_type}") + logger.info(f"Converting all {profile_ext} profile tests of type {test_type}") # Controls if we copy the tests in their operator sub-directory or not output_dir_relative_pos = -1 if trim_op_subdir else -2 @@ -408,7 +425,7 @@ def convert_tests( def get_op_tests_selection( args, - profile, + profile_ext, operator, op_build_dir, selection_config, @@ -418,7 +435,11 @@ def get_op_tests_selection( """Use test picker to get subsection of tests generated.""" # Need a full copy of the config as the selector updates it config = copy.deepcopy(selection_config) - logger.info("Choosing {} tests".format(("negative" if negative else "positive"))) + logger.info( + "Choosing {} tests for {}".format( + ("negative" if negative else "positive"), profile_ext + ) + ) try: op = Operator.registry[operator]( op_build_dir, config, negative=negative, ignore_missing=ignore_missing @@ -517,15 +538,16 @@ def get_framework_tests_selection(args, operator, test_picks, op_build_dir): def parse_args(argv=None): """Parse the arguments.""" parser = argparse.ArgumentParser() - profiles = list(PROFILE_OPS_INFO.keys()) - profiles.append(PROFILES_ALL) + profiles = TosaProfiles.profiles() + profiles.append(PROFILES_EXTENSIONS_ALL) parser.add_argument( "--profile", dest="profile", choices=profiles, - default=profiles[0], + default=[TosaProfiles.TosaBI], type=str, - help=f"TOSA profile (default is {profiles[0]})", + nargs="*", + help=f"TOSA profile (default is {TosaProfiles.TosaBI})", ) parser.add_argument( "--operators", @@ -535,6 +557,15 @@ def parse_args(argv=None): help="The operator(s) to create tests for, if not supplied all tests will be created", ) parser.add_argument( + "--extension", + dest="extension", + choices=TosaProfiles.extensions() + [PROFILES_EXTENSIONS_ALL], + default=[], + type=str, + nargs="*", + help="TOSA extension(s) to create tests for, if not supplied all tests will be created", + ) + parser.add_argument( "--unit-tests", dest="unit_tests", choices=["operator", "framework", "both"], @@ -658,6 +689,13 @@ def parse_args(argv=None): help=f"Test parameters (ops info) JSON file directory (default is {script_dir})", ) parser.add_argument( + "--test-params-json-config", + "--config", + dest="param_config", + type=Path, + help="Test parameters (ops info) JSON file (overrides --test-param-json-directory)", + ) + parser.add_argument( "--convert-all-tests", action="store_true", help="Converts all tests instead of those picked by test_select", @@ -739,6 +777,9 @@ def parse_args(argv=None): args.param_json_dir = args.param_json_dir.absolute() + if args.param_config is not None: + args.param_config = args.param_config.absolute() + if args.unit_tests in ["framework", "both"]: logger.warning( "DEPRECATION - Framework tests are not part of TOSA conformance testing" @@ -827,86 +868,40 @@ def main(): # TODO: For tosa-mi should really generate tosa-bi profile as well # - for now leave it as subset instead of as superset (for testing) - if args.profile == PROFILES_ALL: - profiles = list(PROFILE_OPS_INFO.keys()) + if PROFILES_EXTENSIONS_ALL in args.profile: + profiles = TosaProfiles.profiles() else: - profiles = [args.profile] + profiles = args.profile + + if PROFILES_EXTENSIONS_ALL in args.extension: + extensions = TosaProfiles.extensions() + else: + extensions = args.extension + profileExtList = profiles + extensions + profileExtDone = [] try: - for profile in profiles: - print(f"Creating conformance tests for TOSA {profile} profile") + for profile_ext in profileExtList: # Framework unit tests if args.unit_tests in ["framework", "both"]: - logger.debug("Creating FRAMEWORK unit tests") - test_picks_file = ( - args.param_json_dir / PROFILE_OPS_INFO[profile]["framework_tests"] - ) - try: - with open(test_picks_file, "r") as fd: - test_picks = json.load(fd) - except Exception as e: - logger.error( - f"Couldn't load framework tests info - {test_picks_file}: {e}" - ) - return 1 - - operators = args.operators - if not operators: - # Create tests for all the operators - operators = list(test_picks.keys()) - - root_output_dir = ( - args.output_dir / "frameworks" / "tflite" / "operators" - ) - for op in operators: - logger.info(f"FRAMEWORK OP: {op}") - if op not in test_picks: - logger.warning( - f"Framework op {op} not found in {test_picks_file} - skipping" - ) - continue - - op_profiles_list = test_picks[op]["profile"] - if ( - args.profile != PROFILES_ALL - and args.profile not in op_profiles_list - ): - # Skip this operator as not part of the profile chosen - logger.debug(f"Skipping {op} as not part of {args.profile}") - continue - - logger.debug(f"Copying and renaming {op}") - framework_test_dir = copy_rename_framework_tests( - args, op, test_picks - ) - - if args.convert_all_tests: - logger.debug("Running and converting all framework tests") - framework_tests = None # Don't select any - else: - logger.debug("Running and converting selected framework tests") - framework_tests = get_framework_tests_selection( - args, op, test_picks, framework_test_dir - ) - convert_tests( - args, - "positive", - profile, - op, - framework_test_dir, - root_output_dir, - op_profiles_list, - tests=framework_tests, - trim_op_subdir=True, - ) + logger.error("Framework test support has been removed") # Operator unit tests if args.unit_tests in ["operator", "both"]: logger.debug("Creating OPERATOR unit tests") - test_params_file = ( - args.param_json_dir - / PROFILE_OPS_INFO[profile]["operator_test_params"] - ) + if args.param_config is None: + # Fall back to old method + if profile_ext in PROFILE_OPS_INFO: + config = PROFILE_OPS_INFO[profile_ext]["operator_test_params"] + test_params_file = args.param_json_dir / config + else: + logger.error( + "Extensions not supported in old conformance configs - skipping" + ) + continue + else: + test_params_file = args.param_config + try: with open(test_params_file, "r") as fd: test_params = json.load(fd) @@ -922,6 +917,10 @@ def main(): # Create tests for all the operators operators = list(test_params.keys()) + print( + f"Creating conformance tests for TOSA {profile_ext} profile/extension" + ) + for op in operators: logger.info(f"OPERATOR: {op}") if op not in test_params: @@ -930,30 +929,49 @@ def main(): ) continue - op_profiles_list = test_params[op]["profile"] - if ( - args.profile != PROFILES_ALL - and args.profile not in op_profiles_list - ): - # Skip this operator as not part of the profile chosen - logger.debug(f"Skipping {op} as not part of {args.profile}") - continue - operator_group = test_params[op]["group"] root_output_dir = args.output_dir / "operators" - supports = ( - test_params[op]["support_for"] - if "support_for" in test_params[op] - else [] - ) - gen_filter = ( - test_params[op]["gen_filter"] - if "gen_filter" in test_params[op] - else None - ) + supports = test_params[op].get("support_for", []) + gen_filter = test_params[op].get("gen_filter", None) + old_profile_info = test_params[op].get("profile", []) # Iterate through the generation groups selecting tests from each for gen_name, gen_dict in test_params[op]["generation"].items(): + supports_any = gen_dict.get("supports_any", []) + supports_all = gen_dict.get("supports_all", []) + + # Fall back for old configs + if not supports_all and not supports_any: + if not old_profile_info: + logger.error( + f"generator {gen_name} for {op} is missing supports_all/supports_any" + ) + raise (GenConformanceError()) + else: + supports_any = old_profile_info + + supported = supports_any + supports_all + + if profile_ext not in supported: + logger.info( + f"No match for profile/extension {profile_ext} for generation group {gen_name} - skipping" + ) + continue + + if any(p in supported for p in profileExtDone): + logger.info( + f"Already used this generator {gen_name} before - skipping" + ) + continue + + if profile_ext not in supports_any and not ( + len(supports_all) > 0 + and all(p in profileExtList for p in supports_all) + ): + logger.info( + f"Profile/extension {profile_ext} is not in {supports_any} or the profiles/extensions chosen do not meet all the requirements of {supports_all} - skipping" + ) + continue if not in_version(args.test_version, gen_dict): logger.warning( @@ -993,16 +1011,23 @@ def main(): selector_name = "default" else: selector_name = "default" + if selector_name not in test_params[op]["selection"]: logger.error( f"Could not find {selector_name} in selection dict for {op}" ) raise (GenConformanceError()) + if test_params[op]["selection"][selector_name].get( + "generator_select", False + ): + # Extend the support to include the new test selection in the generator + supports = supports + ["generator_select"] + op_build_dir = build_op_tests( args, test_type, - profile, + profile_ext, op, gen_name, gen_dict["generator_args"], @@ -1020,7 +1045,11 @@ def main(): if test_type in ["positive", "both"]: logger.info(f"Running and converting all {op} tests") generate_results( - args, profile, op, op_build_dir, supports=supports + args, + profile_ext, + op, + op_build_dir, + supports=supports, ) operator_test_list = None else: @@ -1053,7 +1082,7 @@ def main(): tests_gen, tests_gen2 = tee( get_op_tests_selection( args, - profile, + profile_ext, op, op_build_dir, selection_config, @@ -1062,7 +1091,7 @@ def main(): ) generate_results( args, - profile, + profile_ext, op, op_build_dir, supports=supports, @@ -1075,7 +1104,7 @@ def main(): operator_test_list.extend( get_op_tests_selection( args, - profile, + profile_ext, op, op_build_dir, selection_config, @@ -1086,22 +1115,24 @@ def main(): tags = ( [gen_name] if gen_name != STANDARD_GENERATOR_GROUP else None ) - output_dir = convert_tests( args, test_type, - profile, + profile_ext, op, op_build_dir, root_output_dir, - op_profiles_list, + supported, supports=supports, tests=operator_test_list, group=operator_group, tags=tags, ) if not args.keep_large_files: - check_op_tests(args, profile, op, output_dir) + check_op_tests(args, profile_ext, op, output_dir) + + profileExtDone.append(profile_ext) + except GenConformanceError: return 1 |