aboutsummaryrefslogtreecommitdiff
path: root/verif/conformance/tosa_verif_conformance_generator.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/conformance/tosa_verif_conformance_generator.py')
-rw-r--r--verif/conformance/tosa_verif_conformance_generator.py281
1 files changed, 156 insertions, 125 deletions
diff --git a/verif/conformance/tosa_verif_conformance_generator.py b/verif/conformance/tosa_verif_conformance_generator.py
index 433a33f..59e88bb 100644
--- a/verif/conformance/tosa_verif_conformance_generator.py
+++ b/verif/conformance/tosa_verif_conformance_generator.py
@@ -26,6 +26,7 @@ from pathlib import Path
import conformance.model_files as cmf
from conformance.test_select import Operator
+from conformance.tosa_profiles import TosaProfiles
from convert2conformance.convert2conformance import main as c2c_main
from convert2conformance.convert2conformance import OUTPUT_TYPE_DEFAULT
from convert2conformance.convert2conformance import OUTPUT_TYPES
@@ -47,7 +48,7 @@ PROFILE_OPS_INFO = {
"framework_tests": "tosa_main_profile_framework_ops_info.json",
},
}
-PROFILES_ALL = "all"
+PROFILES_EXTENSIONS_ALL = "all"
DEFAULT_SEED = 42
@@ -96,10 +97,22 @@ def _run_sh_command(args, cwd, full_cmd):
return (rc.stdout, rc.stderr)
+def _supports_for_enabled(profile_ext):
+ # The "supports_for" part of the config only works for MI and related extensions
+ # TODO - Update with TosaBI etc in future
+ return profile_ext in (
+ TosaProfiles.TosaMI,
+ TosaProfiles.TosaExtFP8E4M3,
+ TosaProfiles.TosaExtFP8E5M2,
+ TosaProfiles.TosaExtBF16,
+ TosaProfiles.TosaExtFFT,
+ )
+
+
def build_op_tests(
args,
test_type,
- profile,
+ profile_ext,
operator,
group,
gen_args_list,
@@ -115,7 +128,7 @@ def build_op_tests(
Returns operator output directory
"""
build_tests_cmd = "tosa_verif_build_tests"
- op_build_dir = args.build_dir / profile / group
+ op_build_dir = args.build_dir / profile_ext / group
if gen_filter is None:
gen_filter = f"^{operator}$"
@@ -131,19 +144,19 @@ def build_op_tests(
"--seed",
str(args.random_seed),
]
-
if args.verbosity:
build_cmd_base.append("-" + ("v" * args.verbosity))
if args.tests_list_file is not None:
build_cmd_base.append("--list-tests")
- if "lazy_data_gen" in supports and args.lazy_data_generation:
- build_cmd_base.append("--lazy-data-generation")
- if "stable_random_gen" in supports and not args.global_random_generation:
- build_cmd_base.append("--stable-random-generation")
- if "random_const_inputs" in supports:
- build_cmd_base.append("--random-const-inputs")
+ if _supports_for_enabled(profile_ext):
+ if "lazy_data_gen" in supports and args.lazy_data_generation:
+ build_cmd_base.append("--lazy-data-generation")
+ if "stable_random_gen" in supports and not args.global_random_generation:
+ build_cmd_base.append("--stable-random-generation")
+ if "random_const_inputs" in supports:
+ build_cmd_base.append("--random-const-inputs")
if "generator_select" in supports:
if selector_info is None:
@@ -252,11 +265,14 @@ def _get_all_tests_list(test_type, test_root_dir, operator):
return tests
-def generate_results(args, profile, operator, op_build_dir, supports=[], tests=None):
+def generate_results(
+ args, profile_ext, operator, op_build_dir, supports=[], tests=None
+):
"""Run tests on reference model and save result to the test directory."""
- if "lazy_data_gen" in supports and args.lazy_data_generation:
- logger.info("Skipping running tests due to lazy data gen")
- return
+ if _supports_for_enabled(profile_ext):
+ if "lazy_data_gen" in supports and args.lazy_data_generation:
+ logger.info("Skipping running tests due to lazy data gen")
+ return
num_cores = args.num_cores
@@ -320,11 +336,11 @@ def generate_results(args, profile, operator, op_build_dir, supports=[], tests=N
def convert_tests(
args,
test_type,
- profile,
+ profile_ext,
operator,
op_build_dir,
output_dir,
- op_profiles_list,
+ op_profiles_extensions_list,
supports=[],
tests=None,
group=None,
@@ -341,22 +357,23 @@ def convert_tests(
c2c_args_base.extend(["--output-type", args.output_type])
# This op maybe in more than one profile - e.g. tosa_bi and tosa_mi
# even if we are only producing tests for tosa_mi
- for op_profile in op_profiles_list:
+ for op_profile in op_profiles_extensions_list:
c2c_args_base.extend(["--profile", op_profile])
if tags is not None:
for tag in tags:
c2c_args_base.extend(["--tag", tag])
if args.framework_schema:
c2c_args_base.extend(["--framework-schema", str(args.framework_schema)])
- if "lazy_data_gen" in supports and args.lazy_data_generation:
- c2c_args_base.append("--lazy-data-generation")
+ if _supports_for_enabled(profile_ext):
+ if "lazy_data_gen" in supports and args.lazy_data_generation:
+ c2c_args_base.append("--lazy-data-generation")
c2c_args_base.append("--output-directory")
c2c_args_list = []
if not tests:
tests = _get_all_tests_list(test_type, op_build_dir, operator)
- logger.info(f"Converting all {profile} profile tests of type {test_type}")
+ logger.info(f"Converting all {profile_ext} profile tests of type {test_type}")
# Controls if we copy the tests in their operator sub-directory or not
output_dir_relative_pos = -1 if trim_op_subdir else -2
@@ -408,7 +425,7 @@ def convert_tests(
def get_op_tests_selection(
args,
- profile,
+ profile_ext,
operator,
op_build_dir,
selection_config,
@@ -418,7 +435,11 @@ def get_op_tests_selection(
"""Use test picker to get subsection of tests generated."""
# Need a full copy of the config as the selector updates it
config = copy.deepcopy(selection_config)
- logger.info("Choosing {} tests".format(("negative" if negative else "positive")))
+ logger.info(
+ "Choosing {} tests for {}".format(
+ ("negative" if negative else "positive"), profile_ext
+ )
+ )
try:
op = Operator.registry[operator](
op_build_dir, config, negative=negative, ignore_missing=ignore_missing
@@ -517,15 +538,16 @@ def get_framework_tests_selection(args, operator, test_picks, op_build_dir):
def parse_args(argv=None):
"""Parse the arguments."""
parser = argparse.ArgumentParser()
- profiles = list(PROFILE_OPS_INFO.keys())
- profiles.append(PROFILES_ALL)
+ profiles = TosaProfiles.profiles()
+ profiles.append(PROFILES_EXTENSIONS_ALL)
parser.add_argument(
"--profile",
dest="profile",
choices=profiles,
- default=profiles[0],
+ default=[TosaProfiles.TosaBI],
type=str,
- help=f"TOSA profile (default is {profiles[0]})",
+ nargs="*",
+ help=f"TOSA profile (default is {TosaProfiles.TosaBI})",
)
parser.add_argument(
"--operators",
@@ -535,6 +557,15 @@ def parse_args(argv=None):
help="The operator(s) to create tests for, if not supplied all tests will be created",
)
parser.add_argument(
+ "--extension",
+ dest="extension",
+ choices=TosaProfiles.extensions() + [PROFILES_EXTENSIONS_ALL],
+ default=[],
+ type=str,
+ nargs="*",
+ help="TOSA extension(s) to create tests for, if not supplied all tests will be created",
+ )
+ parser.add_argument(
"--unit-tests",
dest="unit_tests",
choices=["operator", "framework", "both"],
@@ -658,6 +689,13 @@ def parse_args(argv=None):
help=f"Test parameters (ops info) JSON file directory (default is {script_dir})",
)
parser.add_argument(
+ "--test-params-json-config",
+ "--config",
+ dest="param_config",
+ type=Path,
+ help="Test parameters (ops info) JSON file (overrides --test-param-json-directory)",
+ )
+ parser.add_argument(
"--convert-all-tests",
action="store_true",
help="Converts all tests instead of those picked by test_select",
@@ -739,6 +777,9 @@ def parse_args(argv=None):
args.param_json_dir = args.param_json_dir.absolute()
+ if args.param_config is not None:
+ args.param_config = args.param_config.absolute()
+
if args.unit_tests in ["framework", "both"]:
logger.warning(
"DEPRECATION - Framework tests are not part of TOSA conformance testing"
@@ -827,86 +868,40 @@ def main():
# TODO: For tosa-mi should really generate tosa-bi profile as well
# - for now leave it as subset instead of as superset (for testing)
- if args.profile == PROFILES_ALL:
- profiles = list(PROFILE_OPS_INFO.keys())
+ if PROFILES_EXTENSIONS_ALL in args.profile:
+ profiles = TosaProfiles.profiles()
else:
- profiles = [args.profile]
+ profiles = args.profile
+
+ if PROFILES_EXTENSIONS_ALL in args.extension:
+ extensions = TosaProfiles.extensions()
+ else:
+ extensions = args.extension
+ profileExtList = profiles + extensions
+ profileExtDone = []
try:
- for profile in profiles:
- print(f"Creating conformance tests for TOSA {profile} profile")
+ for profile_ext in profileExtList:
# Framework unit tests
if args.unit_tests in ["framework", "both"]:
- logger.debug("Creating FRAMEWORK unit tests")
- test_picks_file = (
- args.param_json_dir / PROFILE_OPS_INFO[profile]["framework_tests"]
- )
- try:
- with open(test_picks_file, "r") as fd:
- test_picks = json.load(fd)
- except Exception as e:
- logger.error(
- f"Couldn't load framework tests info - {test_picks_file}: {e}"
- )
- return 1
-
- operators = args.operators
- if not operators:
- # Create tests for all the operators
- operators = list(test_picks.keys())
-
- root_output_dir = (
- args.output_dir / "frameworks" / "tflite" / "operators"
- )
- for op in operators:
- logger.info(f"FRAMEWORK OP: {op}")
- if op not in test_picks:
- logger.warning(
- f"Framework op {op} not found in {test_picks_file} - skipping"
- )
- continue
-
- op_profiles_list = test_picks[op]["profile"]
- if (
- args.profile != PROFILES_ALL
- and args.profile not in op_profiles_list
- ):
- # Skip this operator as not part of the profile chosen
- logger.debug(f"Skipping {op} as not part of {args.profile}")
- continue
-
- logger.debug(f"Copying and renaming {op}")
- framework_test_dir = copy_rename_framework_tests(
- args, op, test_picks
- )
-
- if args.convert_all_tests:
- logger.debug("Running and converting all framework tests")
- framework_tests = None # Don't select any
- else:
- logger.debug("Running and converting selected framework tests")
- framework_tests = get_framework_tests_selection(
- args, op, test_picks, framework_test_dir
- )
- convert_tests(
- args,
- "positive",
- profile,
- op,
- framework_test_dir,
- root_output_dir,
- op_profiles_list,
- tests=framework_tests,
- trim_op_subdir=True,
- )
+ logger.error("Framework test support has been removed")
# Operator unit tests
if args.unit_tests in ["operator", "both"]:
logger.debug("Creating OPERATOR unit tests")
- test_params_file = (
- args.param_json_dir
- / PROFILE_OPS_INFO[profile]["operator_test_params"]
- )
+ if args.param_config is None:
+ # Fall back to old method
+ if profile_ext in PROFILE_OPS_INFO:
+ config = PROFILE_OPS_INFO[profile_ext]["operator_test_params"]
+ test_params_file = args.param_json_dir / config
+ else:
+ logger.error(
+ "Extensions not supported in old conformance configs - skipping"
+ )
+ continue
+ else:
+ test_params_file = args.param_config
+
try:
with open(test_params_file, "r") as fd:
test_params = json.load(fd)
@@ -922,6 +917,10 @@ def main():
# Create tests for all the operators
operators = list(test_params.keys())
+ print(
+ f"Creating conformance tests for TOSA {profile_ext} profile/extension"
+ )
+
for op in operators:
logger.info(f"OPERATOR: {op}")
if op not in test_params:
@@ -930,30 +929,49 @@ def main():
)
continue
- op_profiles_list = test_params[op]["profile"]
- if (
- args.profile != PROFILES_ALL
- and args.profile not in op_profiles_list
- ):
- # Skip this operator as not part of the profile chosen
- logger.debug(f"Skipping {op} as not part of {args.profile}")
- continue
-
operator_group = test_params[op]["group"]
root_output_dir = args.output_dir / "operators"
- supports = (
- test_params[op]["support_for"]
- if "support_for" in test_params[op]
- else []
- )
- gen_filter = (
- test_params[op]["gen_filter"]
- if "gen_filter" in test_params[op]
- else None
- )
+ supports = test_params[op].get("support_for", [])
+ gen_filter = test_params[op].get("gen_filter", None)
+ old_profile_info = test_params[op].get("profile", [])
# Iterate through the generation groups selecting tests from each
for gen_name, gen_dict in test_params[op]["generation"].items():
+ supports_any = gen_dict.get("supports_any", [])
+ supports_all = gen_dict.get("supports_all", [])
+
+ # Fall back for old configs
+ if not supports_all and not supports_any:
+ if not old_profile_info:
+ logger.error(
+ f"generator {gen_name} for {op} is missing supports_all/supports_any"
+ )
+ raise (GenConformanceError())
+ else:
+ supports_any = old_profile_info
+
+ supported = supports_any + supports_all
+
+ if profile_ext not in supported:
+ logger.info(
+ f"No match for profile/extension {profile_ext} for generation group {gen_name} - skipping"
+ )
+ continue
+
+ if any(p in supported for p in profileExtDone):
+ logger.info(
+ f"Already used this generator {gen_name} before - skipping"
+ )
+ continue
+
+ if profile_ext not in supports_any and not (
+ len(supports_all) > 0
+ and all(p in profileExtList for p in supports_all)
+ ):
+ logger.info(
+ f"Profile/extension {profile_ext} is not in {supports_any} or the profiles/extensions chosen do not meet all the requirements of {supports_all} - skipping"
+ )
+ continue
if not in_version(args.test_version, gen_dict):
logger.warning(
@@ -993,16 +1011,23 @@ def main():
selector_name = "default"
else:
selector_name = "default"
+
if selector_name not in test_params[op]["selection"]:
logger.error(
f"Could not find {selector_name} in selection dict for {op}"
)
raise (GenConformanceError())
+ if test_params[op]["selection"][selector_name].get(
+ "generator_select", False
+ ):
+ # Extend the support to include the new test selection in the generator
+ supports = supports + ["generator_select"]
+
op_build_dir = build_op_tests(
args,
test_type,
- profile,
+ profile_ext,
op,
gen_name,
gen_dict["generator_args"],
@@ -1020,7 +1045,11 @@ def main():
if test_type in ["positive", "both"]:
logger.info(f"Running and converting all {op} tests")
generate_results(
- args, profile, op, op_build_dir, supports=supports
+ args,
+ profile_ext,
+ op,
+ op_build_dir,
+ supports=supports,
)
operator_test_list = None
else:
@@ -1053,7 +1082,7 @@ def main():
tests_gen, tests_gen2 = tee(
get_op_tests_selection(
args,
- profile,
+ profile_ext,
op,
op_build_dir,
selection_config,
@@ -1062,7 +1091,7 @@ def main():
)
generate_results(
args,
- profile,
+ profile_ext,
op,
op_build_dir,
supports=supports,
@@ -1075,7 +1104,7 @@ def main():
operator_test_list.extend(
get_op_tests_selection(
args,
- profile,
+ profile_ext,
op,
op_build_dir,
selection_config,
@@ -1086,22 +1115,24 @@ def main():
tags = (
[gen_name] if gen_name != STANDARD_GENERATOR_GROUP else None
)
-
output_dir = convert_tests(
args,
test_type,
- profile,
+ profile_ext,
op,
op_build_dir,
root_output_dir,
- op_profiles_list,
+ supported,
supports=supports,
tests=operator_test_list,
group=operator_group,
tags=tags,
)
if not args.keep_large_files:
- check_op_tests(args, profile, op, output_dir)
+ check_op_tests(args, profile_ext, op, output_dir)
+
+ profileExtDone.append(profile_ext)
+
except GenConformanceError:
return 1